83 lines
2.9 KiB
Python
83 lines
2.9 KiB
Python
"""Tests for LWE-based SimplePIR via the Python bindings."""
|
|
|
|
import json
|
|
import os
|
|
|
|
import zkac
|
|
|
|
|
|
def _pad(entry: dict) -> bytes:
|
|
raw = json.dumps(entry, separators=(",", ":"), sort_keys=True).encode()
|
|
return raw + b"\x00" * (zkac.PIR_RECORD_BYTES - len(raw))
|
|
|
|
|
|
def test_roundtrip_small():
|
|
"""Build a small DB, query every index, verify each decode is correct."""
|
|
records = [
|
|
_pad({"id": i, "data": f"record-{i}"})
|
|
for i in range(8)
|
|
]
|
|
db = zkac.PirDatabase(records, zkac.PIR_RECORD_BYTES)
|
|
server = zkac.PirServer(db)
|
|
hints = bytes(server.hints())
|
|
client = zkac.PirClient(hints, 8, zkac.PIR_RECORD_BYTES)
|
|
|
|
for target in range(8):
|
|
q, state = client.query(target)
|
|
answer = server.answer(q)
|
|
decoded = bytes(client.decode(answer, state))
|
|
assert decoded == records[target], f"mismatch at index {target}"
|
|
|
|
|
|
def test_roundtrip_medium():
|
|
"""64 records padded to ``PIR_RECORD_BYTES`` (must fit LWE-encoded plaintext)."""
|
|
records = []
|
|
for i in range(64):
|
|
entry = {"id": i, "payload": os.urandom(32).hex()}
|
|
records.append(_pad(entry))
|
|
db = zkac.PirDatabase(records, zkac.PIR_RECORD_BYTES)
|
|
server = zkac.PirServer(db)
|
|
hints = bytes(server.hints())
|
|
client = zkac.PirClient(hints, 64, zkac.PIR_RECORD_BYTES)
|
|
|
|
for target in [0, 1, 31, 32, 63]:
|
|
q, state = client.query(target)
|
|
answer = server.answer(q)
|
|
decoded = bytes(client.decode(answer, state))
|
|
assert decoded == records[target]
|
|
|
|
|
|
def test_version_changes_on_rebuild():
|
|
"""Rebuilding from the same data with a new seed gives a different version."""
|
|
records = [_pad({"i": i}) for i in range(4)]
|
|
s1 = zkac.PirServer(zkac.PirDatabase(records, zkac.PIR_RECORD_BYTES))
|
|
s2 = zkac.PirServer(zkac.PirDatabase(records, zkac.PIR_RECORD_BYTES))
|
|
assert bytes(s1.version()) != bytes(s2.version())
|
|
|
|
|
|
def test_hints_serialize_roundtrip():
|
|
records = [_pad({"i": i}) for i in range(4)]
|
|
server = zkac.PirServer(zkac.PirDatabase(records, zkac.PIR_RECORD_BYTES))
|
|
hints1 = bytes(server.hints())
|
|
client = zkac.PirClient(hints1, 4, zkac.PIR_RECORD_BYTES)
|
|
assert bytes(client.version()) == bytes(server.version())
|
|
|
|
|
|
def test_detection_tag_symmetric():
|
|
"""grant_detection_tag(a_sk, b_pk) == grant_detection_tag(b_sk, a_pk)."""
|
|
kp_a = zkac.IssuanceKeypair()
|
|
kp_b = zkac.IssuanceKeypair()
|
|
tag_ab = bytes(zkac.grant_detection_tag(kp_a.secret_bytes(), kp_b.public_key_bytes()))
|
|
tag_ba = bytes(zkac.grant_detection_tag(kp_b.secret_bytes(), kp_a.public_key_bytes()))
|
|
assert tag_ab == tag_ba
|
|
assert len(tag_ab) == 16
|
|
|
|
|
|
def test_detection_tag_distinct():
|
|
kp_a = zkac.IssuanceKeypair()
|
|
kp_b = zkac.IssuanceKeypair()
|
|
kp_c = zkac.IssuanceKeypair()
|
|
t1 = bytes(zkac.grant_detection_tag(kp_a.secret_bytes(), kp_b.public_key_bytes()))
|
|
t2 = bytes(zkac.grant_detection_tag(kp_a.secret_bytes(), kp_c.public_key_bytes()))
|
|
assert t1 != t2
|