"""ZKAC server for registry management and role authentication.""" from __future__ import annotations import base64 import json import os import socket import sys import threading import traceback from pathlib import Path import zkac from zkac.tcp import FramedSession, server_handshake_anon from .server_debug import ServerDebugState def _b64(data: bytes) -> str: return base64.b64encode(data).decode() def _unb64(s: str) -> bytes: return base64.b64decode(s) def _chmod_if_possible(path: Path, mode: int): try: os.chmod(path, mode) except OSError: pass def _write_private_json(path: Path, payload: dict): path.write_text(json.dumps(payload, indent=2)) _chmod_if_possible(path, 0o600) def _is_loopback_host(host: str) -> bool: value = host.strip().lower() return value in {"127.0.0.1", "::1", "localhost"} # ── Opaque server storage ───────────────────────────────────────────── class _ServerStore: """Thread-safe, opaque persistence for registry snapshots.""" def __init__(self, data_dir: Path): self._dir = data_dir self._reg_dir = data_dir / "registries" self._reg_dir.mkdir(parents=True, exist_ok=True) _chmod_if_possible(self._dir, 0o700) _chmod_if_possible(self._reg_dir, 0o700) self._lock = threading.Lock() # ── server key ──────────────────────────────────────────────────── def load_or_create_keypair(self) -> zkac.Keypair: kf = self._dir / "server_key.json" if kf.exists(): data = json.loads(kf.read_text()) return zkac.Keypair.from_secret_key(_unb64(data["secret_b64"])) kp = zkac.Keypair() _write_private_json(kf, { "secret_b64": _b64(kp.secret_key_bytes()), "public_b64": _b64(kp.public_key().to_bytes()), }) return kp # ── registries ──────────────────────────────────────────────────── def save_registry(self, rid_hex: str, state_bytes: bytes, cert_bytes: bytes): with self._lock: (self._reg_dir / f"{rid_hex}.state").write_bytes(state_bytes) (self._reg_dir / f"{rid_hex}.cert").write_bytes(cert_bytes) def load_all_registries(self, mgr: zkac.RegistryManager) -> int: count = 0 for p in sorted(self._reg_dir.glob("*.state")): rid_hex = p.stem cert_path = self._reg_dir / f"{rid_hex}.cert" if not cert_path.exists(): continue try: mgr.restore(p.read_bytes(), cert_path.read_bytes()) count += 1 except Exception as exc: print(f"[server] skip registry {rid_hex}: {exc}") return count # ── Command dispatch (inside encrypted session) ────────────────────── def _dispatch( cmd: dict, mgr: zkac.RegistryManager, store: _ServerStore, server_pk_b64: str, transcript_hash: bytes, conn_ctx: dict, ) -> dict: try: action = cmd.get("cmd") rid_hex = cmd.get("auth_registry_id") admin_proof_b64 = cmd.get("admin_proof_b64") def _require_admin_for_registry(target_rid_hex: str): if rid_hex != target_rid_hex: raise RuntimeError("auth_registry_id must match command registry_id") if not isinstance(admin_proof_b64, str) or not admin_proof_b64: raise RuntimeError("missing admin_proof_b64") if not mgr.verify_admin( bytes.fromhex(target_rid_hex), _unb64(admin_proof_b64), transcript_hash, ): raise RuntimeError("admin authorization failed") if action == "server_info": return {"ok": True, "server_public_key_b64": server_pk_b64} if action == "create_registry": state_bytes = _unb64(cmd["state_bytes_b64"]) state_cert = _unb64(cmd["state_cert_b64"]) auth_rid = cmd.get("auth_registry_id") if not isinstance(auth_rid, str): raise RuntimeError("missing auth_registry_id") if not isinstance(admin_proof_b64, str) or not admin_proof_b64: raise RuntimeError("missing admin_proof_b64") tmp_mgr = zkac.RegistryManager() expected_rid = tmp_mgr.create(state_bytes, state_cert).hex() if expected_rid != auth_rid: raise RuntimeError("auth_registry_id does not match certified state") if not tmp_mgr.verify_admin( bytes.fromhex(expected_rid), _unb64(admin_proof_b64), transcript_hash, ): raise RuntimeError("admin authorization failed for create_registry") rid = mgr.create(state_bytes, state_cert) store.save_registry(rid.hex(), state_bytes, state_cert) return {"ok": True, "registry_id": rid.hex()} if action == "get_registry": rid_hex_cmd = cmd["registry_id"] _require_admin_for_registry(rid_hex_cmd) rid = bytes.fromhex(rid_hex_cmd) state_bytes, state_cert = mgr.get(rid) return { "ok": True, "state_bytes_b64": _b64(state_bytes), "state_cert_b64": _b64(state_cert), } if action == "update_registry": rid_hex_cmd = cmd["registry_id"] _require_admin_for_registry(rid_hex_cmd) rid = bytes.fromhex(rid_hex_cmd) state_bytes = _unb64(cmd["state_bytes_b64"]) state_cert = _unb64(cmd["state_cert_b64"]) mgr.update(rid, state_bytes, state_cert) store.save_registry(rid_hex_cmd, state_bytes, state_cert) return {"ok": True} return {"error": f"unknown command: {action}"} except Exception as exc: return {"error": str(exc)} # ── Connection handler ──────────────────────────────────────────────── def _handle_conn( conn: socket.socket, addr: tuple, node: zkac.Node, mgr: zkac.RegistryManager, store: _ServerStore, server_pk_b64: str, idle_timeout_s: float, slots: threading.BoundedSemaphore, debug: ServerDebugState | None = None, ): peer = f"{addr[0]}:{addr[1]}" cid = debug.open_connection(peer) if debug else None err: str | None = None try: conn.settimeout(idle_timeout_s) if debug and cid: debug.update_connection(cid, phase="handshake") session = server_handshake_anon(conn, node) framed = FramedSession(conn, session) transcript_hash = bytes(session.transcript_hash()) if debug and cid: debug.update_connection( cid, phase="post_handshake", transcript_hash_hex=transcript_hash.hex(), ) hello = json.loads(framed.recv()) op = hello.get("op") if debug and cid: debug.update_connection(cid, phase=f"hello:{op}", hello_op=op) if op == "mgmt": conn_ctx: dict = {} if debug and cid: debug.update_connection(cid, phase="mgmt_loop") while True: try: data = framed.recv() except (ConnectionError, OSError): break cmd = json.loads(data) if debug and cid: debug.note_mgmt_command(cid, cmd) resp = _dispatch(cmd, mgr, store, server_pk_b64, transcript_hash, conn_ctx) framed.send(json.dumps(resp).encode()) elif op == "auth": registry_id = bytes.fromhex(hello["registry_id"]) role_id = bytes.fromhex(hello["role_id"]) proof_bytes = _unb64(hello["bbs_auth_b64"]) if debug and cid: debug.update_connection( cid, phase="auth_verify", auth_registry_hex=registry_id.hex(), auth_role_hex=role_id.hex(), ) ok = mgr.verify_presentation( registry_id, role_id, proof_bytes, transcript_hash, ) if not ok: if debug and cid: debug.update_connection(cid, phase="auth_failed", auth_ok=False) framed.send(json.dumps({"error": "auth failed"}).encode()) return if debug and cid: debug.update_connection(cid, phase="auth_ok", auth_ok=True) resp = { "status": "authenticated", "registry_id": registry_id.hex(), "role_id": role_id.hex(), } framed.send(json.dumps(resp).encode()) if debug and cid: debug.update_connection(cid, phase="auth_echo_loop") while True: try: data = framed.recv() except (ConnectionError, OSError): break if debug and cid: debug.note_echo_chunk(cid, len(data)) framed.send(data) else: if debug and cid: debug.update_connection(cid, phase="unknown_op", error=f"op={op!r}") framed.send(json.dumps({"error": f"unknown op: {op}"}).encode()) except (ConnectionError, BrokenPipeError, OSError): pass except Exception as exc: err = str(exc) if debug and cid: debug.update_connection(cid, phase="error", error=err) print(f"[server] {peer} error: {exc}") traceback.print_exc() finally: if debug and cid: debug.close_connection(cid, error=err) conn.close() slots.release() # ── Public entry point ──────────────────────────────────────────────── def serve( data_dir: str, host: str = "127.0.0.1", port: int = 9800, max_connections: int = 64, idle_timeout_s: float = 45.0, listen_backlog: int = 64, *, debug: ServerDebugState | None = None, allow_non_loopback: bool = False, ): dd = Path(data_dir) dd.mkdir(parents=True, exist_ok=True) store = _ServerStore(dd) kp = store.load_or_create_keypair() server_pk_b64 = _b64(kp.public_key().to_bytes()) pk_hex = _unb64(server_pk_b64).hex() node = zkac.Node(kp) mgr = zkac.RegistryManager() n = store.load_all_registries(mgr) if debug is not None: debug.set_listen(host, port) debug.set_boot_info(server_pk_hex=pk_hex, registries_loaded=n) print(f"server public key: {pk_hex}") print(f"loaded {n} registries") print(f"listening on {host}:{port}") if not _is_loopback_host(host): if not allow_non_loopback: raise RuntimeError( "refusing to bind outside loopback. " "Use --allow-non-loopback only when you intentionally expose this listener." ) print( f"[warning] binding outside loopback: {host}:{port}. " "Ensure network exposure is intentional.", file=sys.stderr, ) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind((host, port)) slots = threading.BoundedSemaphore(max_connections) sock.listen(listen_backlog) try: while True: conn, addr = sock.accept() if not slots.acquire(blocking=False): conn.close() continue threading.Thread( target=_handle_conn, args=(conn, addr, node, mgr, store, server_pk_b64, idle_timeout_s, slots, debug), daemon=True, ).start() except KeyboardInterrupt: print("\nshutdown") finally: sock.close()