"""TCP servers: ZKAC management, ZKAC managed app, localhost relay (credential queue).""" from __future__ import annotations import base64 import json import socket import threading import traceback from pathlib import Path from typing import Callable import zkac from zkac.tcp import FramedSession, server_handshake, server_handshake_managed from zkac_cli import issuance_util from zkac_cli import registry_log MGMT_ROLE = "zkac.mgmt" def build_static_registry(issuer_pk: zkac.BbsPublicKey) -> zkac.RoleRegistry: reg = zkac.RoleRegistry() reg.register_role(zkac.role_id(MGMT_ROLE), issuer_pk, 1) return reg def handle_mgmt_session( conn: socket.socket, data_dir: Path, mgr: zkac.RegistryManager, registry_ids: set[bytes], events_path: Path, save_registry: Callable[[], None], ) -> None: from zkac_cli import creds framed = None try: sk = creds.load_server_keypair(data_dir / "transport.json") node = zkac.Node(sk) issuer = zkac.BbsIssuer.from_secret_key( base64.b64decode(creds.load_json(data_dir / "mgmt_issuer.json")["issuer_secret_b64"]) ) static_reg = build_static_registry(issuer.public_key()) session, role_id = server_handshake(conn, node, static_reg) if role_id != zkac.role_id(MGMT_ROLE): return framed = FramedSession(conn, session) raw = framed.recv() msg = json.loads(raw.decode("utf-8")) cmd = msg["cmd"] if cmd == "create_registry": st = base64.b64decode(msg["state_b64"]) cert = base64.b64decode(msg["state_cert_b64"]) rid = mgr.create(st, cert) registry_ids.add(rid) registry_log.append_event(events_path, "create", st, cert) save_registry() out = {"ok": True, "registry_id_hex": rid.hex()} elif cmd == "update_registry": st = base64.b64decode(msg["state_b64"]) cert = base64.b64decode(msg["state_cert_b64"]) st_o = zkac.RegistryState.deserialize(st) rid = st_o.registry_id() mgr.update(rid, st, cert) registry_log.append_event(events_path, "update", st, cert) save_registry() out = {"ok": True, "registry_id_hex": rid.hex()} elif cmd == "get_registry": rid = bytes.fromhex(msg["registry_id_hex"]) st, c = mgr.get(rid) out = { "ok": True, "state_b64": base64.b64encode(st).decode(), "state_cert_b64": base64.b64encode(c).decode(), } elif cmd == "list_registries": out = {"ok": True, "registry_ids_hex": [h.hex() for h in sorted(registry_ids)]} elif cmd == "issuance_peek": rid = bytes.fromhex(msg["registry_id_hex"]) pending = issuance_util.peek_pending_requests(mgr, rid) out = { "ok": True, "pending": [ { "request_id_hex": a.hex(), "role_id_hex": b.hex(), "eph_pk_hex": c.hex(), "payload_b64": base64.b64encode(d).decode(), } for a, b, c, d in pending ], } elif cmd == "issuance_grant": rid = bytes.fromhex(msg["registry_id_hex"]) req_id = bytes.fromhex(msg["request_id_hex"]) blind = base64.b64decode(msg["blind_sig_b64"]) mgr.grant_credential(rid, req_id, blind) save_registry() out = {"ok": True} else: out = {"ok": False, "error": f"unknown cmd {cmd}"} framed.send(json.dumps(out).encode("utf-8")) except Exception: traceback.print_exc() finally: conn.close() def run_managed_handler( conn: socket.socket, data_dir: Path, mgr: zkac.RegistryManager, ) -> None: from zkac_cli import creds try: sk = creds.load_server_keypair(data_dir / "transport.json") node = zkac.Node(sk) session, registry_id, role_id = server_handshake_managed(conn, node, mgr) framed = FramedSession(conn, session) raw = framed.recv() msg = json.loads(raw.decode("utf-8")) cmd = msg.get("cmd", "ping") if cmd == "get_registry": rid = bytes.fromhex(msg["registry_id_hex"]) st, c = mgr.get(rid) body = { "ok": True, "state_b64": base64.b64encode(st).decode(), "state_cert_b64": base64.b64encode(c).decode(), } elif cmd == "whoami": body = { "ok": True, "registry_id_hex": registry_id.hex(), "role_id_hex": role_id.hex(), } else: body = { "ok": True, "registry_id_hex": registry_id.hex(), "role_id_hex": role_id.hex(), "note": "authenticated managed session", } framed.send(json.dumps(body).encode("utf-8")) except Exception: traceback.print_exc() finally: conn.close() def handle_relay_session( conn: socket.socket, mgr: zkac.RegistryManager, save_registry: Callable[[], None], ) -> None: try: buf = b"" while b"\n" not in buf: chunk = conn.recv(4096) if not chunk: return buf += chunk line, _, _ = buf.partition(b"\n") msg = json.loads(line.decode("utf-8")) cmd = msg["cmd"] if cmd == "enqueue": rid = bytes.fromhex(msg["registry_id_hex"]) role = zkac.role_id(msg["role_name"]) req_id = bytes.fromhex(msg["request_id_hex"]) eph = bytes.fromhex(msg["eph_pk_hex"]) blob = base64.b64decode(msg["payload_b64"]) mgr.queue_issuance_request(rid, req_id, role, eph, blob) save_registry() out = {"ok": True, "status": "queued"} elif cmd == "poll": rid = bytes.fromhex(msg["registry_id_hex"]) req_id = bytes.fromhex(msg["request_id_hex"]) g = mgr.take_granted_credential(rid, req_id) save_registry() if g is None: out = {"ok": True, "status": "pending"} else: out = { "ok": True, "status": "ready", "blind_sig_b64": base64.b64encode(g).decode(), } else: out = {"ok": False, "error": "unknown relay cmd"} conn.sendall((json.dumps(out) + "\n").encode("utf-8")) except Exception: traceback.print_exc() finally: conn.close() def serve( data_dir: Path, mgmt_port: int, managed_port: int, relay_port: int | None, relay_bind: str, ) -> None: events_path = data_dir / "registry_events.json" registry_ids: set[bytes] = set() if events_path.is_file(): mgr = registry_log.replay_manager(registry_log.load_events(events_path)) for e in registry_log.load_events(events_path): st = base64.b64decode(e["state_b64"]) registry_ids.add(zkac.RegistryState.deserialize(st).registry_id()) else: mgr = zkac.RegistryManager() def save_registry() -> None: return def mgmt_loop() -> None: s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind(("0.0.0.0", mgmt_port)) s.listen(8) print(f"[mgmt] ZKAC listening on 0.0.0.0:{mgmt_port}") while True: c, a = s.accept() print(f"[mgmt] connect {a}") threading.Thread( target=handle_mgmt_session, args=(c, data_dir, mgr, registry_ids, events_path, save_registry), daemon=True, ).start() def managed_loop() -> None: s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind(("0.0.0.0", managed_port)) s.listen(8) print(f"[managed] ZKAC listening on 0.0.0.0:{managed_port}") while True: c, a = s.accept() print(f"[managed] connect {a}") threading.Thread( target=run_managed_handler, args=(c, data_dir, mgr), daemon=True, ).start() def relay_loop() -> None: if relay_port is None: return s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind((relay_bind, relay_port)) s.listen(8) print(f"[relay] plaintext JSON lines on {relay_bind}:{relay_port}") while True: c, a = s.accept() print(f"[relay] connect {a}") threading.Thread( target=handle_relay_session, args=(c, mgr, save_registry), daemon=True, ).start() threading.Thread(target=mgmt_loop, daemon=True).start() threading.Thread(target=managed_loop, daemon=True).start() if relay_port is not None: threading.Thread(target=relay_loop, daemon=True).start() threading.Event().wait()