ZKAC/cli/zkac_cli/server.py
2026-04-19 14:26:47 +02:00

392 lines
14 KiB
Python

"""ZKAC server: all traffic over a single encrypted, server-authenticated channel.
Every connection performs an anonymous handshake (X25519 + Schnorr server
identity proof). The first encrypted frame selects the mode:
{"op": "mgmt"} management commands (create_registry, post_grant, ...)
{"op": "auth", "registry_id": hex, "bbs_auth_b64": ...} role authentication
The server stores only cryptographically verified opaque blobs:
<data_dir>/server_key.json Schnorr keypair
<data_dir>/registries/<rid>.state raw RegistryState bytes
<data_dir>/registries/<rid>.cert raw state cert bytes
<data_dir>/mailbox/grants_pool.json anonymous append-only grant pool
Recipients discover matching grants via a cheap detection-tag index
(``pool_tags``) and retrieve individual rows via single-server
LWE-based PIR (``pir_query`` with precomputed ``pir_hints``).
"""
from __future__ import annotations
import base64
import json
import os
import socket
import threading
import traceback
from pathlib import Path
import zkac
from zkac.tcp import FramedSession, server_handshake_anon
def _b64(data: bytes) -> str:
return base64.b64encode(data).decode()
def _unb64(s: str) -> bytes:
return base64.b64decode(s)
def _pad_grant_record(entry: dict) -> bytes:
raw = json.dumps(entry, separators=(",", ":"), sort_keys=True).encode("utf-8")
if len(raw) > zkac.PIR_RECORD_BYTES:
raise ValueError(
f"grant record exceeds PIR_RECORD_BYTES ({zkac.PIR_RECORD_BYTES})"
)
return raw + b"\x00" * (zkac.PIR_RECORD_BYTES - len(raw))
# ── Opaque server storage ─────────────────────────────────────────────
class _ServerStore:
"""Thread-safe, opaque persistence for registries and anonymous grant pool."""
def __init__(self, data_dir: Path):
self._dir = data_dir
self._reg_dir = data_dir / "registries"
self._mbox_dir = data_dir / "mailbox"
self._pool_path = self._mbox_dir / "grants_pool.json"
self._reg_dir.mkdir(parents=True, exist_ok=True)
self._mbox_dir.mkdir(parents=True, exist_ok=True)
self._lock = threading.Lock()
self._pir_server: zkac.PirServer | None = None
self._pir_dirty = True
self._migrate_legacy_mailbox()
# ── server key ────────────────────────────────────────────────────
def load_or_create_keypair(self) -> zkac.Keypair:
kf = self._dir / "server_key.json"
if kf.exists():
data = json.loads(kf.read_text())
return zkac.Keypair.from_secret_key(_unb64(data["secret_b64"]))
kp = zkac.Keypair()
kf.write_text(json.dumps({
"secret_b64": _b64(kp.secret_key_bytes()),
"public_b64": _b64(kp.public_key().to_bytes()),
}, indent=2))
return kp
# ── registries ────────────────────────────────────────────────────
def save_registry(self, rid_hex: str, state_bytes: bytes, cert_bytes: bytes):
with self._lock:
(self._reg_dir / f"{rid_hex}.state").write_bytes(state_bytes)
(self._reg_dir / f"{rid_hex}.cert").write_bytes(cert_bytes)
def load_all_registries(self, mgr: zkac.RegistryManager) -> int:
count = 0
for p in sorted(self._reg_dir.glob("*.state")):
rid_hex = p.stem
cert_path = self._reg_dir / f"{rid_hex}.cert"
if not cert_path.exists():
continue
try:
mgr.create(p.read_bytes(), cert_path.read_bytes())
count += 1
except Exception as exc:
print(f"[server] skip registry {rid_hex}: {exc}")
return count
# ── legacy per-recipient mailbox → anonymous pool ─────────────────
def _migrate_legacy_mailbox(self):
if self._pool_path.exists():
return
records: list[dict] = []
for p in sorted(self._mbox_dir.glob("*.json")):
if p.name == "grants_pool.json":
continue
try:
entries = json.loads(p.read_text())
for e in entries:
if "claimed" not in e:
e["claimed"] = False
records.append(e)
except Exception as exc:
print(f"[server] skip legacy mailbox {p.name}: {exc}")
try:
p.unlink()
except OSError:
pass
if records:
self._pool_path.write_text(
json.dumps({"version": 1, "records": records}, indent=2)
)
def _load_pool(self) -> list[dict]:
if not self._pool_path.exists():
return []
data = json.loads(self._pool_path.read_text())
return data.get("records", [])
def _save_pool(self, records: list[dict]):
self._pool_path.write_text(
json.dumps({"version": 1, "records": records}, indent=2)
)
# ── PIR server rebuild ────────────────────────────────────────────
def _ensure_pir(self) -> zkac.PirServer:
if self._pir_server is not None and not self._pir_dirty:
return self._pir_server
records = self._load_pool()
packed = [_pad_grant_record(r) for r in records]
db = zkac.PirDatabase(packed, zkac.PIR_RECORD_BYTES)
self._pir_server = zkac.PirServer(db)
self._pir_dirty = False
return self._pir_server
# ── anonymous grant pool ─────────────────────────────────────────
def post_grant(self, entry: dict) -> tuple[str, int]:
grant_id = os.urandom(16).hex()
row = {"grant_id": grant_id, "claimed": False, **entry}
_pad_grant_record(row)
with self._lock:
records = self._load_pool()
pool_index = len(records)
records.append(row)
self._save_pool(records)
self._pir_dirty = True
return grant_id, pool_index
def pool_info(self) -> dict:
with self._lock:
pir = self._ensure_pir()
return {
"n": pir.n_records,
"record_bytes": pir.record_bytes,
"pool_version": bytes(pir.version()).hex(),
}
def pool_tags(self) -> tuple[list[tuple[str, str]], str]:
with self._lock:
records = self._load_pool()
pir = self._ensure_pir()
version = bytes(pir.version()).hex()
tags = []
for r in records:
tags.append((
r.get("eph_pk_b64", ""),
r.get("to_tag_b64", ""),
))
return tags, version
def pir_hints(self) -> tuple[bytes, str]:
with self._lock:
pir = self._ensure_pir()
return bytes(pir.hints()), bytes(pir.version()).hex()
def pir_answer(self, query_b64: str, pool_version: str) -> tuple[str, str] | None:
with self._lock:
pir = self._ensure_pir()
current = bytes(pir.version()).hex()
if current != pool_version:
return None
ans = pir.answer(_unb64(query_b64))
return _b64(ans), current
def claim_grant(self, grant_id: str) -> dict | None:
with self._lock:
records = self._load_pool()
for i, e in enumerate(records):
if e["grant_id"] == grant_id and not e.get("claimed", False):
e["claimed"] = True
records[i] = e
self._save_pool(records)
self._pir_dirty = True
return dict(e)
return None
# ── Command dispatch (inside encrypted session) ──────────────────────
def _dispatch(cmd: dict, mgr: zkac.RegistryManager, store: _ServerStore,
server_pk_b64: str, transcript_hash: bytes) -> dict:
try:
action = cmd.get("cmd")
if action == "server_info":
return {"ok": True, "server_public_key_b64": server_pk_b64}
if action == "create_registry":
state_bytes = _unb64(cmd["state_bytes_b64"])
state_cert = _unb64(cmd["state_cert_b64"])
rid = mgr.create(state_bytes, state_cert)
store.save_registry(rid.hex(), state_bytes, state_cert)
return {"ok": True, "registry_id": rid.hex()}
if action == "get_registry":
rid = bytes.fromhex(cmd["registry_id"])
state_bytes, state_cert = mgr.get(rid)
return {
"ok": True,
"state_bytes_b64": _b64(state_bytes),
"state_cert_b64": _b64(state_cert),
}
if action == "update_registry":
rid = bytes.fromhex(cmd["registry_id"])
state_bytes = _unb64(cmd["state_bytes_b64"])
state_cert = _unb64(cmd["state_cert_b64"])
mgr.update(rid, state_bytes, state_cert)
store.save_registry(cmd["registry_id"], state_bytes, state_cert)
return {"ok": True}
if action == "post_grant":
rid = bytes.fromhex(cmd["registry_id"])
proof = _unb64(cmd["admin_proof_b64"])
if not mgr.verify_admin(rid, proof, transcript_hash):
return {"error": "admin proof failed"}
entry = {
"eph_pk_b64": cmd["eph_pk_b64"],
"ciphertext_b64": cmd["ciphertext_b64"],
"to_tag_b64": cmd.get("to_tag_b64", ""),
}
gid, pool_index = store.post_grant(entry)
return {"ok": True, "grant_id": gid, "pool_index": pool_index}
if action == "pool_info":
info = store.pool_info()
return {"ok": True, **info}
if action == "pool_tags":
tags, version = store.pool_tags()
entries = [
{"eph_pk_b64": epk, "to_tag_b64": tag}
for epk, tag in tags
]
return {"ok": True, "tags": entries, "pool_version": version}
if action == "pir_hints":
hints_bytes, version = store.pir_hints()
return {"ok": True, "hints_b64": _b64(hints_bytes), "pool_version": version}
if action == "pir_query":
result = store.pir_answer(cmd["query_b64"], cmd["pool_version"])
if result is None:
return {"error": "stale_version"}
ans_b64, version = result
return {"ok": True, "answer_b64": ans_b64, "pool_version": version}
if action == "claim_grant":
entry = store.claim_grant(cmd["grant_id"])
if entry is None:
return {"error": "grant not found"}
return {"ok": True, "grant": entry}
return {"error": f"unknown command: {action}"}
except Exception as exc:
return {"error": str(exc)}
# ── Connection handler ────────────────────────────────────────────────
def _handle_conn(conn: socket.socket, addr: tuple, node: zkac.Node,
mgr: zkac.RegistryManager, store: _ServerStore,
server_pk_b64: str):
peer = f"{addr[0]}:{addr[1]}"
try:
session = server_handshake_anon(conn, node)
framed = FramedSession(conn, session)
transcript_hash = bytes(session.transcript_hash())
hello = json.loads(framed.recv())
op = hello.get("op")
if op == "mgmt":
while True:
try:
data = framed.recv()
except (ConnectionError, OSError):
break
cmd = json.loads(data)
resp = _dispatch(cmd, mgr, store, server_pk_b64, transcript_hash)
framed.send(json.dumps(resp).encode())
elif op == "auth":
registry_id = bytes.fromhex(hello["registry_id"])
role_id = bytes.fromhex(hello["role_id"])
proof_bytes = _unb64(hello["bbs_auth_b64"])
ok = mgr.verify_presentation(
registry_id, role_id, proof_bytes, transcript_hash,
)
if not ok:
framed.send(json.dumps({"error": "auth failed"}).encode())
return
resp = {
"status": "authenticated",
"registry_id": registry_id.hex(),
"role_id": role_id.hex(),
}
framed.send(json.dumps(resp).encode())
while True:
try:
data = framed.recv()
except (ConnectionError, OSError):
break
framed.send(data)
else:
framed.send(json.dumps({"error": f"unknown op: {op}"}).encode())
except (ConnectionError, BrokenPipeError, OSError):
pass
except Exception as exc:
print(f"[server] {peer} error: {exc}")
traceback.print_exc()
finally:
conn.close()
# ── Public entry point ────────────────────────────────────────────────
def serve(data_dir: str, host: str = "127.0.0.1", port: int = 9800):
dd = Path(data_dir)
dd.mkdir(parents=True, exist_ok=True)
store = _ServerStore(dd)
kp = store.load_or_create_keypair()
server_pk_b64 = _b64(kp.public_key().to_bytes())
node = zkac.Node(kp)
mgr = zkac.RegistryManager()
n = store.load_all_registries(mgr)
print(f"server public key: {_unb64(server_pk_b64).hex()}")
print(f"loaded {n} registries")
print(f"listening on {host}:{port}")
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((host, port))
sock.listen(8)
try:
while True:
conn, addr = sock.accept()
threading.Thread(
target=_handle_conn,
args=(conn, addr, node, mgr, store, server_pk_b64),
daemon=True,
).start()
except KeyboardInterrupt:
print("\nshutdown")
finally:
sock.close()