ZKAC/cli/zkac_cli/server_app.py
2026-04-16 01:29:59 +02:00

273 lines
9.2 KiB
Python

"""TCP servers: ZKAC management, ZKAC managed app, localhost relay (credential queue)."""
from __future__ import annotations
import base64
import json
import socket
import threading
import traceback
from pathlib import Path
from typing import Callable
import zkac
from zkac.tcp import FramedSession, server_handshake, server_handshake_managed
from zkac_cli import issuance_util
from zkac_cli import registry_log
MGMT_ROLE = "zkac.mgmt"
def build_static_registry(issuer_pk: zkac.BbsPublicKey) -> zkac.RoleRegistry:
reg = zkac.RoleRegistry()
reg.register_role(zkac.role_id(MGMT_ROLE), issuer_pk, 1)
return reg
def handle_mgmt_session(
conn: socket.socket,
data_dir: Path,
mgr: zkac.RegistryManager,
registry_ids: set[bytes],
events_path: Path,
save_registry: Callable[[], None],
) -> None:
from zkac_cli import creds
framed = None
try:
sk = creds.load_server_keypair(data_dir / "transport.json")
node = zkac.Node(sk)
issuer = zkac.BbsIssuer.from_secret_key(
base64.b64decode(creds.load_json(data_dir / "mgmt_issuer.json")["issuer_secret_b64"])
)
static_reg = build_static_registry(issuer.public_key())
session, role_id = server_handshake(conn, node, static_reg)
if role_id != zkac.role_id(MGMT_ROLE):
return
framed = FramedSession(conn, session)
raw = framed.recv()
msg = json.loads(raw.decode("utf-8"))
cmd = msg["cmd"]
if cmd == "create_registry":
st = base64.b64decode(msg["state_b64"])
cert = base64.b64decode(msg["state_cert_b64"])
rid = mgr.create(st, cert)
registry_ids.add(rid)
registry_log.append_event(events_path, "create", st, cert)
save_registry()
out = {"ok": True, "registry_id_hex": rid.hex()}
elif cmd == "update_registry":
st = base64.b64decode(msg["state_b64"])
cert = base64.b64decode(msg["state_cert_b64"])
st_o = zkac.RegistryState.deserialize(st)
rid = st_o.registry_id()
mgr.update(rid, st, cert)
registry_log.append_event(events_path, "update", st, cert)
save_registry()
out = {"ok": True, "registry_id_hex": rid.hex()}
elif cmd == "get_registry":
rid = bytes.fromhex(msg["registry_id_hex"])
st, c = mgr.get(rid)
out = {
"ok": True,
"state_b64": base64.b64encode(st).decode(),
"state_cert_b64": base64.b64encode(c).decode(),
}
elif cmd == "list_registries":
out = {"ok": True, "registry_ids_hex": [h.hex() for h in sorted(registry_ids)]}
elif cmd == "issuance_peek":
rid = bytes.fromhex(msg["registry_id_hex"])
pending = issuance_util.peek_pending_requests(mgr, rid)
out = {
"ok": True,
"pending": [
{
"request_id_hex": a.hex(),
"role_id_hex": b.hex(),
"eph_pk_hex": c.hex(),
"payload_b64": base64.b64encode(d).decode(),
}
for a, b, c, d in pending
],
}
elif cmd == "issuance_grant":
rid = bytes.fromhex(msg["registry_id_hex"])
req_id = bytes.fromhex(msg["request_id_hex"])
blind = base64.b64decode(msg["blind_sig_b64"])
mgr.grant_credential(rid, req_id, blind)
save_registry()
out = {"ok": True}
else:
out = {"ok": False, "error": f"unknown cmd {cmd}"}
framed.send(json.dumps(out).encode("utf-8"))
except Exception:
traceback.print_exc()
finally:
conn.close()
def run_managed_handler(
conn: socket.socket,
data_dir: Path,
mgr: zkac.RegistryManager,
) -> None:
from zkac_cli import creds
try:
sk = creds.load_server_keypair(data_dir / "transport.json")
node = zkac.Node(sk)
session, registry_id, role_id = server_handshake_managed(conn, node, mgr)
framed = FramedSession(conn, session)
raw = framed.recv()
msg = json.loads(raw.decode("utf-8"))
cmd = msg.get("cmd", "ping")
if cmd == "get_registry":
rid = bytes.fromhex(msg["registry_id_hex"])
st, c = mgr.get(rid)
body = {
"ok": True,
"state_b64": base64.b64encode(st).decode(),
"state_cert_b64": base64.b64encode(c).decode(),
}
elif cmd == "whoami":
body = {
"ok": True,
"registry_id_hex": registry_id.hex(),
"role_id_hex": role_id.hex(),
}
else:
body = {
"ok": True,
"registry_id_hex": registry_id.hex(),
"role_id_hex": role_id.hex(),
"note": "authenticated managed session",
}
framed.send(json.dumps(body).encode("utf-8"))
except Exception:
traceback.print_exc()
finally:
conn.close()
def handle_relay_session(
conn: socket.socket,
mgr: zkac.RegistryManager,
save_registry: Callable[[], None],
) -> None:
try:
buf = b""
while b"\n" not in buf:
chunk = conn.recv(4096)
if not chunk:
return
buf += chunk
line, _, _ = buf.partition(b"\n")
msg = json.loads(line.decode("utf-8"))
cmd = msg["cmd"]
if cmd == "enqueue":
rid = bytes.fromhex(msg["registry_id_hex"])
role = zkac.role_id(msg["role_name"])
req_id = bytes.fromhex(msg["request_id_hex"])
eph = bytes.fromhex(msg["eph_pk_hex"])
blob = base64.b64decode(msg["payload_b64"])
mgr.queue_issuance_request(rid, req_id, role, eph, blob)
save_registry()
out = {"ok": True, "status": "queued"}
elif cmd == "poll":
rid = bytes.fromhex(msg["registry_id_hex"])
req_id = bytes.fromhex(msg["request_id_hex"])
g = mgr.take_granted_credential(rid, req_id)
save_registry()
if g is None:
out = {"ok": True, "status": "pending"}
else:
out = {
"ok": True,
"status": "ready",
"blind_sig_b64": base64.b64encode(g).decode(),
}
else:
out = {"ok": False, "error": "unknown relay cmd"}
conn.sendall((json.dumps(out) + "\n").encode("utf-8"))
except Exception:
traceback.print_exc()
finally:
conn.close()
def serve(
data_dir: Path,
mgmt_port: int,
managed_port: int,
relay_port: int | None,
relay_bind: str,
) -> None:
events_path = data_dir / "registry_events.json"
registry_ids: set[bytes] = set()
if events_path.is_file():
mgr = registry_log.replay_manager(registry_log.load_events(events_path))
for e in registry_log.load_events(events_path):
st = base64.b64decode(e["state_b64"])
registry_ids.add(zkac.RegistryState.deserialize(st).registry_id())
else:
mgr = zkac.RegistryManager()
def save_registry() -> None:
return
def mgmt_loop() -> None:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(("0.0.0.0", mgmt_port))
s.listen(8)
print(f"[mgmt] ZKAC listening on 0.0.0.0:{mgmt_port}")
while True:
c, a = s.accept()
print(f"[mgmt] connect {a}")
threading.Thread(
target=handle_mgmt_session,
args=(c, data_dir, mgr, registry_ids, events_path, save_registry),
daemon=True,
).start()
def managed_loop() -> None:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(("0.0.0.0", managed_port))
s.listen(8)
print(f"[managed] ZKAC listening on 0.0.0.0:{managed_port}")
while True:
c, a = s.accept()
print(f"[managed] connect {a}")
threading.Thread(
target=run_managed_handler,
args=(c, data_dir, mgr),
daemon=True,
).start()
def relay_loop() -> None:
if relay_port is None:
return
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind((relay_bind, relay_port))
s.listen(8)
print(f"[relay] plaintext JSON lines on {relay_bind}:{relay_port}")
while True:
c, a = s.accept()
print(f"[relay] connect {a}")
threading.Thread(
target=handle_relay_session,
args=(c, mgr, save_registry),
daemon=True,
).start()
threading.Thread(target=mgmt_loop, daemon=True).start()
threading.Thread(target=managed_loop, daemon=True).start()
if relay_port is not None:
threading.Thread(target=relay_loop, daemon=True).start()
threading.Event().wait()