ZKAC/cli/zkac_cli/pir.py
everbarry 6e67836e95 v0.4
2026-04-18 01:06:12 +02:00

88 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Private information retrieval helpers.
Two-server XOR PIR (ChorGoldreichKushilevitzSudan style):
For a database of n fixed-length records D[0],...,D[n-1], to retrieve D[i]:
- Pick a uniformly random subset S ⊆ {0,...,n-1}.
- Server A returns ⊕_{j∈S} D[j]
- Server B returns ⊕_{j∈S⊕{i}} D[j]
- Client XORs the two replies → D[i].
Privacy holds if the two servers do **not** collude. A single host running both
servers learns both queries and can recover i — use two independent operators
or hosts for real privacy.
Recipients fetch rows via **two-server XOR PIR** (``pir_fold`` on each replica).
There is no full-database download endpoint; listing all decryptable grants for
a user requires **O(n) PIR queries** (one Chor query per pool index) when
scanning the pool.
"""
from __future__ import annotations
import json
import secrets
from typing import Iterable
# Fixed record size for XOR PIR (padded JSON grant entry).
PIR_RECORD_BYTES = 64 * 1024
def pad_grant_record(entry: dict) -> bytes:
"""Serialize grant entry to fixed-length bytes for XOR folding."""
raw = json.dumps(entry, separators=(",", ":"), sort_keys=True).encode("utf-8")
if len(raw) > PIR_RECORD_BYTES:
raise ValueError(f"grant record exceeds PIR_RECORD_BYTES ({PIR_RECORD_BYTES})")
return raw + b"\x00" * (PIR_RECORD_BYTES - len(raw))
def unpad_grant_record(buf: bytes) -> dict:
return json.loads(buf.rstrip(b"\x00").decode("utf-8"))
def xor_bytes_many(chunks: Iterable[bytes]) -> bytes:
it = iter(chunks)
first = next(it, None)
if first is None:
return b"\x00" * PIR_RECORD_BYTES
acc = bytearray(first)
for c in it:
if len(c) != len(acc):
raise ValueError("length mismatch in xor_bytes_many")
for j in range(len(acc)):
acc[j] ^= c[j]
return bytes(acc)
def random_subset(n: int) -> set[int]:
"""Uniform random subset of {0,...,n-1}."""
if n <= 0:
return set()
return {j for j in range(n) if secrets.randbelow(2)}
def symdiff_one(s: set[int], i: int) -> set[int]:
"""S ⊕ {i} as symmetric difference."""
t = set(s)
if i in t:
t.remove(i)
else:
t.add(i)
return t
def pir_query_indices(n: int, i: int) -> tuple[list[int], list[int]]:
"""Build two index lists for servers A and B to XOR-fold records."""
if not (0 <= i < n):
raise ValueError("index out of range")
s = random_subset(n)
sa = sorted(s)
sb = sorted(symdiff_one(s, i))
return sa, sb
def pir_recover(xor_a: bytes, xor_b: bytes) -> bytes:
if len(xor_a) != len(xor_b):
raise ValueError("xor length mismatch")
return bytes(a ^ b for a, b in zip(xor_a, xor_b))