"""Private information retrieval helpers. Two-server XOR PIR (Chor–Goldreich–Kushilevitz–Sudan style): For a database of n fixed-length records D[0],...,D[n-1], to retrieve D[i]: - Pick a uniformly random subset S ⊆ {0,...,n-1}. - Server A returns ⊕_{j∈S} D[j] - Server B returns ⊕_{j∈S⊕{i}} D[j] - Client XORs the two replies → D[i]. Privacy holds if the two servers do **not** collude. A single host running both servers learns both queries and can recover i — use two independent operators or hosts for real privacy. Recipients fetch rows via **two-server XOR PIR** (``pir_fold`` on each replica). There is no full-database download endpoint; listing all decryptable grants for a user requires **O(n) PIR queries** (one Chor query per pool index) when scanning the pool. """ from __future__ import annotations import json import secrets from typing import Iterable # Fixed record size for XOR PIR (padded JSON grant entry). PIR_RECORD_BYTES = 64 * 1024 def pad_grant_record(entry: dict) -> bytes: """Serialize grant entry to fixed-length bytes for XOR folding.""" raw = json.dumps(entry, separators=(",", ":"), sort_keys=True).encode("utf-8") if len(raw) > PIR_RECORD_BYTES: raise ValueError(f"grant record exceeds PIR_RECORD_BYTES ({PIR_RECORD_BYTES})") return raw + b"\x00" * (PIR_RECORD_BYTES - len(raw)) def unpad_grant_record(buf: bytes) -> dict: return json.loads(buf.rstrip(b"\x00").decode("utf-8")) def xor_bytes_many(chunks: Iterable[bytes]) -> bytes: it = iter(chunks) first = next(it, None) if first is None: return b"\x00" * PIR_RECORD_BYTES acc = bytearray(first) for c in it: if len(c) != len(acc): raise ValueError("length mismatch in xor_bytes_many") for j in range(len(acc)): acc[j] ^= c[j] return bytes(acc) def random_subset(n: int) -> set[int]: """Uniform random subset of {0,...,n-1}.""" if n <= 0: return set() return {j for j in range(n) if secrets.randbelow(2)} def symdiff_one(s: set[int], i: int) -> set[int]: """S ⊕ {i} as symmetric difference.""" t = set(s) if i in t: t.remove(i) else: t.add(i) return t def pir_query_indices(n: int, i: int) -> tuple[list[int], list[int]]: """Build two index lists for servers A and B to XOR-fold records.""" if not (0 <= i < n): raise ValueError("index out of range") s = random_subset(n) sa = sorted(s) sb = sorted(symdiff_one(s, i)) return sa, sb def pir_recover(xor_a: bytes, xor_b: bytes) -> bytes: if len(xor_a) != len(xor_b): raise ValueError("xor length mismatch") return bytes(a ^ b for a, b in zip(xor_a, xor_b))