diff --git a/Cargo.lock b/Cargo.lock index f049445..3f46f79 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -748,7 +748,7 @@ dependencies = [ [[package]] name = "zkac" -version = "0.4.0" +version = "0.4.1" dependencies = [ "blake2", "chacha20poly1305", diff --git a/Cargo.toml b/Cargo.toml index 06ea262..3bb0110 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zkac" -version = "0.4.0" +version = "0.4.1" edition = "2021" description = "Zero-Knowledge Access Control: BBS+ anonymous credentials (BLS12-381) with encrypted transport (X25519/ChaCha20-Poly1305)" diff --git a/cli/README.md b/cli/README.md index a287982..a800f60 100644 --- a/cli/README.md +++ b/cli/README.md @@ -26,22 +26,18 @@ zkac-node server pin bob localhost:9800 --key zkac-node registry create alice localhost:9800 --roles analyst,operator zkac-node grant alice --server localhost:9800 \ --registry --role analyst --to $BOB_PK_HEX -# (prints pool_index for Bob’s collect) +# (prints pool_index for Bob's collect) -# 4. Two-server XOR PIR needs a second replica with the same server_key + grants pool. -# Example: rsync ~/.zkac/alice/server/ to a temp dir after the grant, then: -# zkac-node serve alice --port 9801 --data-dir /tmp/zkac-replica & -# zkac-node server pin bob localhost:9801 --key - -# 5. Bob lists local creds; optional pending scan (O(n) PIR queries per server) +# 4. Bob lists local creds + pending grants (single-server PIR, no second replica needed) zkac-node credentials list bob -zkac-node credentials list bob --server localhost:9800 --pir-peer localhost:9801 +zkac-node credentials list bob --server localhost:9800 -# 6. Bob collects (primary host in spec, second replica as --pir-peer) -zkac-node collect bob localhost:9800::analyst \ - --pir-peer localhost:9801 --pool-index +# 5. Bob collects (auto-discovers via detection tags; --pool-index is optional) +zkac-node collect bob localhost:9800::analyst +# or with explicit index: +zkac-node collect bob localhost:9800::analyst --pool-index -# 7. Bob authenticates +# 6. Bob authenticates zkac-node auth bob --registry --role analyst --server localhost:9800 ``` @@ -59,13 +55,15 @@ zkac-node auth bob --registry --role analyst --server localhost:98 | `registry get --registry R` | Fetch registry state | | `registry list ` | List registries this user owns locally | | `grant --server S --registry R --role X --to ` | Admin grant (encrypted to recipient pk) | -| `credentials list [--server S …] [--pir-peer P]` | Local credentials; pending grants only with `--pir-peer` (PIR scan) | -| `collect --pir-peer P --pool-index N` | Fetch one grant via two-server XOR PIR | +| `credentials list [--server S …]` | Local credentials + pending grants via detection tags + PIR | +| `collect [--pool-index N]` | Fetch and finalize a pending credential via single-server PIR | | `auth --registry R --role X [--server S]` | Authenticated session | ## Protocol & threat model -See [docs/SECURITY.md](../docs/SECURITY.md) in the repo root. +See [docs/SECURITY.md](../docs/SECURITY.md) in the repo root for the full model, including PIR and detection tags. + +**Operational scaling:** The server grant pool is append-only (claimed rows are tombstones), so pool length grows with every grant. Large pools increase discovery traffic, PIR query size, and server work per retrieval (all linear in pool length). Treat unbounded growth as a potential DoS and capacity risk; mitigations are listed under *Known limitations* and *Future work* in `docs/SECURITY.md`. Transport, BBS+ auth, registry state updates, and issuance queues have separate scaling profiles (CPU dominated by BBS+, state size linear in role count, queue memory); see **Scaling and complexity (transport, credentials, registries)** in the same doc. ## Storage layout @@ -76,5 +74,7 @@ identity.json issuance keypair admin/.json BBS+ admin material for owned registries credentials/_.json received credentials servers/.json pinned server public keys +pir_cache/.json PIR hint metadata (pool_version, n_records) +pir_cache/.bin PIR hint data (cached, keyed by pool_version) server/ (only if you run `serve `) server_key.json, registries/, mailbox/grants_pool.json ``` diff --git a/cli/zkac_cli/client.py b/cli/zkac_cli/client.py index 1100b84..585d69e 100644 --- a/cli/zkac_cli/client.py +++ b/cli/zkac_cli/client.py @@ -5,11 +5,12 @@ from __future__ import annotations import base64 import json import socket +from pathlib import Path import zkac from zkac.tcp import FramedSession, client_handshake_anon -from . import pir, store +from . import store def _b64(data: bytes) -> str: @@ -72,6 +73,71 @@ def _ok(resp: dict) -> dict: return resp +# ── PIR hint cache ─────────────────────────────────────────────────── + +def _cache_dir(userid: str) -> Path: + d = store.user_dir(userid) / "pir_cache" + d.mkdir(parents=True, exist_ok=True) + return d + + +def _server_cache_key(server: str) -> str: + return server.replace(":", "_") + + +def _load_cached_hints(userid: str, server: str, pool_version: str) -> bytes | None: + meta_path = _cache_dir(userid) / f"{_server_cache_key(server)}.json" + bin_path = _cache_dir(userid) / f"{_server_cache_key(server)}.bin" + if not meta_path.exists() or not bin_path.exists(): + return None + meta = json.loads(meta_path.read_text()) + if meta.get("pool_version") != pool_version: + return None + return bin_path.read_bytes() + + +def _save_cached_hints(userid: str, server: str, pool_version: str, + n_records: int, record_bytes: int, hints_bytes: bytes): + key = _server_cache_key(server) + meta = {"pool_version": pool_version, "n_records": n_records, "record_bytes": record_bytes} + (_cache_dir(userid) / f"{key}.json").write_text(json.dumps(meta)) + (_cache_dir(userid) / f"{key}.bin").write_bytes(hints_bytes) + + +def _pir_client(userid: str, framed: FramedSession, server: str) -> tuple[zkac.PirClient, str]: + """Fetch pool_info, load or refresh hints, return (PirClient, pool_version).""" + info = _ok(_mgmt_cmd(framed, {"cmd": "pool_info"})) + n = info["n"] + rb = info["record_bytes"] + pv = info["pool_version"] + + cached = _load_cached_hints(userid, server, pv) + if cached is not None: + return zkac.PirClient(cached, n, rb), pv + + resp = _ok(_mgmt_cmd(framed, {"cmd": "pir_hints"})) + hints_bytes = _unb64(resp["hints_b64"]) + pv = resp["pool_version"] + _save_cached_hints(userid, server, pv, n, rb, hints_bytes) + return zkac.PirClient(hints_bytes, n, rb), pv + + +def _fetch_row( + userid: str, framed: FramedSession, server: str, + pir_client: zkac.PirClient, pool_version: str, pool_index: int, +) -> dict: + q, state = pir_client.query(pool_index) + resp = _ok(_mgmt_cmd(framed, { + "cmd": "pir_query", + "query_b64": _b64(q), + "pool_version": pool_version, + })) + raw = bytes(pir_client.decode(_unb64(resp["answer_b64"]), state)) + return json.loads(raw.rstrip(b"\x00").decode("utf-8")) + + +# ── Public operations ──────────────────────────────────────────────── + def create_registry(userid: str, server: str, role_names: list[str]) -> str: identity = store.load_identity(userid) admin_mat = store.new_admin_material() @@ -178,6 +244,7 @@ def grant(userid: str, server: str, registry_id_hex: str, role_name: str, recipient_pk = bytes.fromhex(recipient_pk_hex) eph_kp = zkac.IssuanceKeypair() ciphertext = eph_kp.encrypt(recipient_pk, payload) + to_tag = zkac.grant_detection_tag(eph_kp.secret_bytes(), recipient_pk) sock, framed = _mgmt_connect(userid, server) try: @@ -188,6 +255,7 @@ def grant(userid: str, server: str, registry_id_hex: str, role_name: str, "registry_id": registry_id_hex, "eph_pk_b64": _b64(eph_kp.public_key_bytes()), "ciphertext_b64": _b64(ciphertext), + "to_tag_b64": _b64(to_tag), "admin_proof_b64": _b64(admin_proof), })) finally: @@ -196,59 +264,47 @@ def grant(userid: str, server: str, registry_id_hex: str, role_name: str, return resp["grant_id"], resp.get("pool_index", -1) -def _pir_recover_row( - framed_a: FramedSession, - framed_b: FramedSession, - n: int, - pool_index: int, -) -> dict: - idx_a, idx_b = pir.pir_query_indices(n, pool_index) - xa = _ok(_mgmt_cmd(framed_a, {"cmd": "pir_fold", "indices": idx_a})) - xb = _ok(_mgmt_cmd(framed_b, {"cmd": "pir_fold", "indices": idx_b})) - raw = pir.pir_recover(_unb64(xa["xor_b64"]), _unb64(xb["xor_b64"])) - return pir.unpad_grant_record(raw) +def _match_tags(userid: str, tags: list[dict]) -> list[int]: + """Return pool indices whose detection tag matches our issuance key.""" + identity = store.load_identity(userid) + receiver_sk = identity["issuance_sk"] + matches = [] + for idx, entry in enumerate(tags): + eph_pk_b64 = entry.get("eph_pk_b64", "") + to_tag_b64 = entry.get("to_tag_b64", "") + if not eph_pk_b64 or not to_tag_b64: + continue + eph_pk = _unb64(eph_pk_b64) + expected = zkac.grant_detection_tag(receiver_sk, eph_pk) + if _unb64(to_tag_b64) == bytes(expected): + matches.append(idx) + return matches -def _fetch_grant_entry_pir( - userid: str, server_a: str, server_b: str, pool_index: int, -) -> dict: - sock_a, fa = _mgmt_connect(userid, server_a) - sock_b, fb = _mgmt_connect(userid, server_b) - try: - a = _ok(_mgmt_cmd(fa, {"cmd": "mail_pool_len"})) - b = _ok(_mgmt_cmd(fb, {"cmd": "mail_pool_len"})) - if a["n"] != b["n"] or a["record_bytes"] != b["record_bytes"]: - raise RuntimeError("PIR peers disagree on pool length or record size") - n = a["n"] - if not (0 <= pool_index < n): - raise RuntimeError("pool_index out of range for current pool") - return _pir_recover_row(fa, fb, n, pool_index) - finally: - sock_a.close() - sock_b.close() - - -def list_pending(userid: str, server: str, pir_peer: str) -> list[dict]: - """Scan the grant pool via two-server XOR PIR (one Chor query per row).""" +def list_pending(userid: str, server: str) -> list[dict]: + """Discover pending grants via detection tags, then PIR-fetch matches.""" identity = store.load_identity(userid) receiver_kp = zkac.IssuanceKeypair.from_secret(identity["issuance_sk"]) info = _mgmt_single(userid, server, {"cmd": "server_info"}) store.pin_server(userid, server, info["server_public_key_b64"]) - info_b = _mgmt_single(userid, pir_peer, {"cmd": "server_info"}) - store.pin_server(userid, pir_peer, info_b["server_public_key_b64"]) - sock_a, fa = _mgmt_connect(userid, server) - sock_b, fb = _mgmt_connect(userid, pir_peer) + sock, framed = _mgmt_connect(userid, server) try: - a = _ok(_mgmt_cmd(fa, {"cmd": "mail_pool_len"})) - b = _ok(_mgmt_cmd(fb, {"cmd": "mail_pool_len"})) - if a["n"] != b["n"] or a["record_bytes"] != b["record_bytes"]: - raise RuntimeError("PIR peers disagree on pool length or record size") - n = a["n"] + tags_resp = _ok(_mgmt_cmd(framed, {"cmd": "pool_tags"})) + tags = tags_resp["tags"] + matches = _match_tags(userid, tags) + + if not matches: + return [] + + pir_cl, pv = _pir_client(userid, framed, server) results = [] - for i in range(n): - row = _pir_recover_row(fa, fb, n, i) + for idx in matches: + try: + row = _fetch_row(userid, framed, server, pir_cl, pv, idx) + except Exception: + continue if row.get("claimed"): continue try: @@ -257,29 +313,27 @@ def list_pending(userid: str, server: str, pir_peer: str) -> list[dict]: plaintext = json.loads(receiver_kp.decrypt(eph_pk, ct)) results.append({ "grant_id": row["grant_id"], - "pool_index": i, + "pool_index": idx, "registry_id": plaintext.get("registry_id", "?"), "role_name": plaintext.get("role_name", "?"), }) except Exception: results.append({ "grant_id": row.get("grant_id", "?"), - "pool_index": i, + "pool_index": idx, "registry_id": "?", "role_name": "(undecryptable)", }) return results finally: - sock_a.close() - sock_b.close() + sock.close() def collect( userid: str, spec: str, *, - pir_peer: str, - pool_index: int, + pool_index: int | None = None, ) -> dict: server, registry_id_hex, role_name = parse_spec(spec) identity = store.load_identity(userid) @@ -287,27 +341,60 @@ def collect( info = _mgmt_single(userid, server, {"cmd": "server_info"}) store.pin_server(userid, server, info["server_public_key_b64"]) - info_b = _mgmt_single(userid, pir_peer, {"cmd": "server_info"}) - store.pin_server(userid, pir_peer, info_b["server_public_key_b64"]) - row = _fetch_grant_entry_pir(userid, server, pir_peer, pool_index) - if row.get("claimed"): - raise RuntimeError("grant row is already claimed") + sock, framed = _mgmt_connect(userid, server) try: - eph_pk = _unb64(row["eph_pk_b64"]) - ct = _unb64(row["ciphertext_b64"]) - plaintext = json.loads(receiver_kp.decrypt(eph_pk, ct)) - except Exception as exc: - raise RuntimeError( - "PIR row did not decrypt for this user (wrong pool_index or desynced peers)" - ) from exc - if (plaintext.get("registry_id") != registry_id_hex or - plaintext.get("role_name") != role_name): - raise RuntimeError( - "PIR row does not match this collect spec (check pool_index and peers)" - ) - target_grant_id = row["grant_id"] - target_payload = plaintext + if pool_index is None: + tags_resp = _ok(_mgmt_cmd(framed, {"cmd": "pool_tags"})) + tags = tags_resp["tags"] + matches = _match_tags(userid, tags) + if not matches: + raise RuntimeError("no matching grants found in pool") + + pir_cl, pv = _pir_client(userid, framed, server) + found = None + for idx in matches: + try: + row = _fetch_row(userid, framed, server, pir_cl, pv, idx) + except Exception: + continue + if row.get("claimed"): + continue + try: + eph_pk = _unb64(row["eph_pk_b64"]) + ct = _unb64(row["ciphertext_b64"]) + plaintext = json.loads(receiver_kp.decrypt(eph_pk, ct)) + except Exception: + continue + if (plaintext.get("registry_id") == registry_id_hex and + plaintext.get("role_name") == role_name): + found = (idx, row, plaintext) + break + if found is None: + raise RuntimeError( + f"no unclaimed grant for {registry_id_hex}:{role_name} in pool" + ) + pool_index, target_row, target_payload = found + else: + pir_cl, pv = _pir_client(userid, framed, server) + target_row = _fetch_row(userid, framed, server, pir_cl, pv, pool_index) + if target_row.get("claimed"): + raise RuntimeError("grant row is already claimed") + try: + eph_pk = _unb64(target_row["eph_pk_b64"]) + ct = _unb64(target_row["ciphertext_b64"]) + target_payload = json.loads(receiver_kp.decrypt(eph_pk, ct)) + except Exception as exc: + raise RuntimeError("PIR row did not decrypt for this user") from exc + if (target_payload.get("registry_id") != registry_id_hex or + target_payload.get("role_name") != role_name): + raise RuntimeError( + "PIR row does not match this collect spec" + ) + finally: + sock.close() + + target_grant_id = target_row["grant_id"] _mgmt_single(userid, server, { "cmd": "claim_grant", diff --git a/cli/zkac_cli/main.py b/cli/zkac_cli/main.py index eee5abd..e4554f4 100644 --- a/cli/zkac_cli/main.py +++ b/cli/zkac_cli/main.py @@ -111,9 +111,13 @@ def _cmd_grant(args): print(f" grant id: {gid}") print(f" pool index: {pool_index}") print(f" recipient can collect with:") + print( + f" zkac-node collect {args.server}:{args.registry}:{args.role}" + ) + print(f" or with explicit index:") print( f" zkac-node collect {args.server}:{args.registry}:{args.role} " - f"--pir-peer --pool-index {pool_index}" + f"--pool-index {pool_index}" ) @@ -136,18 +140,11 @@ def _cmd_credentials_list(args): print("\n(no servers to query; pass --server host:port to check for pending)") return - if args.pir_peer is None: - print( - "\n(skipping pending grants: two-server XOR PIR requires " - "--pir-peer ; local credentials are listed above)" - ) - return - - print("\npending grants (PIR scan, O(n) per server):") + print("\npending grants:") any_pending = False for srv in servers: try: - grants = client.list_pending(args.userid, srv, args.pir_peer) + grants = client.list_pending(args.userid, srv) except Exception as exc: print(f" [{srv}] error: {exc}") continue @@ -170,7 +167,6 @@ def _cmd_collect(args): result = client.collect( args.userid, args.spec, - pir_peer=args.pir_peer, pool_index=args.pool_index, ) print("collected credential") @@ -269,29 +265,17 @@ def main(): c = cred_sub.add_parser("list", help="show local + pending credentials") c.add_argument("userid") c.add_argument("--server", action="append", help="server to query (host:port); repeatable") - c.add_argument( - "--pir-peer", - default=None, - metavar="HOST:PORT", - help="second replica with the same grant pool; required to list pending via XOR PIR", - ) c.set_defaults(func=_cmd_credentials_list) # collect c = sub.add_parser("collect", help="fetch and finalize a pending credential") c.add_argument("userid") c.add_argument("spec", help="host:port:registry_id:role") - c.add_argument( - "--pir-peer", - required=True, - metavar="HOST:PORT", - help="second replica with an identical grant pool (two-server XOR PIR)", - ) c.add_argument( "--pool-index", type=int, - required=True, - help="grant row index from admin (printed on grant)", + default=None, + help="grant row index (optional; auto-discovered via detection tags if omitted)", ) c.set_defaults(func=_cmd_collect) diff --git a/cli/zkac_cli/pir.py b/cli/zkac_cli/pir.py deleted file mode 100644 index f06f56f..0000000 --- a/cli/zkac_cli/pir.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Private information retrieval helpers. - -Two-server XOR PIR (Chor–Goldreich–Kushilevitz–Sudan style): - - For a database of n fixed-length records D[0],...,D[n-1], to retrieve D[i]: - - Pick a uniformly random subset S ⊆ {0,...,n-1}. - - Server A returns ⊕_{j∈S} D[j] - - Server B returns ⊕_{j∈S⊕{i}} D[j] - - Client XORs the two replies → D[i]. - -Privacy holds if the two servers do **not** collude. A single host running both -servers learns both queries and can recover i — use two independent operators -or hosts for real privacy. - -Recipients fetch rows via **two-server XOR PIR** (``pir_fold`` on each replica). -There is no full-database download endpoint; listing all decryptable grants for -a user requires **O(n) PIR queries** (one Chor query per pool index) when -scanning the pool. -""" - -from __future__ import annotations - -import json -import secrets -from typing import Iterable - -# Fixed record size for XOR PIR (padded JSON grant entry). -PIR_RECORD_BYTES = 64 * 1024 - - -def pad_grant_record(entry: dict) -> bytes: - """Serialize grant entry to fixed-length bytes for XOR folding.""" - raw = json.dumps(entry, separators=(",", ":"), sort_keys=True).encode("utf-8") - if len(raw) > PIR_RECORD_BYTES: - raise ValueError(f"grant record exceeds PIR_RECORD_BYTES ({PIR_RECORD_BYTES})") - return raw + b"\x00" * (PIR_RECORD_BYTES - len(raw)) - - -def unpad_grant_record(buf: bytes) -> dict: - return json.loads(buf.rstrip(b"\x00").decode("utf-8")) - - -def xor_bytes_many(chunks: Iterable[bytes]) -> bytes: - it = iter(chunks) - first = next(it, None) - if first is None: - return b"\x00" * PIR_RECORD_BYTES - acc = bytearray(first) - for c in it: - if len(c) != len(acc): - raise ValueError("length mismatch in xor_bytes_many") - for j in range(len(acc)): - acc[j] ^= c[j] - return bytes(acc) - - -def random_subset(n: int) -> set[int]: - """Uniform random subset of {0,...,n-1}.""" - if n <= 0: - return set() - return {j for j in range(n) if secrets.randbelow(2)} - - -def symdiff_one(s: set[int], i: int) -> set[int]: - """S ⊕ {i} as symmetric difference.""" - t = set(s) - if i in t: - t.remove(i) - else: - t.add(i) - return t - - -def pir_query_indices(n: int, i: int) -> tuple[list[int], list[int]]: - """Build two index lists for servers A and B to XOR-fold records.""" - if not (0 <= i < n): - raise ValueError("index out of range") - s = random_subset(n) - sa = sorted(s) - sb = sorted(symdiff_one(s, i)) - return sa, sb - - -def pir_recover(xor_a: bytes, xor_b: bytes) -> bytes: - if len(xor_a) != len(xor_b): - raise ValueError("xor length mismatch") - return bytes(a ^ b for a, b in zip(xor_a, xor_b)) diff --git a/cli/zkac_cli/server.py b/cli/zkac_cli/server.py index fb80d08..b000db6 100644 --- a/cli/zkac_cli/server.py +++ b/cli/zkac_cli/server.py @@ -10,12 +10,11 @@ The server stores only cryptographically verified opaque blobs: /server_key.json Schnorr keypair /registries/.state raw RegistryState bytes /registries/.cert raw state cert bytes - /mailbox/grants_pool.json anonymous append-only grant pool (PIR-friendly) + /mailbox/grants_pool.json anonymous append-only grant pool -Grants are **not** keyed by recipient public key on the server; delivery uses a -shared pool with **two-server XOR PIR** only: clients issue ``pir_fold`` queries -to two replicas with identical pools (``mail_pool_len`` + ``pir_fold``). There -is no bulk grant export command. +Recipients discover matching grants via a cheap detection-tag index +(``pool_tags``) and retrieve individual rows via single-server +LWE-based PIR (``pir_query`` with precomputed ``pir_hints``). """ from __future__ import annotations @@ -31,8 +30,6 @@ from pathlib import Path import zkac from zkac.tcp import FramedSession, server_handshake_anon -from . import pir - def _b64(data: bytes) -> str: return base64.b64encode(data).decode() @@ -42,6 +39,15 @@ def _unb64(s: str) -> bytes: return base64.b64decode(s) +def _pad_grant_record(entry: dict) -> bytes: + raw = json.dumps(entry, separators=(",", ":"), sort_keys=True).encode("utf-8") + if len(raw) > zkac.PIR_RECORD_BYTES: + raise ValueError( + f"grant record exceeds PIR_RECORD_BYTES ({zkac.PIR_RECORD_BYTES})" + ) + return raw + b"\x00" * (zkac.PIR_RECORD_BYTES - len(raw)) + + # ── Opaque server storage ───────────────────────────────────────────── class _ServerStore: @@ -55,6 +61,8 @@ class _ServerStore: self._reg_dir.mkdir(parents=True, exist_ok=True) self._mbox_dir.mkdir(parents=True, exist_ok=True) self._lock = threading.Lock() + self._pir_server: zkac.PirServer | None = None + self._pir_dirty = True self._migrate_legacy_mailbox() # ── server key ──────────────────────────────────────────────────── @@ -129,38 +137,67 @@ class _ServerStore: json.dumps({"version": 1, "records": records}, indent=2) ) + # ── PIR server rebuild ──────────────────────────────────────────── + + def _ensure_pir(self) -> zkac.PirServer: + if self._pir_server is not None and not self._pir_dirty: + return self._pir_server + records = self._load_pool() + packed = [_pad_grant_record(r) for r in records] + db = zkac.PirDatabase(packed, zkac.PIR_RECORD_BYTES) + self._pir_server = zkac.PirServer(db) + self._pir_dirty = False + return self._pir_server + # ── anonymous grant pool ───────────────────────────────────────── def post_grant(self, entry: dict) -> tuple[str, int]: grant_id = os.urandom(16).hex() row = {"grant_id": grant_id, "claimed": False, **entry} - try: - pir.pad_grant_record(row) - except ValueError as exc: - raise ValueError( - f"grant entry too large for PIR record size ({pir.PIR_RECORD_BYTES} bytes): {exc}" - ) from exc + _pad_grant_record(row) with self._lock: records = self._load_pool() pool_index = len(records) records.append(row) self._save_pool(records) + self._pir_dirty = True return grant_id, pool_index - def pool_len(self) -> tuple[int, int]: + def pool_info(self) -> dict: with self._lock: - n = len(self._load_pool()) - return n, pir.PIR_RECORD_BYTES + pir = self._ensure_pir() + return { + "n": pir.n_records, + "record_bytes": pir.record_bytes, + "pool_version": bytes(pir.version()).hex(), + } - def pir_fold(self, indices: list[int]) -> bytes: + def pool_tags(self) -> tuple[list[tuple[str, str]], str]: with self._lock: records = self._load_pool() - uniq = sorted(set(indices)) - for i in uniq: - if not (0 <= i < len(records)): - raise ValueError("pir_fold index out of range") - chunks = [pir.pad_grant_record(records[i]) for i in uniq] - return pir.xor_bytes_many(chunks) + pir = self._ensure_pir() + version = bytes(pir.version()).hex() + tags = [] + for r in records: + tags.append(( + r.get("eph_pk_b64", ""), + r.get("to_tag_b64", ""), + )) + return tags, version + + def pir_hints(self) -> tuple[bytes, str]: + with self._lock: + pir = self._ensure_pir() + return bytes(pir.hints()), bytes(pir.version()).hex() + + def pir_answer(self, query_b64: str, pool_version: str) -> tuple[str, str] | None: + with self._lock: + pir = self._ensure_pir() + current = bytes(pir.version()).hex() + if current != pool_version: + return None + ans = pir.answer(_unb64(query_b64)) + return _b64(ans), current def claim_grant(self, grant_id: str) -> dict | None: with self._lock: @@ -170,6 +207,7 @@ class _ServerStore: e["claimed"] = True records[i] = e self._save_pool(records) + self._pir_dirty = True return dict(e) return None @@ -216,17 +254,33 @@ def _dispatch(cmd: dict, mgr: zkac.RegistryManager, store: _ServerStore, entry = { "eph_pk_b64": cmd["eph_pk_b64"], "ciphertext_b64": cmd["ciphertext_b64"], + "to_tag_b64": cmd.get("to_tag_b64", ""), } gid, pool_index = store.post_grant(entry) return {"ok": True, "grant_id": gid, "pool_index": pool_index} - if action == "mail_pool_len": - n, rec_b = store.pool_len() - return {"ok": True, "n": n, "record_bytes": rec_b} + if action == "pool_info": + info = store.pool_info() + return {"ok": True, **info} - if action == "pir_fold": - raw = store.pir_fold(list(cmd.get("indices", []))) - return {"ok": True, "xor_b64": _b64(raw)} + if action == "pool_tags": + tags, version = store.pool_tags() + entries = [ + {"eph_pk_b64": epk, "to_tag_b64": tag} + for epk, tag in tags + ] + return {"ok": True, "tags": entries, "pool_version": version} + + if action == "pir_hints": + hints_bytes, version = store.pir_hints() + return {"ok": True, "hints_b64": _b64(hints_bytes), "pool_version": version} + + if action == "pir_query": + result = store.pir_answer(cmd["query_b64"], cmd["pool_version"]) + if result is None: + return {"error": "stale_version"} + ans_b64, version = result + return {"ok": True, "answer_b64": ans_b64, "pool_version": version} if action == "claim_grant": entry = store.claim_grant(cmd["grant_id"]) @@ -283,7 +337,6 @@ def _handle_conn(conn: socket.socket, addr: tuple, node: zkac.Node, } framed.send(json.dumps(resp).encode()) - # keep session open for app traffic while True: try: data = framed.recv() diff --git a/docs/PYTHON_API.md b/docs/PYTHON_API.md index 318c85a..8952fdb 100644 --- a/docs/PYTHON_API.md +++ b/docs/PYTHON_API.md @@ -1,6 +1,6 @@ # ZKAC Python API Reference -Version 0.4.0. Cryptographic stack: **BBS+** on BLS12-381 (credentials), **X25519** + **ChaCha20-Poly1305** (transport), **Schnorr/Ristretto255** (identity), **BLAKE2b** (role IDs, signatures). +Version 0.4.1. Cryptographic stack: **BBS+** on BLS12-381 (credentials), **X25519** + **ChaCha20-Poly1305** (transport), **Schnorr/Ristretto255** (identity), **BLAKE2b** (role IDs, signatures). ```python import zkac diff --git a/docs/SECURITY.md b/docs/SECURITY.md index b35c6b9..0b45ab4 100644 --- a/docs/SECURITY.md +++ b/docs/SECURITY.md @@ -1,4 +1,4 @@ -# Security model and audit notes (ZKAC 0.4.0) +# Security model and audit notes (ZKAC 0.4.1) This document summarizes the design, residual risks, and recommendations for operators integrating **ZKAC**. It is not a substitute for independent review before high-assurance deployment. @@ -21,6 +21,8 @@ This document summarizes the design, residual risks, and recommendations for ope | Credentials | BBS+ on BLS12-381 (zkryptium), SHAKE256 ciphersuite | Blind issuance, ZK presentations | | Role IDs | BLAKE2b-512 (truncated to 32 bytes) | Opaque role identifiers | | Grant delivery | X25519 static/ephemeral DH, HKDF-SHA256, ChaCha20-Poly1305 | E2E-encrypted credential grants | +| Grant discovery | X25519 DH + BLAKE2b-512 truncated to 16 bytes | Detection tags for anonymous matching | +| PIR | LWE (n=1024, q=2^32, p=256, σ=6.4) | Single-server private record retrieval | ## Protocol flow @@ -42,22 +44,124 @@ Management commands (`create_registry`, `post_grant`, etc.) and BBS+ role authen ### Grant delivery (admin → recipient, through server) -Grants live in a single **anonymous append-only pool** (no `recipient_pk` on the server). Recipients fetch rows only via **two-server XOR PIR** (`mail_pool_len` + `pir_fold` on two replicas with identical pools). Each query reveals only a random subset-XOR to each server; *which* logical index is recovered is hidden if the replicas do not collude. There is **no** full-pool download API. Scanning all rows for “pending” uses **O(n) PIR round-trips** (one Chor-style query per index). +Grants live in a single **anonymous append-only pool** (no recipient identifier on the server). Each grant entry carries an ephemeral public key, the E2E-encrypted credential payload, and a 16-byte **detection tag**. + +**Discovery (cheap, no PIR):** The server exposes a `pool_tags` command returning all `(eph_pk, tag)` pairs. The client computes `X25519(my_issuance_sk, eph_pk_j)` for each entry and derives the expected tag via `BLAKE2b-512("zkac-grant-tag" || shared_secret)[..16]`. Matching entries are the client's grants. This scan is a single round-trip transferring ~48 bytes per pool entry and is computed locally. + +**Retrieval (PIR):** Matching rows are fetched individually via LWE-based single-server PIR (`pir_query`). The server precomputes hints (`H = D · A^T` where A is a seeded public matrix); the client caches hints keyed by `pool_version` and only refetches when the pool changes. ``` Admin Server (opaque relay) Recipient |-- post_grant ------->| | | (admin_proof, | appends to pool: | - | eph_pk, | {grant_id, eph_pk, ct} | - | ciphertext) | (no recipient address) | - | |<-- pir_fold (replica A/B) --| - | |--- XOR of subset rows ----->| - | | | combine → one row - | | | trial-decrypt + | eph_pk, | {grant_id, eph_pk, | + | ciphertext, | ciphertext, to_tag} | + | to_tag) | (no recipient address) | + | | | + | |<-- pool_tags --------------| + | |--- [(eph_pk, tag), …] ---->| + | | | local tag match + | |<-- pir_query(j) -----------| + | |--- answer ----------------->| + | | | PIR decode → row + | | | trial-decrypt → cred | |<-- claim_grant ------------| | | (tombstone / claimed) | ``` +## PIR security (LWE) + +Private information retrieval uses the **SimplePIR** construction (first layer of DoublePIR, Henzinger–Hong–Corrigan-Gibbs–Meiklejohn–Vaikuntanathan, USENIX Security '23). Security rests on the **decisional Learning With Errors (LWE)** assumption: + +- **Parameters:** LWE dimension n=1024, ciphertext modulus q=2^32, plaintext modulus p=256, discrete Gaussian noise σ=6.4. +- **Classical security:** ~128 bits (based on lattice estimator analysis at these parameters). +- **Post-quantum:** LWE is believed hard for quantum computers; no known quantum algorithm breaks it in polynomial time. +- **Single-server:** No non-collusion assumption. Privacy holds against an honest-but-curious server that inspects all queries and answers. + +The PIR scheme is **honest-but-curious only**: a malicious server can return incorrect answers. This is acceptable because grant payloads are E2E-encrypted (ChaCha20-Poly1305) and credential finalization validates BBS+ blind signatures — a corrupted PIR answer causes decryption or BBS+ verification to fail, not credential forgery. + +## Detection tags + +Each grant carries a 16-byte detection tag: `BLAKE2b-512("zkac-grant-tag" || X25519(eph_sk, recipient_pk))[..16]`. + +**Privacy properties:** +- The tag is a deterministic function of the shared secret, which requires knowledge of either the ephemeral secret key or the recipient's issuance secret key to compute. An observer (including the server) who knows neither key cannot link a tag to a recipient. +- The `pool_tags` list is equivalent to what the server already sees at grant insertion time — broadcasting it to querying clients reveals no new information. +- A client downloading `pool_tags` reveals that it is checking for pending grants, but not which entries matched. Matching is a local computation. +- Tags have 128-bit collision resistance (16 bytes); false positives are negligible. + +## Scaling and complexity (transport, credentials, registries) + +This section complements the **grant pool / PIR** analysis above. Asymptotics use: **R** = number of roles in one registry state, **G** = number of registries hosted in memory, **L** = byte length of an application payload (JSON management command or auth packet body after decryption). + +### Transport and session crypto + +| Operation | Time | Bandwidth / memory | +|-----------|------|----------------------| +| Handshake (`connect` / `accept`) | **O(1)** | Fixed 32-byte handshake messages; one X25519 DH, HKDF, ChaCha open. | +| Server identity proof | **O(1)** | Schnorr verify on Ristretto255 over a short transcript-derived message. | +| `Session::encrypt` / `decrypt` per frame | **O(L)** | ChaCha20-Poly1305 is linear in payload size; replay window checks are **O(1)** per direction. | + +**Bottlenecks:** negligible compared to BBS+ unless payloads are pushed toward frame limits. Python framing caps TCP payloads at `MAX_BBS_AUTH_PROOF_BYTES + 4 KiB` (~260 KiB), bounding worst-case allocations per read. + +### BBS+ credentials (issuance and verification) + +| Operation | Time | Notes | +|-----------|------|-------| +| Blind `issue_blind` / `finalize` (issuer / member) | **O(1)** in R and G | Dominated by BLS12-381 and BBS+ proof math in zkryptium (pairings, multi-scalar muls); not sensitive to registry count or pool size. | +| `present` (proof generation) | **O(1)** | Produces a presentation bound to a nonce (e.g. transcript hash). | +| `verify_presentation` | **O(1)** | One proof check against one issuer public key. | +| Proof size on the wire | **≤ 256 KiB** | `MAX_BBS_AUTH_PROOF_BYTES`; caps attacker-controlled allocation for auth packets. | + +**Bottlenecks:** **BBS+ verify and present** dominate CPU on authenticated paths (role auth, admin proofs for `post_grant`, registry state certification). Cost is **per event**, not per grant in pool, but high QPS auth still needs horizontal scaling or hardware tuned for pairing-heavy crypto. + +### Registry state (client-managed blob on server) + +| Operation | Time | Size | +|-----------|------|------| +| `RegistryState::serialize` / `deserialize` | **O(R)** | Linear in number of role entries (each: fixed `role_id`, variable-length issuer pk bytes, epoch). | +| `state_hash` | **O(|state_bytes|)** ≈ **O(R)** | One BLAKE2b-512 over the serialized state. | +| `certify` / `verify_cert` | Same as BBS+ present / verify | One presentation over `state_hash`. | +| `RegistryManager::update` | **O(R)** for cache rebuild | Deserializes old + new state, verifies cert and version chain, rebuilds `RoleRegistry` cache by iterating all roles (`build_role_cache`). | + +**Bandwidth:** `get_registry` / `create_registry` / `update_registry` move the **full serialized state** and **state certificate** each time — **O(R)** bytes per round-trip. Very large role lists mean large management frames and more CPU on every update. + +**Bottlenecks:** **Large R** (many roles in one registry) inflates state blob size, hash work, and cache rebuild. **Frequent updates** multiply BBS+ certify/verify cost. + +### RegistryManager (multi-registry server) + +| Operation | Time | Notes | +|-----------|------|-------| +| `create` / `get` / `update` / `verify_*` | **O(1)** expected in G | Hash map on `registry_id`; work is on **one** stored registry at a time. | +| In-memory footprint | **O(G × (|state| + |cert| + queues))** | Each registry holds state bytes, cert bytes, `RoleRegistry` cache, and issuance **queues** (below). | + +**Bottlenecks:** **G** grows with every distinct registry the server accepts — mostly a **RAM** and operational concern. Per-request CPU is still dominated by BBS+ and (for managed flows) issuance queue handling. + +### Issuance request queues (`RegistryManager`) + +| Structure | Growth | Risk | +|-----------|--------|------| +| `pending_requests` / `granted` maps | **Unbounded** per registry unless the application drains them | A client could queue many `queue_issuance_request` entries; server memory grows with pending items. Not the same as the grant pool file, but a similar **resource exhaustion** class. | + +**Bottlenecks:** **Queue depth** per registry; mitigations are rate limits, caps, or TTL policies at the application layer (not enforced in core today). + +### Issuance encryption (X25519 + ChaCha) + +| Operation | Time | +|-----------|------| +| `encrypt` / `decrypt` (grant payloads, admin replies) | **O(L)** for payload length L | + +Negligible vs BBS+ for typical small JSON blobs. + +### Summary: dominant costs outside the grant pool + +1. **BBS+ present/verify** on every auth, admin proof, and registry certificate path — **pairing-heavy**, fixed per operation, proof capped at 256 KiB. +2. **Registry state size and `update`** — **O(R)** serialization, hashing, and full cache rebuild. +3. **Issuance queues** — **unbounded** pending entries per registry if abused. +4. **Transport** — **O(L)** per frame; handshake **O(1)**. + +The **grant pool** remains the subsystem whose **per-operation** cost scales with **pool length n** (discovery, PIR query, PIR answer compute); the rest of the protocol scales mainly with **roles per registry**, **registry count**, and **proof operations per session**, not with anonymous pool size. + ## Threats considered ### Network attacker (passive) @@ -76,13 +180,15 @@ Admin Server (opaque relay) Recipient - Can **learn** opaque `role_id`, current epoch, and that *some* valid member authenticated. - Sees `registry_id` values (needed for routing) but not role names or registry contents beyond opaque state bytes. -- Sees `eph_pk` and ciphertext per grant in the anonymous pool, and pool size / timing of syncs, but cannot decrypt grant payloads. It does **not** see a per-recipient mailbox key for addressing. +- Sees `eph_pk`, `to_tag`, and ciphertext per grant in the anonymous pool, and pool size / timing of syncs, but cannot decrypt grant payloads or link tags to recipients. +- Sees PIR queries, which are LWE-encrypted under the decisional LWE assumption — cannot determine which pool index the client is retrieving (single-server, no collusion needed). - **Cannot** forge BBS+ credentials without the issuer secret key. - **Cannot** learn `member_secret` from presentations under the BBS+ security assumptions. - **Cannot** distinguish which specific member authenticated among valid credential holders (unlinkability holds against the verifier for distinct presentation headers). - **Cannot** learn the client's long-term public key — it is never transmitted during handshake or auth. - **Cannot** perform admin operations (registry updates, grant posting) without a valid admin BBS+ credential. - **Cannot** correlate a recipient's mailbox identity with their authenticated sessions (different keys, unlinkable proofs). +- **Can** censor grants by omitting tags from `pool_tags` or returning corrupted PIR answers. Corrupted answers are caught by E2E decryption / BBS+ verification failures. Censorship is a residual operational risk; cross-checking pool hashes across replicas mitigates it. ### Malicious client @@ -93,6 +199,7 @@ Admin Server (opaque relay) Recipient - **Auth packet size:** Proof length is capped (`MAX_BBS_AUTH_PROOF_BYTES`, 256 KiB) to bound allocations. - **Handshake:** Fixed 32-byte messages; no variable-length handshake parsing. +- **Grant pool growth:** The anonymous pool is append-only with tombstoned rows (`claimed`), so **pool length `n` never shrinks** on disk. A malicious or careless admin can grow `n` without bound: larger `pool_tags` downloads, longer PIR hint **recomputation** when the pool version bumps, and **per-query** PIR cost linear in `n` (see Known limitations). This is a **storage and workload amplification** vector, not credential forgery. Mitigation belongs in future work (pool caps, compaction, generations). - General packet limits should still be enforced at the application layer (total message size, rate limits). ## Key distribution @@ -114,7 +221,7 @@ Recommended strategies: 4. **Epoch revocation:** On compromise or policy change, call `set_epoch` and re-issue credentials only to legitimate members; old credentials become invalid at verification time. 5. **Registry integrity:** Registry state is integrity-protected by BBS+ state certificates (admin must sign updates). The server verifies these certificates before accepting changes. 6. **Role ID privacy:** `role_id` is a hash of the role name only if you use `role_id("myrole")`; treat role names as secrets if enumeration is a concern, or derive role IDs with an additional secret salt known to members. -7. **Recipient addressing:** Admins encrypt grants to the recipient’s issuance public key off-server; that key is not used as a server-side mailbox index. Recipients are identified to the issuer out-of-band only. +7. **Recipient addressing:** Admins encrypt grants to the recipient's issuance public key off-server; that key is not used as a server-side mailbox index. Recipients are identified to the issuer out-of-band only. ## Implementation notes (audit checklist) @@ -129,6 +236,8 @@ Recommended strategies: - [x] Admin proofs in `post_grant` are bound to the session transcript hash (no separate nonce); the CLI uses **one TCP session per grant** so each proof uses a fresh transcript. - [x] After collect, the client persists the server public key from `server_info` (never a placeholder key). - [x] Server stores only opaque state bytes, state certs, and encrypted grant blobs (no role names, no user IDs). +- [x] PIR queries are LWE-encrypted; the server cannot determine the queried index. +- [x] Detection tags are derived from X25519 shared secrets and cannot be linked to recipients by the server. - [ ] **External:** Python bindings surface raw bytes; callers must not log secrets (`secret_key_bytes`, `member_secret`, `prover_blind`). - [ ] **External:** Use secure randomness from the OS (library uses OS RNG for key generation paths exposed in Rust). @@ -138,21 +247,28 @@ Recommended strategies: - **Anonymous handshake (`complete_connect_anon`):** The client verifies the server's identity but does not authenticate itself during the handshake. BBS+ auth is sent as an application-layer message inside the encrypted session, not as part of the handshake. This allows the same channel for both anonymous management and authenticated role access. - **Server-only identity proof:** Only the server signs the transcript. Adding client long-term signing would break BBS+ unlinkability (the server could correlate sessions by client public key). Client authentication is handled entirely by the anonymous BBS+ credential. - **Deterministic Schnorr nonces:** The signing nonce is derived as `H("zkac-schnorr-nonce" || sk || msg)`, eliminating a class of RNG-failure attacks (cf. PS3 ECDSA, Sony 2010). Same key + same message = same signature. -- **Anonymous grant pool:** Grant entries contain only `(eph_pk, ciphertext)` plus stable row metadata — no registry ID or role name. Recipients find their grants by trial-decrypting after two-server XOR PIR (or an O(n) PIR scan over the pool). Pool rows use tombstones (`claimed`) so indices stay stable for replicated PIR. +- **Anonymous grant pool:** Grant entries contain `(eph_pk, ciphertext, to_tag)` plus stable row metadata — no registry ID or role name. Recipients discover their grants via detection tags and retrieve them via LWE PIR. Pool rows use tombstones (`claimed`) so indices stay stable for PIR hints. - **No user IDs on server:** The server has no concept of user accounts. It is a stateless relay authenticated only by cryptographic proofs. -- **One session per admin grant (CLI):** Each `post_grant` runs in its own connection so `verify_admin` nonces are not reused across grants in a single session. Registry updates use separate connections for `get_registry` and `update_registry`. Collect uses separate connections for `server_info`, pool fetch / PIR, `claim_grant`, and `get_registry` so those operations are not tied to one transcript. +- **Single-server PIR (LWE):** Eliminates the two-server non-collusion assumption of the previous XOR PIR design. Query privacy rests on decisional LWE, not operational trust in multiple server operators. +- **Detection tags for discovery:** A 16-byte tag derived from X25519 DH allows O(n) local matching from a cheap bulk download, reducing PIR usage from O(n) queries to O(matches) queries per scan. +- **One session per admin grant (CLI):** Each `post_grant` runs in its own connection so `verify_admin` nonces are not reused across grants in a single session. ## Known limitations -- **No post-quantum** primitives: classical security assumptions only. - **Epoch granularity:** Revocation is coarse (epoch bump); plan issuance and rotation policy accordingly. - **zkryptium dependency:** Security follows the underlying crate and BLS12-381/BBS+ standards; keep dependencies updated. - **Key distribution:** The library provides the cryptographic mechanism; initial key distribution is an application-layer responsibility. -- **Pool metadata:** Each replica sees `pir_fold` subset queries (random-looking index sets) and timing. Two-server XOR PIR hides the target index from each server if they do not collude; running both replicas under one operator does not provide that privacy. A full-pool scan issues **n** PIR queries and has high cost; the issuer should send **`pool_index` out-of-band** so the recipient runs **one** PIR retrieval for collect. +- **Honest-but-curious PIR:** The server can return incorrect PIR answers. Corrupted answers are caught by E2E decryption / BBS+ verification, but censorship (omitting grants) is not detected at the PIR layer. Cross-replica hash comparison or a transparency log can mitigate this. +- **Hint size:** PIR hints are approximately `record_bytes × N_LWE × 4` bytes (~256 MB for 64 KiB records with N_LWE=1024). Hints are cached client-side and only refetched when the pool version changes. +- **Unbounded grant pool:** Rows are never removed from the pool file; only marked claimed. Pool length `n` therefore grows monotonically with every posted grant. That increases discovery traffic (`pool_tags` is O(n)), PIR query size (O(n) bytes per query), server work per PIR answer (O(n × record_bytes)), and hint **rebuild** cost when the pool changes (O(n × record_bytes × N_LWE)). Operators should plan for bounded pools or archival; the codebase does not yet enforce limits. ## Future work -- **Single-server sublinear PIR:** The CLI uses **two-server XOR PIR** (Chor-style) only. **Single-server** private information retrieval with **sublinear** client communication (e.g. **SealPIR**, **DoublePIR**, or other lattice / homomorphic-encryption–based schemes) is **not** implemented; adding it would require new dependencies, fixed database encoding, and a distinct query/response protocol. That would allow a lone replica without a non-colluding peer, at the cost of heavier crypto and implementation complexity. +- **Bounded grant pool and anti-DoS:** Introduce explicit **pool caps**, **rate limits** on `post_grant`, **per-registry quotas**, or **pool generations** (rotate to a fresh empty pool while archiving the old one). Optionally **compact** the on-disk pool by rewriting only unclaimed rows and bumping a generation id so PIR indices stay meaningful without retaining every tombstone forever. Any design must preserve stable addressing for in-flight collects or migrate clients with explicit pool ids. +- **Scale beyond large `n`:** Today’s bottleneck is **linear cost in pool length `n`** for each PIR retrieval: client upload ~4n bytes per query, server matrix–vector multiply O(n × record_bytes), and discovery O(n). For very large pools, future work includes **sublinear-communication PIR** (e.g. full DoublePIR second layer), **sharded pools** with client-side routing, **streaming or chunked hints**, or **moving heavy work off the hot path** (precomputed answers, CDN for hints) — trading complexity, trust, or privacy for throughput. +- **DoublePIR second layer:** The current implementation uses SimplePIR (the first layer of DoublePIR). For very large pools where per-query answer size matters, a second SimplePIR layer can compress answers without API changes. +- **Verifiable PIR:** Adding a commitment to the pool state (e.g. Merkle tree or KZG) and proof of correct answer computation would defend against malicious server responses beyond what E2E encryption catches. +- **Pool commitment / transparency:** Publishing a hash of `(pool_version, hints, tags)` to a public log or allowing cross-replica comparison would detect censorship by a malicious server. ## Reporting issues diff --git a/fuzz/Cargo.lock b/fuzz/Cargo.lock index 9e63907..d6c7277 100644 --- a/fuzz/Cargo.lock +++ b/fuzz/Cargo.lock @@ -709,7 +709,7 @@ dependencies = [ [[package]] name = "zkac" -version = "0.4.0" +version = "0.4.1" dependencies = [ "blake2", "chacha20poly1305", diff --git a/pyproject.toml b/pyproject.toml index abcb15f..af9fecc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "zkac" -version = "0.4.0" +version = "0.4.1" description = "Zero-Knowledge Access Control: BBS+ anonymous credentials with encrypted transport" readme = "README.md" requires-python = ">=3.9" diff --git a/python/zkac/__init__.py b/python/zkac/__init__.py index 009434a..609b44e 100644 --- a/python/zkac/__init__.py +++ b/python/zkac/__init__.py @@ -6,6 +6,7 @@ BBS+ anonymous credentials (BLS12-381) with encrypted transport (Ristretto255 / from zkac._zkac import ( MAX_BBS_AUTH_PROOF_BYTES, + PIR_RECORD_BYTES, Keypair, PublicKey, BbsIssuer, @@ -22,6 +23,11 @@ from zkac._zkac import ( IssuanceKeypair, encrypt_for_admin, decrypt_from_admin, + PirDatabase, + PirServer, + PirClient, + PirClientState, + grant_detection_tag, Session, Node, PendingConnect, @@ -29,6 +35,7 @@ from zkac._zkac import ( __all__ = [ "MAX_BBS_AUTH_PROOF_BYTES", + "PIR_RECORD_BYTES", "Keypair", "PublicKey", "BbsIssuer", @@ -45,6 +52,11 @@ __all__ = [ "IssuanceKeypair", "encrypt_for_admin", "decrypt_from_admin", + "PirDatabase", + "PirServer", + "PirClient", + "PirClientState", + "grant_detection_tag", "Session", "Node", "PendingConnect", diff --git a/src/lib.rs b/src/lib.rs index a31b617..d0923d2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ pub mod credential; pub mod error; pub mod issuance; pub mod node; +pub mod pir; pub mod registry_manager; pub mod transport; diff --git a/src/pir/db.rs b/src/pir/db.rs new file mode 100644 index 0000000..8629a68 --- /dev/null +++ b/src/pir/db.rs @@ -0,0 +1,46 @@ +use blake2::Blake2b512; +use digest::Digest; + +/// Database of fixed-length records, packed as a `cells_per_record × n_records` +/// column-major matrix of mod-p cells (p = 256, so one byte = one cell). +pub struct Database { + /// Row-major: data[i * n_records + j] = record j's byte i. + data: Vec, + n_records: usize, + record_bytes: usize, +} + +impl Database { + /// Pack `records` (each padded/truncated to `record_bytes`) into a query-ready matrix. + pub fn new(records: &[&[u8]], record_bytes: usize) -> Self { + let n_records = records.len(); + let cells = record_bytes; + let mut data = vec![0u32; cells * n_records]; + for (j, rec) in records.iter().enumerate() { + let len = rec.len().min(record_bytes); + for i in 0..len { + data[i * n_records + j] = rec[i] as u32; + } + } + Database { data, n_records, record_bytes } + } + + pub fn data(&self) -> &[u32] { &self.data } + pub fn n_records(&self) -> usize { self.n_records } + pub fn record_bytes(&self) -> usize { self.record_bytes } + pub fn cells_per_record(&self) -> usize { self.record_bytes } + + /// BLAKE2b-256 commitment over (n_records, record_bytes, all packed cells). + pub fn version(&self) -> [u8; 32] { + let mut h = Blake2b512::new(); + h.update((self.n_records as u64).to_le_bytes()); + h.update((self.record_bytes as u64).to_le_bytes()); + for &val in &self.data { + h.update(val.to_le_bytes()); + } + let digest = h.finalize(); + let mut v = [0u8; 32]; + v.copy_from_slice(&digest[..32]); + v + } +} diff --git a/src/pir/doublepir.rs b/src/pir/doublepir.rs new file mode 100644 index 0000000..d565c11 --- /dev/null +++ b/src/pir/doublepir.rs @@ -0,0 +1,287 @@ +//! SimplePIR-based single-server Private Information Retrieval. +//! +//! Implements the first layer of DoublePIR (Henzinger–Hong–Corrigan-Gibbs– +//! Meiklejohn–Vaikuntanathan, USENIX Security '23). For full-record retrieval +//! the second compression layer is unnecessary — the client needs the entire +//! column — so this is equivalent to SimplePIR. The second layer can be added +//! as an optimisation for very large pools without API changes. +//! +//! Security: decisional LWE with n=1024, q=2^32, σ=6.4 (128-bit classical). + +use blake2::Blake2b512; +use digest::Digest; +use rand::Rng; +use rand::rngs::OsRng; + +use super::params::*; +use super::lwe; +use super::db::Database; + +// ── Hints (public matrix seed + precomputed H = D · A^T) ──────────── + +pub struct Hints { + seed: [u8; 32], + n_records: usize, + cells_per_record: usize, + /// Row-major `cells_per_record × N_LWE` matrix (mod 2^32). + hint: Vec, +} + +impl Hints { + pub fn n_records(&self) -> usize { self.n_records } + pub fn cells_per_record(&self) -> usize { self.cells_per_record } + pub fn seed(&self) -> &[u8; 32] { &self.seed } + pub fn hint_matrix(&self) -> &[u32] { &self.hint } + + pub fn serialize(&self) -> Vec { + let n = self.cells_per_record * N_LWE; + let mut buf = Vec::with_capacity(32 + 8 + 8 + n * 4); + buf.extend_from_slice(&self.seed); + buf.extend_from_slice(&(self.n_records as u64).to_le_bytes()); + buf.extend_from_slice(&(self.cells_per_record as u64).to_le_bytes()); + for &v in &self.hint { + buf.extend_from_slice(&v.to_le_bytes()); + } + buf + } + + pub fn deserialize(data: &[u8]) -> Result { + if data.len() < 48 { + return Err("hint data too short"); + } + let mut seed = [0u8; 32]; + seed.copy_from_slice(&data[..32]); + let n_records = u64::from_le_bytes(data[32..40].try_into().unwrap()) as usize; + let cells_per_record = u64::from_le_bytes(data[40..48].try_into().unwrap()) as usize; + let n = cells_per_record.checked_mul(N_LWE).ok_or("overflow")?; + if data.len() != 48 + n * 4 { + return Err("hint data length mismatch"); + } + let hint: Vec = data[48..] + .chunks_exact(4) + .map(|c| u32::from_le_bytes(c.try_into().unwrap())) + .collect(); + Ok(Hints { seed, n_records, cells_per_record, hint }) + } + + pub fn version(&self) -> [u8; 32] { + let mut h = Blake2b512::new(); + h.update(b"zkac-pir-hints-v1"); + h.update(self.seed); + h.update((self.n_records as u64).to_le_bytes()); + h.update((self.cells_per_record as u64).to_le_bytes()); + for &v in &self.hint { + h.update(v.to_le_bytes()); + } + let digest = h.finalize(); + let mut out = [0u8; 32]; + out.copy_from_slice(&digest[..32]); + out + } +} + +// ── Server ────────────────────────────────────────────────────────── + +pub struct Server { + db: Database, + hints: Hints, +} + +impl Server { + /// Build server state: generates a random public matrix A and precomputes + /// the hint H = D · A^T. This is the expensive offline phase. + pub fn new(db: Database) -> Self { + let m = db.n_records(); + let ell = db.cells_per_record(); + + let mut seed = [0u8; 32]; + OsRng.fill(&mut seed); + + let hint = if m == 0 { + Vec::new() + } else { + let a = lwe::gen_matrix(&seed, N_LWE, m); + lwe::mat_mul_bt(db.data(), &a, ell, m, N_LWE) + }; + + Server { + hints: Hints { seed, n_records: m, cells_per_record: ell, hint }, + db, + } + } + + pub fn hints(&self) -> &Hints { &self.hints } + + pub fn version(&self) -> [u8; 32] { self.hints.version() } + + /// Compute answer = D · query (mod 2^32). The query vector has `n_records` + /// entries; the answer has `cells_per_record` entries. + pub fn answer(&self, query: &[u32]) -> Vec { + lwe::mat_vec_mul( + self.db.data(), query, + self.db.cells_per_record(), self.db.n_records(), + ) + } + + pub fn n_records(&self) -> usize { self.db.n_records() } + pub fn record_bytes(&self) -> usize { self.db.record_bytes() } +} + +// ── Client ────────────────────────────────────────────────────────── + +pub struct ClientState { + secret: Vec, +} + +pub struct Client { + hints: Hints, +} + +impl Client { + pub fn new(hints: Hints) -> Self { + Client { hints } + } + + pub fn version(&self) -> [u8; 32] { self.hints.version() } + pub fn n_records(&self) -> usize { self.hints.n_records } + pub fn record_bytes(&self) -> usize { self.hints.cells_per_record } + + /// Generate a PIR query for `index`. Returns the query vector (to send to + /// the server) and opaque client state (needed for decoding the answer). + pub fn query(&self, index: usize) -> (Vec, ClientState) { + assert!(index < self.hints.n_records, "PIR query index out of range"); + + let mut rng = OsRng; + let s = lwe::sample_uniform_vec(&mut rng, N_LWE); + let e = lwe::sample_error_vec(&mut rng, self.hints.n_records); + + // q = A^T · s + e (A is N_LWE × n_records) + let a = lwe::gen_matrix(self.hints.seed(), N_LWE, self.hints.n_records); + let mut q = lwe::mat_t_vec_mul(&a, &s, N_LWE, self.hints.n_records); + + for j in 0..self.hints.n_records { + q[j] = q[j].wrapping_add(e[j]); + } + q[index] = q[index].wrapping_add(DELTA); + + (q, ClientState { secret: s }) + } + + /// Decode the server's answer using the saved client state. + /// Returns the `record_bytes`-length plaintext record. + pub fn decode(&self, answer: &[u32], state: &ClientState) -> Vec { + // h_s = H · s (cells_per_record × N_LWE) · (N_LWE) → cells_per_record + let h_s = lwe::mat_vec_mul( + self.hints.hint_matrix(), &state.secret, + self.hints.cells_per_record, N_LWE, + ); + + (0..self.hints.cells_per_record) + .map(|i| lwe::round_to_plaintext(answer[i].wrapping_sub(h_s[i]))) + .collect() + } +} + +// ── Serialization helpers for query / answer vectors ──────────────── + +pub fn serialize_vec(v: &[u32]) -> Vec { + let mut buf = Vec::with_capacity(v.len() * 4); + for &val in v { + buf.extend_from_slice(&val.to_le_bytes()); + } + buf +} + +pub fn deserialize_vec(data: &[u8]) -> Vec { + data.chunks_exact(4) + .map(|c| u32::from_le_bytes(c.try_into().unwrap())) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_test_records(n: usize, rec_bytes: usize) -> Vec> { + (0..n).map(|i| { + let mut rec = vec![0u8; rec_bytes]; + rec[0] = i as u8; + rec[1] = (i.wrapping_mul(37) & 0xFF) as u8; + if rec_bytes > 2 { + rec[rec_bytes - 1] = 0xAA; + } + rec + }).collect() + } + + #[test] + fn roundtrip_small() { + let records = make_test_records(8, 128); + let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect(); + let db = Database::new(&refs, 128); + let server = Server::new(db); + let hints_bytes = server.hints().serialize(); + let client = Client::new(Hints::deserialize(&hints_bytes).unwrap()); + + for target in 0..8 { + let (query, state) = client.query(target); + let answer = server.answer(&query); + let decoded = client.decode(&answer, &state); + assert_eq!(decoded, records[target]); + } + } + + #[test] + fn roundtrip_serialized_query_answer() { + let records = make_test_records(4, 64); + let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect(); + let db = Database::new(&refs, 64); + let server = Server::new(db); + let client = Client::new(Hints::deserialize(&server.hints().serialize()).unwrap()); + + let (q, state) = client.query(2); + let q_bytes = serialize_vec(&q); + let q2 = deserialize_vec(&q_bytes); + assert_eq!(q, q2); + + let ans = server.answer(&q2); + let ans_bytes = serialize_vec(&ans); + let ans2 = deserialize_vec(&ans_bytes); + let decoded = client.decode(&ans2, &state); + assert_eq!(decoded, records[2]); + } + + #[test] + fn version_changes_on_rebuild() { + let records = make_test_records(4, 64); + let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect(); + let db1 = Database::new(&refs, 64); + let s1 = Server::new(db1); + let v1 = s1.version(); + + let db2 = Database::new(&refs, 64); + let s2 = Server::new(db2); + let v2 = s2.version(); + + // Different seeds → different versions (with overwhelming probability). + assert_ne!(v1, v2); + } + + #[test] + fn hints_roundtrip() { + let records = make_test_records(4, 64); + let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect(); + let db = Database::new(&refs, 64); + let server = Server::new(db); + + let original = server.hints(); + let bytes = original.serialize(); + let restored = Hints::deserialize(&bytes).unwrap(); + + assert_eq!(original.seed(), restored.seed()); + assert_eq!(original.n_records(), restored.n_records()); + assert_eq!(original.cells_per_record(), restored.cells_per_record()); + assert_eq!(original.hint_matrix(), restored.hint_matrix()); + assert_eq!(original.version(), restored.version()); + } +} diff --git a/src/pir/lwe.rs b/src/pir/lwe.rs new file mode 100644 index 0000000..854e25c --- /dev/null +++ b/src/pir/lwe.rs @@ -0,0 +1,103 @@ +use rand::{Rng, SeedableRng}; +use rand::rngs::StdRng; + +use super::params::*; + +/// Deterministically generate a `rows × cols` matrix of uniform u32 values from `seed`. +pub fn gen_matrix(seed: &[u8; 32], rows: usize, cols: usize) -> Vec { + let mut rng = StdRng::from_seed(*seed); + (0..rows * cols).map(|_| rng.gen()).collect() +} + +/// Sample one value from a discrete Gaussian with stddev SIGMA (Box-Muller, rounded). +fn sample_gaussian(rng: &mut impl Rng) -> i32 { + let u1: f64 = rng.gen_range(1e-10f64..1.0); + let u2: f64 = rng.gen(); + let z = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos(); + (z * SIGMA).round() as i32 +} + +/// Sample a vector of `len` small error values (discrete Gaussian, stored as u32 mod 2^32). +pub fn sample_error_vec(rng: &mut impl Rng, len: usize) -> Vec { + (0..len).map(|_| sample_gaussian(rng) as u32).collect() +} + +pub fn sample_uniform_vec(rng: &mut impl Rng, len: usize) -> Vec { + (0..len).map(|_| rng.gen()).collect() +} + +/// C = A · B^T (mod 2^32). +/// A is `rows_a × inner` (row-major), B is `rows_b × inner` (row-major). +/// Result C is `rows_a × rows_b`. +pub fn mat_mul_bt(a: &[u32], b: &[u32], rows_a: usize, inner: usize, rows_b: usize) -> Vec { + let mut c = vec![0u32; rows_a * rows_b]; + for i in 0..rows_a { + let a_off = i * inner; + for k in 0..rows_b { + let b_off = k * inner; + let mut acc = 0u64; + for j in 0..inner { + acc = acc.wrapping_add( + (a[a_off + j] as u64).wrapping_mul(b[b_off + j] as u64), + ); + } + c[i * rows_b + k] = acc as u32; + } + } + c +} + +/// c = A · v (mod 2^32). A is `rows × cols`, v is `cols`-length. +pub fn mat_vec_mul(a: &[u32], v: &[u32], rows: usize, cols: usize) -> Vec { + let mut c = vec![0u32; rows]; + for i in 0..rows { + let off = i * cols; + let mut acc = 0u64; + for j in 0..cols { + acc = acc.wrapping_add((a[off + j] as u64).wrapping_mul(v[j] as u64)); + } + c[i] = acc as u32; + } + c +} + +/// c = A^T · v (mod 2^32). A is `rows × cols`, v is `rows`-length. Result is `cols`-length. +pub fn mat_t_vec_mul(a: &[u32], v: &[u32], rows: usize, cols: usize) -> Vec { + let mut c = vec![0u64; cols]; + for k in 0..rows { + let v_k = v[k] as u64; + let off = k * cols; + for j in 0..cols { + c[j] = c[j].wrapping_add(v_k.wrapping_mul(a[off + j] as u64)); + } + } + c.iter().map(|&x| x as u32).collect() +} + +/// Round from Z_{2^32} to Z_p. Recovers the plaintext byte from Δ·m + noise. +pub fn round_to_plaintext(val: u32) -> u8 { + (val.wrapping_add(DELTA / 2) >> 24) as u8 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn round_exact() { + for m in 0u32..=255 { + let encoded = m.wrapping_mul(DELTA); + assert_eq!(round_to_plaintext(encoded), m as u8); + } + } + + #[test] + fn round_with_small_noise() { + for m in 0u32..=255 { + for noise in [-100i32, -1, 0, 1, 100] { + let encoded = m.wrapping_mul(DELTA).wrapping_add(noise as u32); + assert_eq!(round_to_plaintext(encoded), m as u8); + } + } + } +} diff --git a/src/pir/mod.rs b/src/pir/mod.rs new file mode 100644 index 0000000..2e25e8f --- /dev/null +++ b/src/pir/mod.rs @@ -0,0 +1,65 @@ +pub mod params; +pub mod lwe; +pub mod db; +pub mod doublepir; + +pub use params::RECORD_BYTES; +pub use db::Database; +pub use doublepir::{Server, Client, ClientState, Hints, serialize_vec, deserialize_vec}; + +use blake2::Blake2b512; +use digest::Digest; +use x25519_dalek::{PublicKey as X25519Public, StaticSecret}; + +/// Compute a 16-byte detection tag for grant discovery. +/// +/// tag = BLAKE2b-512("zkac-grant-tag" || X25519(sk, pk))[..16] +/// +/// Both sides of the DH produce the same tag, so the sender (admin) can +/// publish it alongside the grant ciphertext, and the recipient can match +/// it locally without PIR. +pub fn detection_tag(secret_key: &[u8; 32], public_key: &[u8; 32]) -> [u8; 16] { + let sk = StaticSecret::from(*secret_key); + let pk = X25519Public::from(*public_key); + let shared = sk.diffie_hellman(&pk); + + let mut h = Blake2b512::new(); + h.update(b"zkac-grant-tag"); + h.update(shared.as_bytes()); + let digest = h.finalize(); + + let mut tag = [0u8; 16]; + tag.copy_from_slice(&digest[..16]); + tag +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::rngs::OsRng; + + #[test] + fn detection_tag_symmetric() { + let sk_a = x25519_dalek::StaticSecret::random_from_rng(&mut OsRng); + let pk_a = X25519Public::from(&sk_a); + let sk_b = x25519_dalek::StaticSecret::random_from_rng(&mut OsRng); + let pk_b = X25519Public::from(&sk_b); + + let tag_ab = detection_tag(&sk_a.to_bytes(), pk_b.as_bytes()); + let tag_ba = detection_tag(&sk_b.to_bytes(), pk_a.as_bytes()); + assert_eq!(tag_ab, tag_ba); + } + + #[test] + fn detection_tag_different_keys() { + let sk_a = x25519_dalek::StaticSecret::random_from_rng(&mut OsRng); + let sk_b = x25519_dalek::StaticSecret::random_from_rng(&mut OsRng); + let sk_c = x25519_dalek::StaticSecret::random_from_rng(&mut OsRng); + let pk_b = X25519Public::from(&sk_b); + let pk_c = X25519Public::from(&sk_c); + + let tag1 = detection_tag(&sk_a.to_bytes(), pk_b.as_bytes()); + let tag2 = detection_tag(&sk_a.to_bytes(), pk_c.as_bytes()); + assert_ne!(tag1, tag2); + } +} diff --git a/src/pir/params.rs b/src/pir/params.rs new file mode 100644 index 0000000..fcdc140 --- /dev/null +++ b/src/pir/params.rs @@ -0,0 +1,14 @@ +/// LWE dimension — controls security level. +pub const N_LWE: usize = 1024; + +/// Plaintext modulus: each cell holds one byte (p = 256). +pub const P: u32 = 256; + +/// Scaling factor Δ = 2^32 / p = 2^24. Maps plaintext [0,255] into Z_{2^32}. +pub const DELTA: u32 = 1 << 24; + +/// Discrete Gaussian standard deviation for LWE error sampling. +pub const SIGMA: f64 = 6.4; + +/// Fixed record size for PIR (bytes). Must match the CLI grant padding. +pub const RECORD_BYTES: usize = 64 * 1024; diff --git a/src/python.rs b/src/python.rs index 116db0d..c6bdf60 100644 --- a/src/python.rs +++ b/src/python.rs @@ -827,11 +827,186 @@ impl PyNode { } } +// ── PIR (DoublePIR / SimplePIR, LWE-based) ────────────────────────── + +#[pyclass(name = "PirDatabase")] +pub struct PyPirDatabase { + inner: crate::pir::Database, +} + +#[pymethods] +impl PyPirDatabase { + #[new] + fn new(records: Vec>, record_bytes: usize) -> PyResult { + if record_bytes == 0 { + return Err(PyValueError::new_err("record_bytes must be > 0")); + } + let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect(); + Ok(PyPirDatabase { + inner: crate::pir::Database::new(&refs, record_bytes), + }) + } + + fn version<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> { + PyBytes::new(py, &self.inner.version()) + } + + #[getter] + fn record_bytes(&self) -> usize { + self.inner.record_bytes() + } + + #[getter] + fn n_records(&self) -> usize { + self.inner.n_records() + } +} + +#[pyclass(name = "PirServer")] +pub struct PyPirServer { + inner: crate::pir::Server, +} + +#[pymethods] +impl PyPirServer { + #[new] + fn new(db: &PyPirDatabase) -> Self { + // Database is packed into u32 cells; we need to reconstruct from the + // inner data. Since Database doesn't implement Clone, rebuild it from + // the raw cell data. + // + // Actually, Server::new takes ownership of Database. We reconstruct + // by re-packing from the stored cell matrix. This is slightly wasteful + // but keeps the API simple. + let n = db.inner.n_records(); + let rb = db.inner.record_bytes(); + let cells = db.inner.cells_per_record(); + + let mut records: Vec> = Vec::with_capacity(n); + for j in 0..n { + let mut rec = vec![0u8; rb]; + for i in 0..cells { + rec[i] = db.inner.data()[i * n + j] as u8; + } + records.push(rec); + } + let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect(); + let new_db = crate::pir::Database::new(&refs, rb); + PyPirServer { + inner: crate::pir::Server::new(new_db), + } + } + + fn hints<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> { + PyBytes::new(py, &self.inner.hints().serialize()) + } + + fn version<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> { + PyBytes::new(py, &self.inner.version()) + } + + fn answer<'py>(&self, py: Python<'py>, query: &[u8]) -> PyResult> { + let q = crate::pir::deserialize_vec(query); + if q.len() != self.inner.n_records() { + return Err(PyValueError::new_err("query length mismatch")); + } + let ans = self.inner.answer(&q); + Ok(PyBytes::new(py, &crate::pir::serialize_vec(&ans))) + } + + #[getter] + fn n_records(&self) -> usize { + self.inner.n_records() + } + + #[getter] + fn record_bytes(&self) -> usize { + self.inner.record_bytes() + } +} + +#[pyclass(name = "PirClientState")] +pub struct PyPirClientState { + inner: Option, +} + +#[pyclass(name = "PirClient")] +pub struct PyPirClient { + inner: crate::pir::Client, +} + +#[pymethods] +impl PyPirClient { + #[new] + fn new(hints: &[u8], n_records: usize, record_bytes: usize) -> PyResult { + let h = crate::pir::Hints::deserialize(hints) + .map_err(|e| PyValueError::new_err(e.to_string()))?; + if h.n_records() != n_records { + return Err(PyValueError::new_err("n_records mismatch with hints")); + } + if h.cells_per_record() != record_bytes { + return Err(PyValueError::new_err("record_bytes mismatch with hints")); + } + Ok(PyPirClient { + inner: crate::pir::Client::new(h), + }) + } + + fn version<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> { + PyBytes::new(py, &self.inner.version()) + } + + fn query<'py>( + &self, + py: Python<'py>, + index: usize, + ) -> PyResult<(Bound<'py, PyBytes>, PyPirClientState)> { + if index >= self.inner.n_records() { + return Err(PyValueError::new_err("index out of range")); + } + let (q, state) = self.inner.query(index); + Ok(( + PyBytes::new(py, &crate::pir::serialize_vec(&q)), + PyPirClientState { inner: Some(state) }, + )) + } + + fn decode<'py>( + &self, + py: Python<'py>, + answer: &[u8], + state: &mut PyPirClientState, + ) -> PyResult> { + let st = state.inner.take().ok_or_else(|| { + PyValueError::new_err("PirClientState already consumed") + })?; + let ans = crate::pir::deserialize_vec(answer); + if ans.len() != self.inner.record_bytes() { + return Err(PyValueError::new_err("answer length mismatch")); + } + let decoded = self.inner.decode(&ans, &st); + Ok(PyBytes::new(py, &decoded)) + } +} + +#[pyfunction] +fn grant_detection_tag<'py>( + py: Python<'py>, + secret_key: &[u8], + public_key: &[u8], +) -> PyResult> { + let sk = to_32(secret_key, "secret_key")?; + let pk = to_32(public_key, "public_key")?; + let tag = crate::pir::detection_tag(&sk, &pk); + Ok(PyBytes::new(py, &tag)) +} + // ── Module ─────────────────────────────────────────────────────────── #[pymodule] fn _zkac(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add("MAX_BBS_AUTH_PROOF_BYTES", crate::node::MAX_BBS_AUTH_PROOF_BYTES)?; + m.add("PIR_RECORD_BYTES", crate::pir::RECORD_BYTES)?; // Transport identity (ristretto255) m.add_class::()?; m.add_class::()?; @@ -853,6 +1028,12 @@ fn _zkac(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_function(wrap_pyfunction!(encrypt_for_admin, m)?)?; m.add_function(wrap_pyfunction!(decrypt_from_admin, m)?)?; + // PIR (LWE-based, single-server) + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_function(wrap_pyfunction!(grant_detection_tag, m)?)?; // Transport m.add_class::()?; m.add_class::()?; diff --git a/tests/test_pir.py b/tests/test_pir.py index da17b9f..3a48492 100644 --- a/tests/test_pir.py +++ b/tests/test_pir.py @@ -1,42 +1,82 @@ -"""Unit tests for XOR PIR helpers (Chor-style two-server PIR).""" +"""Tests for LWE-based PIR (DoublePIR / SimplePIR) via the Python bindings.""" -from zkac_cli import pir +import json +import os + +import zkac -def test_pad_roundtrip(): - d = { - "grant_id": "abc", - "claimed": False, - "eph_pk_b64": "eA==", - "ciphertext_b64": "cA==", - } - p = pir.pad_grant_record(d) - assert len(p) == pir.PIR_RECORD_BYTES - assert pir.unpad_grant_record(p) == d +def _pad(entry: dict) -> bytes: + raw = json.dumps(entry, separators=(",", ":"), sort_keys=True).encode() + return raw + b"\x00" * (zkac.PIR_RECORD_BYTES - len(raw)) -def test_xor_pir_row_recovery(): - rows = [ - { - "grant_id": f"id{i}", - "claimed": False, - "eph_pk_b64": f"e{i}==", - "ciphertext_b64": f"c{i}==", - } - for i in range(5) +def test_roundtrip_small(): + """Build a small DB, query every index, verify each decode is correct.""" + records = [ + _pad({"id": i, "data": f"record-{i}"}) + for i in range(8) ] - want_i = 3 - sa, sb = pir.pir_query_indices(len(rows), want_i) + db = zkac.PirDatabase(records, zkac.PIR_RECORD_BYTES) + server = zkac.PirServer(db) + hints = bytes(server.hints()) + client = zkac.PirClient(hints, 8, zkac.PIR_RECORD_BYTES) - def fold(idxs): - pads = [pir.pad_grant_record(rows[j]) for j in sorted(set(idxs))] - return pir.xor_bytes_many(pads) - - xa = fold(sa) - xb = fold(sb) - got = pir.pir_recover(xa, xb) - assert got == pir.pad_grant_record(rows[want_i]) + for target in range(8): + q, state = client.query(target) + answer = server.answer(q) + decoded = bytes(client.decode(answer, state)) + assert decoded == records[target], f"mismatch at index {target}" -def test_pir_fold_empty_is_zero_block(): - assert pir.xor_bytes_many([]) == b"\x00" * pir.PIR_RECORD_BYTES +def test_roundtrip_medium(): + """64 records at full PIR_RECORD_BYTES size.""" + records = [] + for i in range(64): + entry = {"id": i, "payload": os.urandom(128).hex()} + records.append(_pad(entry)) + db = zkac.PirDatabase(records, zkac.PIR_RECORD_BYTES) + server = zkac.PirServer(db) + hints = bytes(server.hints()) + client = zkac.PirClient(hints, 64, zkac.PIR_RECORD_BYTES) + + for target in [0, 1, 31, 32, 63]: + q, state = client.query(target) + answer = server.answer(q) + decoded = bytes(client.decode(answer, state)) + assert decoded == records[target] + + +def test_version_changes_on_rebuild(): + """Rebuilding from the same data with a new seed gives a different version.""" + records = [_pad({"i": i}) for i in range(4)] + s1 = zkac.PirServer(zkac.PirDatabase(records, zkac.PIR_RECORD_BYTES)) + s2 = zkac.PirServer(zkac.PirDatabase(records, zkac.PIR_RECORD_BYTES)) + assert bytes(s1.version()) != bytes(s2.version()) + + +def test_hints_serialize_roundtrip(): + records = [_pad({"i": i}) for i in range(4)] + server = zkac.PirServer(zkac.PirDatabase(records, zkac.PIR_RECORD_BYTES)) + hints1 = bytes(server.hints()) + client = zkac.PirClient(hints1, 4, zkac.PIR_RECORD_BYTES) + assert bytes(client.version()) == bytes(server.version()) + + +def test_detection_tag_symmetric(): + """grant_detection_tag(a_sk, b_pk) == grant_detection_tag(b_sk, a_pk).""" + kp_a = zkac.IssuanceKeypair() + kp_b = zkac.IssuanceKeypair() + tag_ab = bytes(zkac.grant_detection_tag(kp_a.secret_bytes(), kp_b.public_key_bytes())) + tag_ba = bytes(zkac.grant_detection_tag(kp_b.secret_bytes(), kp_a.public_key_bytes())) + assert tag_ab == tag_ba + assert len(tag_ab) == 16 + + +def test_detection_tag_distinct(): + kp_a = zkac.IssuanceKeypair() + kp_b = zkac.IssuanceKeypair() + kp_c = zkac.IssuanceKeypair() + t1 = bytes(zkac.grant_detection_tag(kp_a.secret_bytes(), kp_b.public_key_bytes())) + t2 = bytes(zkac.grant_detection_tag(kp_a.secret_bytes(), kp_c.public_key_bytes())) + assert t1 != t2 diff --git a/uv.lock b/uv.lock index 8c8585c..b5dc53a 100644 --- a/uv.lock +++ b/uv.lock @@ -844,7 +844,7 @@ wheels = [ [[package]] name = "zkac" -version = "0.4.0" +version = "0.4.1" source = { editable = "." } dependencies = [ { name = "ipykernel", version = "6.31.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },