273 lines
9.2 KiB
Python
273 lines
9.2 KiB
Python
"""TCP servers: ZKAC management, ZKAC managed app, localhost relay (credential queue)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import json
|
|
import socket
|
|
import threading
|
|
import traceback
|
|
from pathlib import Path
|
|
from typing import Callable
|
|
|
|
import zkac
|
|
from zkac.tcp import FramedSession, server_handshake, server_handshake_managed
|
|
|
|
from zkac_cli import issuance_util
|
|
from zkac_cli import registry_log
|
|
|
|
|
|
MGMT_ROLE = "zkac.mgmt"
|
|
|
|
|
|
def build_static_registry(issuer_pk: zkac.BbsPublicKey) -> zkac.RoleRegistry:
|
|
reg = zkac.RoleRegistry()
|
|
reg.register_role(zkac.role_id(MGMT_ROLE), issuer_pk, 1)
|
|
return reg
|
|
|
|
|
|
def handle_mgmt_session(
|
|
conn: socket.socket,
|
|
data_dir: Path,
|
|
mgr: zkac.RegistryManager,
|
|
registry_ids: set[bytes],
|
|
events_path: Path,
|
|
save_registry: Callable[[], None],
|
|
) -> None:
|
|
from zkac_cli import creds
|
|
|
|
framed = None
|
|
try:
|
|
sk = creds.load_server_keypair(data_dir / "transport.json")
|
|
node = zkac.Node(sk)
|
|
issuer = zkac.BbsIssuer.from_secret_key(
|
|
base64.b64decode(creds.load_json(data_dir / "mgmt_issuer.json")["issuer_secret_b64"])
|
|
)
|
|
static_reg = build_static_registry(issuer.public_key())
|
|
session, role_id = server_handshake(conn, node, static_reg)
|
|
if role_id != zkac.role_id(MGMT_ROLE):
|
|
return
|
|
framed = FramedSession(conn, session)
|
|
raw = framed.recv()
|
|
msg = json.loads(raw.decode("utf-8"))
|
|
cmd = msg["cmd"]
|
|
if cmd == "create_registry":
|
|
st = base64.b64decode(msg["state_b64"])
|
|
cert = base64.b64decode(msg["state_cert_b64"])
|
|
rid = mgr.create(st, cert)
|
|
registry_ids.add(rid)
|
|
registry_log.append_event(events_path, "create", st, cert)
|
|
save_registry()
|
|
out = {"ok": True, "registry_id_hex": rid.hex()}
|
|
elif cmd == "update_registry":
|
|
st = base64.b64decode(msg["state_b64"])
|
|
cert = base64.b64decode(msg["state_cert_b64"])
|
|
st_o = zkac.RegistryState.deserialize(st)
|
|
rid = st_o.registry_id()
|
|
mgr.update(rid, st, cert)
|
|
registry_log.append_event(events_path, "update", st, cert)
|
|
save_registry()
|
|
out = {"ok": True, "registry_id_hex": rid.hex()}
|
|
elif cmd == "get_registry":
|
|
rid = bytes.fromhex(msg["registry_id_hex"])
|
|
st, c = mgr.get(rid)
|
|
out = {
|
|
"ok": True,
|
|
"state_b64": base64.b64encode(st).decode(),
|
|
"state_cert_b64": base64.b64encode(c).decode(),
|
|
}
|
|
elif cmd == "list_registries":
|
|
out = {"ok": True, "registry_ids_hex": [h.hex() for h in sorted(registry_ids)]}
|
|
elif cmd == "issuance_peek":
|
|
rid = bytes.fromhex(msg["registry_id_hex"])
|
|
pending = issuance_util.peek_pending_requests(mgr, rid)
|
|
out = {
|
|
"ok": True,
|
|
"pending": [
|
|
{
|
|
"request_id_hex": a.hex(),
|
|
"role_id_hex": b.hex(),
|
|
"eph_pk_hex": c.hex(),
|
|
"payload_b64": base64.b64encode(d).decode(),
|
|
}
|
|
for a, b, c, d in pending
|
|
],
|
|
}
|
|
elif cmd == "issuance_grant":
|
|
rid = bytes.fromhex(msg["registry_id_hex"])
|
|
req_id = bytes.fromhex(msg["request_id_hex"])
|
|
blind = base64.b64decode(msg["blind_sig_b64"])
|
|
mgr.grant_credential(rid, req_id, blind)
|
|
save_registry()
|
|
out = {"ok": True}
|
|
else:
|
|
out = {"ok": False, "error": f"unknown cmd {cmd}"}
|
|
framed.send(json.dumps(out).encode("utf-8"))
|
|
except Exception:
|
|
traceback.print_exc()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def run_managed_handler(
|
|
conn: socket.socket,
|
|
data_dir: Path,
|
|
mgr: zkac.RegistryManager,
|
|
) -> None:
|
|
from zkac_cli import creds
|
|
|
|
try:
|
|
sk = creds.load_server_keypair(data_dir / "transport.json")
|
|
node = zkac.Node(sk)
|
|
session, registry_id, role_id = server_handshake_managed(conn, node, mgr)
|
|
framed = FramedSession(conn, session)
|
|
raw = framed.recv()
|
|
msg = json.loads(raw.decode("utf-8"))
|
|
cmd = msg.get("cmd", "ping")
|
|
if cmd == "get_registry":
|
|
rid = bytes.fromhex(msg["registry_id_hex"])
|
|
st, c = mgr.get(rid)
|
|
body = {
|
|
"ok": True,
|
|
"state_b64": base64.b64encode(st).decode(),
|
|
"state_cert_b64": base64.b64encode(c).decode(),
|
|
}
|
|
elif cmd == "whoami":
|
|
body = {
|
|
"ok": True,
|
|
"registry_id_hex": registry_id.hex(),
|
|
"role_id_hex": role_id.hex(),
|
|
}
|
|
else:
|
|
body = {
|
|
"ok": True,
|
|
"registry_id_hex": registry_id.hex(),
|
|
"role_id_hex": role_id.hex(),
|
|
"note": "authenticated managed session",
|
|
}
|
|
framed.send(json.dumps(body).encode("utf-8"))
|
|
except Exception:
|
|
traceback.print_exc()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def handle_relay_session(
|
|
conn: socket.socket,
|
|
mgr: zkac.RegistryManager,
|
|
save_registry: Callable[[], None],
|
|
) -> None:
|
|
try:
|
|
buf = b""
|
|
while b"\n" not in buf:
|
|
chunk = conn.recv(4096)
|
|
if not chunk:
|
|
return
|
|
buf += chunk
|
|
line, _, _ = buf.partition(b"\n")
|
|
msg = json.loads(line.decode("utf-8"))
|
|
cmd = msg["cmd"]
|
|
if cmd == "enqueue":
|
|
rid = bytes.fromhex(msg["registry_id_hex"])
|
|
role = zkac.role_id(msg["role_name"])
|
|
req_id = bytes.fromhex(msg["request_id_hex"])
|
|
eph = bytes.fromhex(msg["eph_pk_hex"])
|
|
blob = base64.b64decode(msg["payload_b64"])
|
|
mgr.queue_issuance_request(rid, req_id, role, eph, blob)
|
|
save_registry()
|
|
out = {"ok": True, "status": "queued"}
|
|
elif cmd == "poll":
|
|
rid = bytes.fromhex(msg["registry_id_hex"])
|
|
req_id = bytes.fromhex(msg["request_id_hex"])
|
|
g = mgr.take_granted_credential(rid, req_id)
|
|
save_registry()
|
|
if g is None:
|
|
out = {"ok": True, "status": "pending"}
|
|
else:
|
|
out = {
|
|
"ok": True,
|
|
"status": "ready",
|
|
"blind_sig_b64": base64.b64encode(g).decode(),
|
|
}
|
|
else:
|
|
out = {"ok": False, "error": "unknown relay cmd"}
|
|
conn.sendall((json.dumps(out) + "\n").encode("utf-8"))
|
|
except Exception:
|
|
traceback.print_exc()
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def serve(
|
|
data_dir: Path,
|
|
mgmt_port: int,
|
|
managed_port: int,
|
|
relay_port: int | None,
|
|
relay_bind: str,
|
|
) -> None:
|
|
events_path = data_dir / "registry_events.json"
|
|
registry_ids: set[bytes] = set()
|
|
if events_path.is_file():
|
|
mgr = registry_log.replay_manager(registry_log.load_events(events_path))
|
|
for e in registry_log.load_events(events_path):
|
|
st = base64.b64decode(e["state_b64"])
|
|
registry_ids.add(zkac.RegistryState.deserialize(st).registry_id())
|
|
else:
|
|
mgr = zkac.RegistryManager()
|
|
|
|
def save_registry() -> None:
|
|
return
|
|
|
|
def mgmt_loop() -> None:
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
s.bind(("0.0.0.0", mgmt_port))
|
|
s.listen(8)
|
|
print(f"[mgmt] ZKAC listening on 0.0.0.0:{mgmt_port}")
|
|
while True:
|
|
c, a = s.accept()
|
|
print(f"[mgmt] connect {a}")
|
|
threading.Thread(
|
|
target=handle_mgmt_session,
|
|
args=(c, data_dir, mgr, registry_ids, events_path, save_registry),
|
|
daemon=True,
|
|
).start()
|
|
|
|
def managed_loop() -> None:
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
s.bind(("0.0.0.0", managed_port))
|
|
s.listen(8)
|
|
print(f"[managed] ZKAC listening on 0.0.0.0:{managed_port}")
|
|
while True:
|
|
c, a = s.accept()
|
|
print(f"[managed] connect {a}")
|
|
threading.Thread(
|
|
target=run_managed_handler,
|
|
args=(c, data_dir, mgr),
|
|
daemon=True,
|
|
).start()
|
|
|
|
def relay_loop() -> None:
|
|
if relay_port is None:
|
|
return
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
s.bind((relay_bind, relay_port))
|
|
s.listen(8)
|
|
print(f"[relay] plaintext JSON lines on {relay_bind}:{relay_port}")
|
|
while True:
|
|
c, a = s.accept()
|
|
print(f"[relay] connect {a}")
|
|
threading.Thread(
|
|
target=handle_relay_session,
|
|
args=(c, mgr, save_registry),
|
|
daemon=True,
|
|
).start()
|
|
|
|
threading.Thread(target=mgmt_loop, daemon=True).start()
|
|
threading.Thread(target=managed_loop, daemon=True).start()
|
|
if relay_port is not None:
|
|
threading.Thread(target=relay_loop, daemon=True).start()
|
|
threading.Event().wait()
|