"""Tests for LWE-based SimplePIR via the Python bindings.""" import json import os import zkac def _pad(entry: dict) -> bytes: raw = json.dumps(entry, separators=(",", ":"), sort_keys=True).encode() return raw + b"\x00" * (zkac.PIR_RECORD_BYTES - len(raw)) def test_roundtrip_small(): """Build a small DB, query every index, verify each decode is correct.""" records = [ _pad({"id": i, "data": f"record-{i}"}) for i in range(8) ] db = zkac.PirDatabase(records, zkac.PIR_RECORD_BYTES) server = zkac.PirServer(db) hints = bytes(server.hints()) client = zkac.PirClient(hints, 8, zkac.PIR_RECORD_BYTES) for target in range(8): q, state = client.query(target) answer = server.answer(q) decoded = bytes(client.decode(answer, state)) assert decoded == records[target], f"mismatch at index {target}" def test_roundtrip_medium(): """64 records padded to ``PIR_RECORD_BYTES`` (must fit LWE-encoded plaintext).""" records = [] for i in range(64): entry = {"id": i, "payload": os.urandom(32).hex()} records.append(_pad(entry)) db = zkac.PirDatabase(records, zkac.PIR_RECORD_BYTES) server = zkac.PirServer(db) hints = bytes(server.hints()) client = zkac.PirClient(hints, 64, zkac.PIR_RECORD_BYTES) for target in [0, 1, 31, 32, 63]: q, state = client.query(target) answer = server.answer(q) decoded = bytes(client.decode(answer, state)) assert decoded == records[target] def test_version_changes_on_rebuild(): """Rebuilding from the same data with a new seed gives a different version.""" records = [_pad({"i": i}) for i in range(4)] s1 = zkac.PirServer(zkac.PirDatabase(records, zkac.PIR_RECORD_BYTES)) s2 = zkac.PirServer(zkac.PirDatabase(records, zkac.PIR_RECORD_BYTES)) assert bytes(s1.version()) != bytes(s2.version()) def test_hints_serialize_roundtrip(): records = [_pad({"i": i}) for i in range(4)] server = zkac.PirServer(zkac.PirDatabase(records, zkac.PIR_RECORD_BYTES)) hints1 = bytes(server.hints()) client = zkac.PirClient(hints1, 4, zkac.PIR_RECORD_BYTES) assert bytes(client.version()) == bytes(server.version()) def test_detection_tag_symmetric(): """grant_detection_tag(a_sk, b_pk) == grant_detection_tag(b_sk, a_pk).""" kp_a = zkac.IssuanceKeypair() kp_b = zkac.IssuanceKeypair() tag_ab = bytes(zkac.grant_detection_tag(kp_a.secret_bytes(), kp_b.public_key_bytes())) tag_ba = bytes(zkac.grant_detection_tag(kp_b.secret_bytes(), kp_a.public_key_bytes())) assert tag_ab == tag_ba assert len(tag_ab) == 16 def test_detection_tag_distinct(): kp_a = zkac.IssuanceKeypair() kp_b = zkac.IssuanceKeypair() kp_c = zkac.IssuanceKeypair() t1 = bytes(zkac.grant_detection_tag(kp_a.secret_bytes(), kp_b.public_key_bytes())) t2 = bytes(zkac.grant_detection_tag(kp_a.secret_bytes(), kp_c.public_key_bytes())) assert t1 != t2