ZKAC/tests/test_pir.py
2026-04-19 23:19:24 +02:00

83 lines
2.9 KiB
Python

"""Tests for LWE-based SimplePIR via the Python bindings."""
import json
import os
import zkac
def _pad(entry: dict) -> bytes:
raw = json.dumps(entry, separators=(",", ":"), sort_keys=True).encode()
return raw + b"\x00" * (zkac.PIR_RECORD_BYTES - len(raw))
def test_roundtrip_small():
"""Build a small DB, query every index, verify each decode is correct."""
records = [
_pad({"id": i, "data": f"record-{i}"})
for i in range(8)
]
db = zkac.PirDatabase(records, zkac.PIR_RECORD_BYTES)
server = zkac.PirServer(db)
hints = bytes(server.hints())
client = zkac.PirClient(hints, 8, zkac.PIR_RECORD_BYTES)
for target in range(8):
q, state = client.query(target)
answer = server.answer(q)
decoded = bytes(client.decode(answer, state))
assert decoded == records[target], f"mismatch at index {target}"
def test_roundtrip_medium():
"""64 records padded to ``PIR_RECORD_BYTES`` (must fit LWE-encoded plaintext)."""
records = []
for i in range(64):
entry = {"id": i, "payload": os.urandom(32).hex()}
records.append(_pad(entry))
db = zkac.PirDatabase(records, zkac.PIR_RECORD_BYTES)
server = zkac.PirServer(db)
hints = bytes(server.hints())
client = zkac.PirClient(hints, 64, zkac.PIR_RECORD_BYTES)
for target in [0, 1, 31, 32, 63]:
q, state = client.query(target)
answer = server.answer(q)
decoded = bytes(client.decode(answer, state))
assert decoded == records[target]
def test_version_changes_on_rebuild():
"""Rebuilding from the same data with a new seed gives a different version."""
records = [_pad({"i": i}) for i in range(4)]
s1 = zkac.PirServer(zkac.PirDatabase(records, zkac.PIR_RECORD_BYTES))
s2 = zkac.PirServer(zkac.PirDatabase(records, zkac.PIR_RECORD_BYTES))
assert bytes(s1.version()) != bytes(s2.version())
def test_hints_serialize_roundtrip():
records = [_pad({"i": i}) for i in range(4)]
server = zkac.PirServer(zkac.PirDatabase(records, zkac.PIR_RECORD_BYTES))
hints1 = bytes(server.hints())
client = zkac.PirClient(hints1, 4, zkac.PIR_RECORD_BYTES)
assert bytes(client.version()) == bytes(server.version())
def test_detection_tag_symmetric():
"""grant_detection_tag(a_sk, b_pk) == grant_detection_tag(b_sk, a_pk)."""
kp_a = zkac.IssuanceKeypair()
kp_b = zkac.IssuanceKeypair()
tag_ab = bytes(zkac.grant_detection_tag(kp_a.secret_bytes(), kp_b.public_key_bytes()))
tag_ba = bytes(zkac.grant_detection_tag(kp_b.secret_bytes(), kp_a.public_key_bytes()))
assert tag_ab == tag_ba
assert len(tag_ab) == 16
def test_detection_tag_distinct():
kp_a = zkac.IssuanceKeypair()
kp_b = zkac.IssuanceKeypair()
kp_c = zkac.IssuanceKeypair()
t1 = bytes(zkac.grant_detection_tag(kp_a.secret_bytes(), kp_b.public_key_bytes()))
t2 = bytes(zkac.grant_detection_tag(kp_a.secret_bytes(), kp_c.public_key_bytes()))
assert t1 != t2