88 lines
2.7 KiB
Python
88 lines
2.7 KiB
Python
"""Private information retrieval helpers.
|
||
|
||
Two-server XOR PIR (Chor–Goldreich–Kushilevitz–Sudan style):
|
||
|
||
For a database of n fixed-length records D[0],...,D[n-1], to retrieve D[i]:
|
||
- Pick a uniformly random subset S ⊆ {0,...,n-1}.
|
||
- Server A returns ⊕_{j∈S} D[j]
|
||
- Server B returns ⊕_{j∈S⊕{i}} D[j]
|
||
- Client XORs the two replies → D[i].
|
||
|
||
Privacy holds if the two servers do **not** collude. A single host running both
|
||
servers learns both queries and can recover i — use two independent operators
|
||
or hosts for real privacy.
|
||
|
||
Recipients fetch rows via **two-server XOR PIR** (``pir_fold`` on each replica).
|
||
There is no full-database download endpoint; listing all decryptable grants for
|
||
a user requires **O(n) PIR queries** (one Chor query per pool index) when
|
||
scanning the pool.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import secrets
|
||
from typing import Iterable
|
||
|
||
# Fixed record size for XOR PIR (padded JSON grant entry).
|
||
PIR_RECORD_BYTES = 64 * 1024
|
||
|
||
|
||
def pad_grant_record(entry: dict) -> bytes:
|
||
"""Serialize grant entry to fixed-length bytes for XOR folding."""
|
||
raw = json.dumps(entry, separators=(",", ":"), sort_keys=True).encode("utf-8")
|
||
if len(raw) > PIR_RECORD_BYTES:
|
||
raise ValueError(f"grant record exceeds PIR_RECORD_BYTES ({PIR_RECORD_BYTES})")
|
||
return raw + b"\x00" * (PIR_RECORD_BYTES - len(raw))
|
||
|
||
|
||
def unpad_grant_record(buf: bytes) -> dict:
|
||
return json.loads(buf.rstrip(b"\x00").decode("utf-8"))
|
||
|
||
|
||
def xor_bytes_many(chunks: Iterable[bytes]) -> bytes:
|
||
it = iter(chunks)
|
||
first = next(it, None)
|
||
if first is None:
|
||
return b"\x00" * PIR_RECORD_BYTES
|
||
acc = bytearray(first)
|
||
for c in it:
|
||
if len(c) != len(acc):
|
||
raise ValueError("length mismatch in xor_bytes_many")
|
||
for j in range(len(acc)):
|
||
acc[j] ^= c[j]
|
||
return bytes(acc)
|
||
|
||
|
||
def random_subset(n: int) -> set[int]:
|
||
"""Uniform random subset of {0,...,n-1}."""
|
||
if n <= 0:
|
||
return set()
|
||
return {j for j in range(n) if secrets.randbelow(2)}
|
||
|
||
|
||
def symdiff_one(s: set[int], i: int) -> set[int]:
|
||
"""S ⊕ {i} as symmetric difference."""
|
||
t = set(s)
|
||
if i in t:
|
||
t.remove(i)
|
||
else:
|
||
t.add(i)
|
||
return t
|
||
|
||
|
||
def pir_query_indices(n: int, i: int) -> tuple[list[int], list[int]]:
|
||
"""Build two index lists for servers A and B to XOR-fold records."""
|
||
if not (0 <= i < n):
|
||
raise ValueError("index out of range")
|
||
s = random_subset(n)
|
||
sa = sorted(s)
|
||
sb = sorted(symdiff_one(s, i))
|
||
return sa, sb
|
||
|
||
|
||
def pir_recover(xor_a: bytes, xor_b: bytes) -> bytes:
|
||
if len(xor_a) != len(xor_b):
|
||
raise ValueError("xor length mismatch")
|
||
return bytes(a ^ b for a, b in zip(xor_a, xor_b))
|