ZKAC/cli/zkac_cli/server.py
2026-05-04 13:33:40 +02:00

465 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""ZKAC server: all traffic over a single encrypted, server-authenticated channel.
Every connection performs an anonymous handshake (X25519 + Schnorr server
identity proof). The first encrypted frame selects the mode:
{"op": "mgmt"} management commands (create_registry, post_grant, ...)
{"op": "auth", "registry_id": hex, "bbs_auth_b64": ...} role authentication
The server stores only cryptographically verified opaque blobs:
<data_dir>/server_key.json Schnorr keypair
<data_dir>/registries/<rid>.state raw RegistryState bytes
<data_dir>/registries/<rid>.cert raw state cert bytes
<data_dir>/mailbox/grants_pool.json anonymous append-only grant pool
Recipients discover matching grants via a cheap detection-tag index
(``pool_tags``). PIR (``pir_query``) returns a full encrypted mailbox row,
so no follow-up ``grant_id`` fetch is required. This avoids leaking a stable
row identifier during retrieval.
Large PIR answers are streamed in slices (same pattern as ``pir_hints``).
"""
from __future__ import annotations
import base64
import hashlib
import json
import socket
import threading
import traceback
from pathlib import Path
import zkac
from zkac.tcp import MAX_TCP_FRAME_BYTES, FramedSession, server_handshake_anon
# Serialized PIR hints can be huge (64 KiB records × LWE width). Each mgmt reply must
# stay under :data:`MAX_TCP_FRAME_BYTES` after encryption; base64 expands ~4/3.
_PIR_HINT_CHUNK = min(131_072, max(16_384, (MAX_TCP_FRAME_BYTES * 3) // 5))
def _b64(data: bytes) -> str:
return base64.b64encode(data).decode()
def _unb64(s: str) -> bytes:
return base64.b64decode(s)
def _pir_row_bytes(entry: dict) -> bytes:
"""Fixed-size PIR row: full encrypted mailbox row + ciphertext digest."""
ct_b64 = entry.get("ciphertext_b64", "")
ct_digest = hashlib.sha256(_unb64(ct_b64)).hexdigest() if ct_b64 else ""
row = {
"v": 2,
"eph_pk_b64": entry.get("eph_pk_b64", ""),
"to_tag_b64": entry.get("to_tag_b64", ""),
"ciphertext_b64": ct_b64,
"ciphertext_sha256": ct_digest,
"claimed": bool(entry.get("claimed", False)),
}
raw = json.dumps(row, separators=(",", ":"), sort_keys=True).encode("utf-8")
if len(raw) > zkac.PIR_RECORD_BYTES:
raise ValueError(
f"PIR row exceeds PIR_RECORD_BYTES ({zkac.PIR_RECORD_BYTES})"
)
return raw + b"\x00" * (zkac.PIR_RECORD_BYTES - len(raw))
# ── Opaque server storage ─────────────────────────────────────────────
class _ServerStore:
"""Thread-safe, opaque persistence for registries and anonymous grant pool."""
def __init__(self, data_dir: Path):
self._dir = data_dir
self._reg_dir = data_dir / "registries"
self._mbox_dir = data_dir / "mailbox"
self._pool_path = self._mbox_dir / "grants_pool.json"
self._reg_dir.mkdir(parents=True, exist_ok=True)
self._mbox_dir.mkdir(parents=True, exist_ok=True)
self._lock = threading.Lock()
self._pir_server: zkac.PirServer | None = None
self._pir_dirty = True
self._migrate_legacy_mailbox()
# ── server key ────────────────────────────────────────────────────
def load_or_create_keypair(self) -> zkac.Keypair:
kf = self._dir / "server_key.json"
if kf.exists():
data = json.loads(kf.read_text())
return zkac.Keypair.from_secret_key(_unb64(data["secret_b64"]))
kp = zkac.Keypair()
kf.write_text(json.dumps({
"secret_b64": _b64(kp.secret_key_bytes()),
"public_b64": _b64(kp.public_key().to_bytes()),
}, indent=2))
return kp
# ── registries ────────────────────────────────────────────────────
def save_registry(self, rid_hex: str, state_bytes: bytes, cert_bytes: bytes):
with self._lock:
(self._reg_dir / f"{rid_hex}.state").write_bytes(state_bytes)
(self._reg_dir / f"{rid_hex}.cert").write_bytes(cert_bytes)
def load_all_registries(self, mgr: zkac.RegistryManager) -> int:
count = 0
for p in sorted(self._reg_dir.glob("*.state")):
rid_hex = p.stem
cert_path = self._reg_dir / f"{rid_hex}.cert"
if not cert_path.exists():
continue
try:
mgr.restore(p.read_bytes(), cert_path.read_bytes())
count += 1
except Exception as exc:
print(f"[server] skip registry {rid_hex}: {exc}")
return count
# ── legacy per-recipient mailbox → anonymous pool ─────────────────
def _migrate_legacy_mailbox(self):
if self._pool_path.exists():
return
records: list[dict] = []
for p in sorted(self._mbox_dir.glob("*.json")):
if p.name == "grants_pool.json":
continue
try:
entries = json.loads(p.read_text())
for e in entries:
if "claimed" not in e:
e["claimed"] = False
records.append(e)
except Exception as exc:
print(f"[server] skip legacy mailbox {p.name}: {exc}")
try:
p.unlink()
except OSError:
pass
if records:
self._pool_path.write_text(
json.dumps({"version": 1, "records": records}, indent=2)
)
def _load_pool(self) -> list[dict]:
if not self._pool_path.exists():
return []
data = json.loads(self._pool_path.read_text())
return data.get("records", [])
def _save_pool(self, records: list[dict]):
self._pool_path.write_text(
json.dumps({"version": 1, "records": records}, indent=2)
)
# ── PIR server rebuild ────────────────────────────────────────────
def _ensure_pir(self) -> zkac.PirServer:
if self._pir_server is not None and not self._pir_dirty:
return self._pir_server
records = self._load_pool()
packed = [_pir_row_bytes(r) for r in records]
db = zkac.PirDatabase(packed, zkac.PIR_RECORD_BYTES)
self._pir_server = zkac.PirServer(db)
self._pir_dirty = False
return self._pir_server
# ── anonymous grant pool ─────────────────────────────────────────
def post_grant(self, entry: dict) -> int:
_pir_row_bytes(entry)
row = {"claimed": False, **entry}
with self._lock:
records = self._load_pool()
pool_index = len(records)
records.append(row)
self._save_pool(records)
self._pir_dirty = True
return pool_index
def pool_info(self) -> dict:
with self._lock:
pir = self._ensure_pir()
return {
"n": pir.n_records,
"record_bytes": pir.record_bytes,
"pool_version": bytes(pir.version()).hex(),
}
def pool_tags(self) -> tuple[list[tuple[str, str]], str]:
with self._lock:
records = self._load_pool()
pir = self._ensure_pir()
version = bytes(pir.version()).hex()
tags = []
for r in records:
tags.append((
r.get("eph_pk_b64", ""),
r.get("to_tag_b64", ""),
))
return tags, version
def pir_hints(self) -> tuple[bytes, str]:
with self._lock:
pir = self._ensure_pir()
return bytes(pir.hints()), bytes(pir.version()).hex()
def pir_answer_bytes(self, query_b64: str, pool_version: str) -> bytes | None:
"""Return raw PIR answer bytes, or ``None`` if ``pool_version`` is stale."""
with self._lock:
pir = self._ensure_pir()
current = bytes(pir.version()).hex()
if current != pool_version:
return None
ans = pir.answer(_unb64(query_b64))
return bytes(ans)
# ── Command dispatch (inside encrypted session) ──────────────────────
def _dispatch(
cmd: dict,
mgr: zkac.RegistryManager,
store: _ServerStore,
server_pk_b64: str,
transcript_hash: bytes,
conn_ctx: dict,
) -> dict:
try:
action = cmd.get("cmd")
if action == "server_info":
return {"ok": True, "server_public_key_b64": server_pk_b64}
if action == "create_registry":
state_bytes = _unb64(cmd["state_bytes_b64"])
state_cert = _unb64(cmd["state_cert_b64"])
rid = mgr.create(state_bytes, state_cert)
store.save_registry(rid.hex(), state_bytes, state_cert)
return {"ok": True, "registry_id": rid.hex()}
if action == "get_registry":
rid = bytes.fromhex(cmd["registry_id"])
state_bytes, state_cert = mgr.get(rid)
return {
"ok": True,
"state_bytes_b64": _b64(state_bytes),
"state_cert_b64": _b64(state_cert),
}
if action == "update_registry":
rid = bytes.fromhex(cmd["registry_id"])
state_bytes = _unb64(cmd["state_bytes_b64"])
state_cert = _unb64(cmd["state_cert_b64"])
mgr.update(rid, state_bytes, state_cert)
store.save_registry(cmd["registry_id"], state_bytes, state_cert)
return {"ok": True}
if action == "post_grant":
rid = bytes.fromhex(cmd["registry_id"])
proof = _unb64(cmd["admin_proof_b64"])
if not mgr.verify_admin(rid, proof, transcript_hash):
return {"error": "admin proof failed"}
entry = {
"eph_pk_b64": cmd["eph_pk_b64"],
"ciphertext_b64": cmd["ciphertext_b64"],
"to_tag_b64": cmd.get("to_tag_b64", ""),
}
pool_index = store.post_grant(entry)
return {"ok": True, "pool_index": pool_index}
if action == "pool_info":
info = store.pool_info()
return {"ok": True, **info}
if action == "pool_tags":
tags, version = store.pool_tags()
entries = [
{"eph_pk_b64": epk, "to_tag_b64": tag}
for epk, tag in tags
]
return {"ok": True, "tags": entries, "pool_version": version}
if action == "pir_hints":
raw, version = store.pir_hints()
offset = cmd.get("offset")
if offset is not None:
if cmd.get("pool_version") != version:
return {"error": "stale_version"}
try:
off = int(offset)
except (TypeError, ValueError):
return {"error": "bad offset"}
if off < 0 or off > len(raw):
return {"error": "bad offset"}
piece = raw[off : off + _PIR_HINT_CHUNK]
nxt = off + len(piece)
return {
"ok": True,
"pool_version": version,
"slice_b64": _b64(piece),
"offset": off,
"returned": len(piece),
"done": nxt >= len(raw),
}
if len(raw) <= _PIR_HINT_CHUNK:
return {"ok": True, "hints_b64": _b64(raw), "pool_version": version}
return {
"ok": True,
"pool_version": version,
"hints_total": len(raw),
"chunk": _PIR_HINT_CHUNK,
}
if action == "pir_query":
q_b64 = cmd.get("query_b64", "")
pv = cmd.get("pool_version", "")
offset = cmd.get("offset")
if offset is None:
ans_bytes = store.pir_answer_bytes(q_b64, pv)
if ans_bytes is None:
return {"error": "stale_version"}
conn_ctx.pop("pir_answer_buffer", None)
conn_ctx.pop("pir_answer_pv", None)
if len(ans_bytes) <= _PIR_HINT_CHUNK:
return {"ok": True, "answer_b64": _b64(ans_bytes), "pool_version": pv}
conn_ctx["pir_answer_buffer"] = ans_bytes
conn_ctx["pir_answer_pv"] = pv
return {
"ok": True,
"pool_version": pv,
"answer_total": len(ans_bytes),
"chunk": _PIR_HINT_CHUNK,
}
if conn_ctx.get("pir_answer_pv") != pv:
return {"error": "stale_version"}
buf = conn_ctx.get("pir_answer_buffer")
if buf is None:
return {"error": "no PIR answer in progress"}
try:
off = int(offset)
except (TypeError, ValueError):
return {"error": "bad offset"}
if off < 0 or off > len(buf):
return {"error": "bad offset"}
piece = buf[off : off + _PIR_HINT_CHUNK]
nxt = off + len(piece)
done = nxt >= len(buf)
if done:
conn_ctx.pop("pir_answer_buffer", None)
conn_ctx.pop("pir_answer_pv", None)
return {
"ok": True,
"pool_version": pv,
"slice_b64": _b64(piece),
"offset": off,
"returned": len(piece),
"done": done,
}
return {"error": f"unknown command: {action}"}
except Exception as exc:
return {"error": str(exc)}
# ── Connection handler ────────────────────────────────────────────────
def _handle_conn(conn: socket.socket, addr: tuple, node: zkac.Node,
mgr: zkac.RegistryManager, store: _ServerStore,
server_pk_b64: str):
peer = f"{addr[0]}:{addr[1]}"
try:
session = server_handshake_anon(conn, node)
framed = FramedSession(conn, session)
transcript_hash = bytes(session.transcript_hash())
hello = json.loads(framed.recv())
op = hello.get("op")
if op == "mgmt":
conn_ctx: dict = {}
while True:
try:
data = framed.recv()
except (ConnectionError, OSError):
break
cmd = json.loads(data)
resp = _dispatch(cmd, mgr, store, server_pk_b64, transcript_hash, conn_ctx)
framed.send(json.dumps(resp).encode())
elif op == "auth":
registry_id = bytes.fromhex(hello["registry_id"])
role_id = bytes.fromhex(hello["role_id"])
proof_bytes = _unb64(hello["bbs_auth_b64"])
ok = mgr.verify_presentation(
registry_id, role_id, proof_bytes, transcript_hash,
)
if not ok:
framed.send(json.dumps({"error": "auth failed"}).encode())
return
resp = {
"status": "authenticated",
"registry_id": registry_id.hex(),
"role_id": role_id.hex(),
}
framed.send(json.dumps(resp).encode())
while True:
try:
data = framed.recv()
except (ConnectionError, OSError):
break
framed.send(data)
else:
framed.send(json.dumps({"error": f"unknown op: {op}"}).encode())
except (ConnectionError, BrokenPipeError, OSError):
pass
except Exception as exc:
print(f"[server] {peer} error: {exc}")
traceback.print_exc()
finally:
conn.close()
# ── Public entry point ────────────────────────────────────────────────
def serve(data_dir: str, host: str = "127.0.0.1", port: int = 9800):
dd = Path(data_dir)
dd.mkdir(parents=True, exist_ok=True)
store = _ServerStore(dd)
kp = store.load_or_create_keypair()
server_pk_b64 = _b64(kp.public_key().to_bytes())
node = zkac.Node(kp)
mgr = zkac.RegistryManager()
n = store.load_all_registries(mgr)
print(f"server public key: {_unb64(server_pk_b64).hex()}")
print(f"loaded {n} registries")
print(f"listening on {host}:{port}")
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((host, port))
sock.listen(8)
try:
while True:
conn, addr = sock.accept()
threading.Thread(
target=_handle_conn,
args=(conn, addr, node, mgr, store, server_pk_b64),
daemon=True,
).start()
except KeyboardInterrupt:
print("\nshutdown")
finally:
sock.close()