"""ZKAC server: all traffic over a single encrypted, server-authenticated channel. Every connection performs an anonymous handshake (X25519 + Schnorr server identity proof). The first encrypted frame selects the mode: {"op": "mgmt"} management commands (create_registry, post_grant, ...) {"op": "auth", "registry_id": hex, "bbs_auth_b64": ...} role authentication The server stores only cryptographically verified opaque blobs: /server_key.json Schnorr keypair /registries/.state raw RegistryState bytes /registries/.cert raw state cert bytes /mailbox/grants_pool.json anonymous append-only grant pool Recipients discover matching grants via a cheap detection-tag index (``pool_tags``). PIR (``pir_query``) returns a full encrypted mailbox row, so no follow-up ``grant_id`` fetch is required. This avoids leaking a stable row identifier during retrieval. Large PIR answers are streamed in slices (same pattern as ``pir_hints``). """ from __future__ import annotations import base64 import hashlib import json import socket import threading import traceback from pathlib import Path import zkac from zkac.tcp import MAX_TCP_FRAME_BYTES, FramedSession, server_handshake_anon # Serialized PIR hints can be huge (64 KiB records × LWE width). Each mgmt reply must # stay under :data:`MAX_TCP_FRAME_BYTES` after encryption; base64 expands ~4/3. _PIR_HINT_CHUNK = min(131_072, max(16_384, (MAX_TCP_FRAME_BYTES * 3) // 5)) def _b64(data: bytes) -> str: return base64.b64encode(data).decode() def _unb64(s: str) -> bytes: return base64.b64decode(s) def _pir_row_bytes(entry: dict) -> bytes: """Fixed-size PIR row: full encrypted mailbox row + ciphertext digest.""" ct_b64 = entry.get("ciphertext_b64", "") ct_digest = hashlib.sha256(_unb64(ct_b64)).hexdigest() if ct_b64 else "" row = { "v": 2, "eph_pk_b64": entry.get("eph_pk_b64", ""), "to_tag_b64": entry.get("to_tag_b64", ""), "ciphertext_b64": ct_b64, "ciphertext_sha256": ct_digest, "claimed": bool(entry.get("claimed", False)), } raw = json.dumps(row, separators=(",", ":"), sort_keys=True).encode("utf-8") if len(raw) > zkac.PIR_RECORD_BYTES: raise ValueError( f"PIR row exceeds PIR_RECORD_BYTES ({zkac.PIR_RECORD_BYTES})" ) return raw + b"\x00" * (zkac.PIR_RECORD_BYTES - len(raw)) # ── Opaque server storage ───────────────────────────────────────────── class _ServerStore: """Thread-safe, opaque persistence for registries and anonymous grant pool.""" def __init__(self, data_dir: Path): self._dir = data_dir self._reg_dir = data_dir / "registries" self._mbox_dir = data_dir / "mailbox" self._pool_path = self._mbox_dir / "grants_pool.json" self._reg_dir.mkdir(parents=True, exist_ok=True) self._mbox_dir.mkdir(parents=True, exist_ok=True) self._lock = threading.Lock() self._pir_server: zkac.PirServer | None = None self._pir_dirty = True self._migrate_legacy_mailbox() # ── server key ──────────────────────────────────────────────────── def load_or_create_keypair(self) -> zkac.Keypair: kf = self._dir / "server_key.json" if kf.exists(): data = json.loads(kf.read_text()) return zkac.Keypair.from_secret_key(_unb64(data["secret_b64"])) kp = zkac.Keypair() kf.write_text(json.dumps({ "secret_b64": _b64(kp.secret_key_bytes()), "public_b64": _b64(kp.public_key().to_bytes()), }, indent=2)) return kp # ── registries ──────────────────────────────────────────────────── def save_registry(self, rid_hex: str, state_bytes: bytes, cert_bytes: bytes): with self._lock: (self._reg_dir / f"{rid_hex}.state").write_bytes(state_bytes) (self._reg_dir / f"{rid_hex}.cert").write_bytes(cert_bytes) def load_all_registries(self, mgr: zkac.RegistryManager) -> int: count = 0 for p in sorted(self._reg_dir.glob("*.state")): rid_hex = p.stem cert_path = self._reg_dir / f"{rid_hex}.cert" if not cert_path.exists(): continue try: mgr.restore(p.read_bytes(), cert_path.read_bytes()) count += 1 except Exception as exc: print(f"[server] skip registry {rid_hex}: {exc}") return count # ── legacy per-recipient mailbox → anonymous pool ───────────────── def _migrate_legacy_mailbox(self): if self._pool_path.exists(): return records: list[dict] = [] for p in sorted(self._mbox_dir.glob("*.json")): if p.name == "grants_pool.json": continue try: entries = json.loads(p.read_text()) for e in entries: if "claimed" not in e: e["claimed"] = False records.append(e) except Exception as exc: print(f"[server] skip legacy mailbox {p.name}: {exc}") try: p.unlink() except OSError: pass if records: self._pool_path.write_text( json.dumps({"version": 1, "records": records}, indent=2) ) def _load_pool(self) -> list[dict]: if not self._pool_path.exists(): return [] data = json.loads(self._pool_path.read_text()) return data.get("records", []) def _save_pool(self, records: list[dict]): self._pool_path.write_text( json.dumps({"version": 1, "records": records}, indent=2) ) # ── PIR server rebuild ──────────────────────────────────────────── def _ensure_pir(self) -> zkac.PirServer: if self._pir_server is not None and not self._pir_dirty: return self._pir_server records = self._load_pool() packed = [_pir_row_bytes(r) for r in records] db = zkac.PirDatabase(packed, zkac.PIR_RECORD_BYTES) self._pir_server = zkac.PirServer(db) self._pir_dirty = False return self._pir_server # ── anonymous grant pool ───────────────────────────────────────── def post_grant(self, entry: dict) -> int: _pir_row_bytes(entry) row = {"claimed": False, **entry} with self._lock: records = self._load_pool() pool_index = len(records) records.append(row) self._save_pool(records) self._pir_dirty = True return pool_index def pool_info(self) -> dict: with self._lock: pir = self._ensure_pir() return { "n": pir.n_records, "record_bytes": pir.record_bytes, "pool_version": bytes(pir.version()).hex(), } def pool_tags(self) -> tuple[list[tuple[str, str]], str]: with self._lock: records = self._load_pool() pir = self._ensure_pir() version = bytes(pir.version()).hex() tags = [] for r in records: tags.append(( r.get("eph_pk_b64", ""), r.get("to_tag_b64", ""), )) return tags, version def pir_hints(self) -> tuple[bytes, str]: with self._lock: pir = self._ensure_pir() return bytes(pir.hints()), bytes(pir.version()).hex() def pir_answer_bytes(self, query_b64: str, pool_version: str) -> bytes | None: """Return raw PIR answer bytes, or ``None`` if ``pool_version`` is stale.""" with self._lock: pir = self._ensure_pir() current = bytes(pir.version()).hex() if current != pool_version: return None ans = pir.answer(_unb64(query_b64)) return bytes(ans) # ── Command dispatch (inside encrypted session) ────────────────────── def _dispatch( cmd: dict, mgr: zkac.RegistryManager, store: _ServerStore, server_pk_b64: str, transcript_hash: bytes, conn_ctx: dict, ) -> dict: try: action = cmd.get("cmd") if action == "server_info": return {"ok": True, "server_public_key_b64": server_pk_b64} if action == "create_registry": state_bytes = _unb64(cmd["state_bytes_b64"]) state_cert = _unb64(cmd["state_cert_b64"]) rid = mgr.create(state_bytes, state_cert) store.save_registry(rid.hex(), state_bytes, state_cert) return {"ok": True, "registry_id": rid.hex()} if action == "get_registry": rid = bytes.fromhex(cmd["registry_id"]) state_bytes, state_cert = mgr.get(rid) return { "ok": True, "state_bytes_b64": _b64(state_bytes), "state_cert_b64": _b64(state_cert), } if action == "update_registry": rid = bytes.fromhex(cmd["registry_id"]) state_bytes = _unb64(cmd["state_bytes_b64"]) state_cert = _unb64(cmd["state_cert_b64"]) mgr.update(rid, state_bytes, state_cert) store.save_registry(cmd["registry_id"], state_bytes, state_cert) return {"ok": True} if action == "post_grant": rid = bytes.fromhex(cmd["registry_id"]) proof = _unb64(cmd["admin_proof_b64"]) if not mgr.verify_admin(rid, proof, transcript_hash): return {"error": "admin proof failed"} entry = { "eph_pk_b64": cmd["eph_pk_b64"], "ciphertext_b64": cmd["ciphertext_b64"], "to_tag_b64": cmd.get("to_tag_b64", ""), } pool_index = store.post_grant(entry) return {"ok": True, "pool_index": pool_index} if action == "pool_info": info = store.pool_info() return {"ok": True, **info} if action == "pool_tags": tags, version = store.pool_tags() entries = [ {"eph_pk_b64": epk, "to_tag_b64": tag} for epk, tag in tags ] return {"ok": True, "tags": entries, "pool_version": version} if action == "pir_hints": raw, version = store.pir_hints() offset = cmd.get("offset") if offset is not None: if cmd.get("pool_version") != version: return {"error": "stale_version"} try: off = int(offset) except (TypeError, ValueError): return {"error": "bad offset"} if off < 0 or off > len(raw): return {"error": "bad offset"} piece = raw[off : off + _PIR_HINT_CHUNK] nxt = off + len(piece) return { "ok": True, "pool_version": version, "slice_b64": _b64(piece), "offset": off, "returned": len(piece), "done": nxt >= len(raw), } if len(raw) <= _PIR_HINT_CHUNK: return {"ok": True, "hints_b64": _b64(raw), "pool_version": version} return { "ok": True, "pool_version": version, "hints_total": len(raw), "chunk": _PIR_HINT_CHUNK, } if action == "pir_query": q_b64 = cmd.get("query_b64", "") pv = cmd.get("pool_version", "") offset = cmd.get("offset") if offset is None: ans_bytes = store.pir_answer_bytes(q_b64, pv) if ans_bytes is None: return {"error": "stale_version"} conn_ctx.pop("pir_answer_buffer", None) conn_ctx.pop("pir_answer_pv", None) if len(ans_bytes) <= _PIR_HINT_CHUNK: return {"ok": True, "answer_b64": _b64(ans_bytes), "pool_version": pv} conn_ctx["pir_answer_buffer"] = ans_bytes conn_ctx["pir_answer_pv"] = pv return { "ok": True, "pool_version": pv, "answer_total": len(ans_bytes), "chunk": _PIR_HINT_CHUNK, } if conn_ctx.get("pir_answer_pv") != pv: return {"error": "stale_version"} buf = conn_ctx.get("pir_answer_buffer") if buf is None: return {"error": "no PIR answer in progress"} try: off = int(offset) except (TypeError, ValueError): return {"error": "bad offset"} if off < 0 or off > len(buf): return {"error": "bad offset"} piece = buf[off : off + _PIR_HINT_CHUNK] nxt = off + len(piece) done = nxt >= len(buf) if done: conn_ctx.pop("pir_answer_buffer", None) conn_ctx.pop("pir_answer_pv", None) return { "ok": True, "pool_version": pv, "slice_b64": _b64(piece), "offset": off, "returned": len(piece), "done": done, } return {"error": f"unknown command: {action}"} except Exception as exc: return {"error": str(exc)} # ── Connection handler ──────────────────────────────────────────────── def _handle_conn(conn: socket.socket, addr: tuple, node: zkac.Node, mgr: zkac.RegistryManager, store: _ServerStore, server_pk_b64: str): peer = f"{addr[0]}:{addr[1]}" try: session = server_handshake_anon(conn, node) framed = FramedSession(conn, session) transcript_hash = bytes(session.transcript_hash()) hello = json.loads(framed.recv()) op = hello.get("op") if op == "mgmt": conn_ctx: dict = {} while True: try: data = framed.recv() except (ConnectionError, OSError): break cmd = json.loads(data) resp = _dispatch(cmd, mgr, store, server_pk_b64, transcript_hash, conn_ctx) framed.send(json.dumps(resp).encode()) elif op == "auth": registry_id = bytes.fromhex(hello["registry_id"]) role_id = bytes.fromhex(hello["role_id"]) proof_bytes = _unb64(hello["bbs_auth_b64"]) ok = mgr.verify_presentation( registry_id, role_id, proof_bytes, transcript_hash, ) if not ok: framed.send(json.dumps({"error": "auth failed"}).encode()) return resp = { "status": "authenticated", "registry_id": registry_id.hex(), "role_id": role_id.hex(), } framed.send(json.dumps(resp).encode()) while True: try: data = framed.recv() except (ConnectionError, OSError): break framed.send(data) else: framed.send(json.dumps({"error": f"unknown op: {op}"}).encode()) except (ConnectionError, BrokenPipeError, OSError): pass except Exception as exc: print(f"[server] {peer} error: {exc}") traceback.print_exc() finally: conn.close() # ── Public entry point ──────────────────────────────────────────────── def serve(data_dir: str, host: str = "127.0.0.1", port: int = 9800): dd = Path(data_dir) dd.mkdir(parents=True, exist_ok=True) store = _ServerStore(dd) kp = store.load_or_create_keypair() server_pk_b64 = _b64(kp.public_key().to_bytes()) node = zkac.Node(kp) mgr = zkac.RegistryManager() n = store.load_all_registries(mgr) print(f"server public key: {_unb64(server_pk_b64).hex()}") print(f"loaded {n} registries") print(f"listening on {host}:{port}") sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind((host, port)) sock.listen(8) try: while True: conn, addr = sock.accept() threading.Thread( target=_handle_conn, args=(conn, addr, node, mgr, store, server_pk_b64), daemon=True, ).start() except KeyboardInterrupt: print("\nshutdown") finally: sock.close()