ZKAC/cli/zkac_cli/server.py
2026-05-06 17:39:11 +02:00

347 lines
12 KiB
Python

"""ZKAC server for registry management and role authentication."""
from __future__ import annotations
import base64
import json
import os
import socket
import sys
import threading
import traceback
from pathlib import Path
import zkac
from zkac.tcp import FramedSession, server_handshake_anon
from .server_debug import ServerDebugState
def _b64(data: bytes) -> str:
return base64.b64encode(data).decode()
def _unb64(s: str) -> bytes:
return base64.b64decode(s)
def _chmod_if_possible(path: Path, mode: int):
try:
os.chmod(path, mode)
except OSError:
pass
def _write_private_json(path: Path, payload: dict):
path.write_text(json.dumps(payload, indent=2))
_chmod_if_possible(path, 0o600)
def _is_loopback_host(host: str) -> bool:
value = host.strip().lower()
return value in {"127.0.0.1", "::1", "localhost"}
# ── Opaque server storage ─────────────────────────────────────────────
class _ServerStore:
"""Thread-safe, opaque persistence for registry snapshots."""
def __init__(self, data_dir: Path):
self._dir = data_dir
self._reg_dir = data_dir / "registries"
self._reg_dir.mkdir(parents=True, exist_ok=True)
_chmod_if_possible(self._dir, 0o700)
_chmod_if_possible(self._reg_dir, 0o700)
self._lock = threading.Lock()
# ── server key ────────────────────────────────────────────────────
def load_or_create_keypair(self) -> zkac.Keypair:
kf = self._dir / "server_key.json"
if kf.exists():
data = json.loads(kf.read_text())
return zkac.Keypair.from_secret_key(_unb64(data["secret_b64"]))
kp = zkac.Keypair()
_write_private_json(kf, {
"secret_b64": _b64(kp.secret_key_bytes()),
"public_b64": _b64(kp.public_key().to_bytes()),
})
return kp
# ── registries ────────────────────────────────────────────────────
def save_registry(self, rid_hex: str, state_bytes: bytes, cert_bytes: bytes):
with self._lock:
(self._reg_dir / f"{rid_hex}.state").write_bytes(state_bytes)
(self._reg_dir / f"{rid_hex}.cert").write_bytes(cert_bytes)
def load_all_registries(self, mgr: zkac.RegistryManager) -> int:
count = 0
for p in sorted(self._reg_dir.glob("*.state")):
rid_hex = p.stem
cert_path = self._reg_dir / f"{rid_hex}.cert"
if not cert_path.exists():
continue
try:
mgr.restore(p.read_bytes(), cert_path.read_bytes())
count += 1
except Exception as exc:
print(f"[server] skip registry {rid_hex}: {exc}")
return count
# ── Command dispatch (inside encrypted session) ──────────────────────
def _dispatch(
cmd: dict,
mgr: zkac.RegistryManager,
store: _ServerStore,
server_pk_b64: str,
transcript_hash: bytes,
conn_ctx: dict,
) -> dict:
try:
action = cmd.get("cmd")
rid_hex = cmd.get("auth_registry_id")
admin_proof_b64 = cmd.get("admin_proof_b64")
def _require_admin_for_registry(target_rid_hex: str):
if rid_hex != target_rid_hex:
raise RuntimeError("auth_registry_id must match command registry_id")
if not isinstance(admin_proof_b64, str) or not admin_proof_b64:
raise RuntimeError("missing admin_proof_b64")
if not mgr.verify_admin(
bytes.fromhex(target_rid_hex),
_unb64(admin_proof_b64),
transcript_hash,
):
raise RuntimeError("admin authorization failed")
if action == "server_info":
return {"ok": True, "server_public_key_b64": server_pk_b64}
if action == "create_registry":
state_bytes = _unb64(cmd["state_bytes_b64"])
state_cert = _unb64(cmd["state_cert_b64"])
auth_rid = cmd.get("auth_registry_id")
if not isinstance(auth_rid, str):
raise RuntimeError("missing auth_registry_id")
if not isinstance(admin_proof_b64, str) or not admin_proof_b64:
raise RuntimeError("missing admin_proof_b64")
tmp_mgr = zkac.RegistryManager()
expected_rid = tmp_mgr.create(state_bytes, state_cert).hex()
if expected_rid != auth_rid:
raise RuntimeError("auth_registry_id does not match certified state")
if not tmp_mgr.verify_admin(
bytes.fromhex(expected_rid),
_unb64(admin_proof_b64),
transcript_hash,
):
raise RuntimeError("admin authorization failed for create_registry")
rid = mgr.create(state_bytes, state_cert)
store.save_registry(rid.hex(), state_bytes, state_cert)
return {"ok": True, "registry_id": rid.hex()}
if action == "get_registry":
rid_hex_cmd = cmd["registry_id"]
_require_admin_for_registry(rid_hex_cmd)
rid = bytes.fromhex(rid_hex_cmd)
state_bytes, state_cert = mgr.get(rid)
return {
"ok": True,
"state_bytes_b64": _b64(state_bytes),
"state_cert_b64": _b64(state_cert),
}
if action == "update_registry":
rid_hex_cmd = cmd["registry_id"]
_require_admin_for_registry(rid_hex_cmd)
rid = bytes.fromhex(rid_hex_cmd)
state_bytes = _unb64(cmd["state_bytes_b64"])
state_cert = _unb64(cmd["state_cert_b64"])
mgr.update(rid, state_bytes, state_cert)
store.save_registry(rid_hex_cmd, state_bytes, state_cert)
return {"ok": True}
return {"error": f"unknown command: {action}"}
except Exception as exc:
return {"error": str(exc)}
# ── Connection handler ────────────────────────────────────────────────
def _handle_conn(
conn: socket.socket,
addr: tuple,
node: zkac.Node,
mgr: zkac.RegistryManager,
store: _ServerStore,
server_pk_b64: str,
idle_timeout_s: float,
slots: threading.BoundedSemaphore,
debug: ServerDebugState | None = None,
):
peer = f"{addr[0]}:{addr[1]}"
cid = debug.open_connection(peer) if debug else None
err: str | None = None
try:
conn.settimeout(idle_timeout_s)
if debug and cid:
debug.update_connection(cid, phase="handshake")
session = server_handshake_anon(conn, node)
framed = FramedSession(conn, session)
transcript_hash = bytes(session.transcript_hash())
if debug and cid:
debug.update_connection(
cid,
phase="post_handshake",
transcript_hash_hex=transcript_hash.hex(),
)
hello = json.loads(framed.recv())
op = hello.get("op")
if debug and cid:
debug.update_connection(cid, phase=f"hello:{op}", hello_op=op)
if op == "mgmt":
conn_ctx: dict = {}
if debug and cid:
debug.update_connection(cid, phase="mgmt_loop")
while True:
try:
data = framed.recv()
except (ConnectionError, OSError):
break
cmd = json.loads(data)
if debug and cid:
debug.note_mgmt_command(cid, cmd)
resp = _dispatch(cmd, mgr, store, server_pk_b64, transcript_hash, conn_ctx)
framed.send(json.dumps(resp).encode())
elif op == "auth":
registry_id = bytes.fromhex(hello["registry_id"])
role_id = bytes.fromhex(hello["role_id"])
proof_bytes = _unb64(hello["bbs_auth_b64"])
if debug and cid:
debug.update_connection(
cid,
phase="auth_verify",
auth_registry_hex=registry_id.hex(),
auth_role_hex=role_id.hex(),
)
ok = mgr.verify_presentation(
registry_id, role_id, proof_bytes, transcript_hash,
)
if not ok:
if debug and cid:
debug.update_connection(cid, phase="auth_failed", auth_ok=False)
framed.send(json.dumps({"error": "auth failed"}).encode())
return
if debug and cid:
debug.update_connection(cid, phase="auth_ok", auth_ok=True)
resp = {
"status": "authenticated",
"registry_id": registry_id.hex(),
"role_id": role_id.hex(),
}
framed.send(json.dumps(resp).encode())
if debug and cid:
debug.update_connection(cid, phase="auth_echo_loop")
while True:
try:
data = framed.recv()
except (ConnectionError, OSError):
break
if debug and cid:
debug.note_echo_chunk(cid, len(data))
framed.send(data)
else:
if debug and cid:
debug.update_connection(cid, phase="unknown_op", error=f"op={op!r}")
framed.send(json.dumps({"error": f"unknown op: {op}"}).encode())
except (ConnectionError, BrokenPipeError, OSError):
pass
except Exception as exc:
err = str(exc)
if debug and cid:
debug.update_connection(cid, phase="error", error=err)
print(f"[server] {peer} error: {exc}")
traceback.print_exc()
finally:
if debug and cid:
debug.close_connection(cid, error=err)
conn.close()
slots.release()
# ── Public entry point ────────────────────────────────────────────────
def serve(
data_dir: str,
host: str = "127.0.0.1",
port: int = 9800,
max_connections: int = 64,
idle_timeout_s: float = 45.0,
listen_backlog: int = 64,
*,
debug: ServerDebugState | None = None,
allow_non_loopback: bool = False,
):
dd = Path(data_dir)
dd.mkdir(parents=True, exist_ok=True)
store = _ServerStore(dd)
kp = store.load_or_create_keypair()
server_pk_b64 = _b64(kp.public_key().to_bytes())
pk_hex = _unb64(server_pk_b64).hex()
node = zkac.Node(kp)
mgr = zkac.RegistryManager()
n = store.load_all_registries(mgr)
if debug is not None:
debug.set_listen(host, port)
debug.set_boot_info(server_pk_hex=pk_hex, registries_loaded=n)
print(f"server public key: {pk_hex}")
print(f"loaded {n} registries")
print(f"listening on {host}:{port}")
if not _is_loopback_host(host):
if not allow_non_loopback:
raise RuntimeError(
"refusing to bind outside loopback. "
"Use --allow-non-loopback only when you intentionally expose this listener."
)
print(
f"[warning] binding outside loopback: {host}:{port}. "
"Ensure network exposure is intentional.",
file=sys.stderr,
)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((host, port))
slots = threading.BoundedSemaphore(max_connections)
sock.listen(listen_backlog)
try:
while True:
conn, addr = sock.accept()
if not slots.acquire(blocking=False):
conn.close()
continue
threading.Thread(
target=_handle_conn,
args=(conn, addr, node, mgr, store, server_pk_b64, idle_timeout_s, slots, debug),
daemon=True,
).start()
except KeyboardInterrupt:
print("\nshutdown")
finally:
sock.close()