499 lines
19 KiB
Python
499 lines
19 KiB
Python
"""ZKAC server: all traffic over a single encrypted, server-authenticated channel.
|
||
|
||
Every connection performs an anonymous handshake (X25519 + Schnorr server
|
||
identity proof). The first encrypted frame selects the mode:
|
||
|
||
{"op": "mgmt"} management commands (create_registry, post_grant, ...)
|
||
{"op": "auth", "registry_id": hex, "bbs_auth_b64": ...} role authentication
|
||
|
||
The server stores only cryptographically verified opaque blobs:
|
||
<data_dir>/server_key.json Schnorr keypair
|
||
<data_dir>/registries/<rid>.state raw RegistryState bytes
|
||
<data_dir>/registries/<rid>.cert raw state cert bytes
|
||
<data_dir>/mailbox/grants_pool.json anonymous append-only grant pool
|
||
|
||
Recipients discover matching grants via a cheap detection-tag index
|
||
(``pool_tags``). PIR (``pir_query``) returns a **small handle** per row;
|
||
bulk ciphertext is fetched with ``get_grant_blob`` (split payload).
|
||
Large PIR answers are streamed in slices (same pattern as ``pir_hints``).
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import base64
|
||
import hashlib
|
||
import json
|
||
import os
|
||
import socket
|
||
import threading
|
||
import traceback
|
||
from pathlib import Path
|
||
|
||
import zkac
|
||
from zkac.tcp import MAX_TCP_FRAME_BYTES, FramedSession, server_handshake_anon
|
||
|
||
# Serialized PIR hints can be huge (64 KiB records × LWE width). Each mgmt reply must
|
||
# stay under :data:`MAX_TCP_FRAME_BYTES` after encryption; base64 expands ~4/3.
|
||
_PIR_HINT_CHUNK = min(131_072, max(16_384, (MAX_TCP_FRAME_BYTES * 3) // 5))
|
||
|
||
|
||
def _b64(data: bytes) -> str:
|
||
return base64.b64encode(data).decode()
|
||
|
||
|
||
def _unb64(s: str) -> bytes:
|
||
return base64.b64decode(s)
|
||
|
||
|
||
def _pir_handle_bytes(grant_id: str, ciphertext_b64: str) -> bytes:
|
||
"""Fixed-size PIR row: JSON handle binding ``grant_id`` to ciphertext (SHA-256)."""
|
||
ct_digest = hashlib.sha256(_unb64(ciphertext_b64)).hexdigest()
|
||
handle = {"v": 1, "g": grant_id, "h": ct_digest}
|
||
raw = json.dumps(handle, separators=(",", ":"), sort_keys=True).encode("utf-8")
|
||
if len(raw) > zkac.PIR_RECORD_BYTES:
|
||
raise ValueError(
|
||
f"PIR handle exceeds PIR_RECORD_BYTES ({zkac.PIR_RECORD_BYTES})"
|
||
)
|
||
return raw + b"\x00" * (zkac.PIR_RECORD_BYTES - len(raw))
|
||
|
||
|
||
# ── Opaque server storage ─────────────────────────────────────────────
|
||
|
||
class _ServerStore:
|
||
"""Thread-safe, opaque persistence for registries and anonymous grant pool."""
|
||
|
||
def __init__(self, data_dir: Path):
|
||
self._dir = data_dir
|
||
self._reg_dir = data_dir / "registries"
|
||
self._mbox_dir = data_dir / "mailbox"
|
||
self._pool_path = self._mbox_dir / "grants_pool.json"
|
||
self._reg_dir.mkdir(parents=True, exist_ok=True)
|
||
self._mbox_dir.mkdir(parents=True, exist_ok=True)
|
||
self._lock = threading.Lock()
|
||
self._pir_server: zkac.PirServer | None = None
|
||
self._pir_dirty = True
|
||
self._migrate_legacy_mailbox()
|
||
|
||
# ── server key ────────────────────────────────────────────────────
|
||
|
||
def load_or_create_keypair(self) -> zkac.Keypair:
|
||
kf = self._dir / "server_key.json"
|
||
if kf.exists():
|
||
data = json.loads(kf.read_text())
|
||
return zkac.Keypair.from_secret_key(_unb64(data["secret_b64"]))
|
||
kp = zkac.Keypair()
|
||
kf.write_text(json.dumps({
|
||
"secret_b64": _b64(kp.secret_key_bytes()),
|
||
"public_b64": _b64(kp.public_key().to_bytes()),
|
||
}, indent=2))
|
||
return kp
|
||
|
||
# ── registries ────────────────────────────────────────────────────
|
||
|
||
def save_registry(self, rid_hex: str, state_bytes: bytes, cert_bytes: bytes):
|
||
with self._lock:
|
||
(self._reg_dir / f"{rid_hex}.state").write_bytes(state_bytes)
|
||
(self._reg_dir / f"{rid_hex}.cert").write_bytes(cert_bytes)
|
||
|
||
def load_all_registries(self, mgr: zkac.RegistryManager) -> int:
|
||
count = 0
|
||
for p in sorted(self._reg_dir.glob("*.state")):
|
||
rid_hex = p.stem
|
||
cert_path = self._reg_dir / f"{rid_hex}.cert"
|
||
if not cert_path.exists():
|
||
continue
|
||
try:
|
||
mgr.restore(p.read_bytes(), cert_path.read_bytes())
|
||
count += 1
|
||
except Exception as exc:
|
||
print(f"[server] skip registry {rid_hex}: {exc}")
|
||
return count
|
||
|
||
# ── legacy per-recipient mailbox → anonymous pool ─────────────────
|
||
|
||
def _migrate_legacy_mailbox(self):
|
||
if self._pool_path.exists():
|
||
return
|
||
records: list[dict] = []
|
||
for p in sorted(self._mbox_dir.glob("*.json")):
|
||
if p.name == "grants_pool.json":
|
||
continue
|
||
try:
|
||
entries = json.loads(p.read_text())
|
||
for e in entries:
|
||
if "claimed" not in e:
|
||
e["claimed"] = False
|
||
records.append(e)
|
||
except Exception as exc:
|
||
print(f"[server] skip legacy mailbox {p.name}: {exc}")
|
||
try:
|
||
p.unlink()
|
||
except OSError:
|
||
pass
|
||
if records:
|
||
self._pool_path.write_text(
|
||
json.dumps({"version": 1, "records": records}, indent=2)
|
||
)
|
||
|
||
def _load_pool(self) -> list[dict]:
|
||
if not self._pool_path.exists():
|
||
return []
|
||
data = json.loads(self._pool_path.read_text())
|
||
return data.get("records", [])
|
||
|
||
def _save_pool(self, records: list[dict]):
|
||
self._pool_path.write_text(
|
||
json.dumps({"version": 1, "records": records}, indent=2)
|
||
)
|
||
|
||
# ── PIR server rebuild ────────────────────────────────────────────
|
||
|
||
def _ensure_pir(self) -> zkac.PirServer:
|
||
if self._pir_server is not None and not self._pir_dirty:
|
||
return self._pir_server
|
||
records = self._load_pool()
|
||
packed = [_pir_handle_bytes(r["grant_id"], r["ciphertext_b64"]) for r in records]
|
||
db = zkac.PirDatabase(packed, zkac.PIR_RECORD_BYTES)
|
||
self._pir_server = zkac.PirServer(db)
|
||
self._pir_dirty = False
|
||
return self._pir_server
|
||
|
||
# ── anonymous grant pool ─────────────────────────────────────────
|
||
|
||
def post_grant(self, entry: dict) -> tuple[str, int]:
|
||
grant_id = os.urandom(16).hex()
|
||
_pir_handle_bytes(grant_id, entry["ciphertext_b64"])
|
||
row = {"grant_id": grant_id, "claimed": False, **entry}
|
||
with self._lock:
|
||
records = self._load_pool()
|
||
pool_index = len(records)
|
||
records.append(row)
|
||
self._save_pool(records)
|
||
self._pir_dirty = True
|
||
return grant_id, pool_index
|
||
|
||
def pool_info(self) -> dict:
|
||
with self._lock:
|
||
pir = self._ensure_pir()
|
||
return {
|
||
"n": pir.n_records,
|
||
"record_bytes": pir.record_bytes,
|
||
"pool_version": bytes(pir.version()).hex(),
|
||
}
|
||
|
||
def pool_tags(self) -> tuple[list[tuple[str, str]], str]:
|
||
with self._lock:
|
||
records = self._load_pool()
|
||
pir = self._ensure_pir()
|
||
version = bytes(pir.version()).hex()
|
||
tags = []
|
||
for r in records:
|
||
tags.append((
|
||
r.get("eph_pk_b64", ""),
|
||
r.get("to_tag_b64", ""),
|
||
))
|
||
return tags, version
|
||
|
||
def pir_hints(self) -> tuple[bytes, str]:
|
||
with self._lock:
|
||
pir = self._ensure_pir()
|
||
return bytes(pir.hints()), bytes(pir.version()).hex()
|
||
|
||
def pir_answer_bytes(self, query_b64: str, pool_version: str) -> bytes | None:
|
||
"""Return raw PIR answer bytes, or ``None`` if ``pool_version`` is stale."""
|
||
with self._lock:
|
||
pir = self._ensure_pir()
|
||
current = bytes(pir.version()).hex()
|
||
if current != pool_version:
|
||
return None
|
||
ans = pir.answer(_unb64(query_b64))
|
||
return bytes(ans)
|
||
|
||
def claim_grant(self, grant_id: str) -> dict | None:
|
||
with self._lock:
|
||
records = self._load_pool()
|
||
for i, e in enumerate(records):
|
||
if e["grant_id"] == grant_id and not e.get("claimed", False):
|
||
e["claimed"] = True
|
||
records[i] = e
|
||
self._save_pool(records)
|
||
self._pir_dirty = True
|
||
return dict(e)
|
||
return None
|
||
|
||
def get_grant_blob(self, grant_id: str) -> dict | None:
|
||
"""Return public grant fields for second-phase fetch (after PIR handle)."""
|
||
with self._lock:
|
||
for e in self._load_pool():
|
||
if e.get("grant_id") == grant_id:
|
||
return {
|
||
"eph_pk_b64": e.get("eph_pk_b64", ""),
|
||
"ciphertext_b64": e.get("ciphertext_b64", ""),
|
||
"to_tag_b64": e.get("to_tag_b64", ""),
|
||
"claimed": bool(e.get("claimed", False)),
|
||
}
|
||
return None
|
||
|
||
|
||
# ── Command dispatch (inside encrypted session) ──────────────────────
|
||
|
||
def _dispatch(
|
||
cmd: dict,
|
||
mgr: zkac.RegistryManager,
|
||
store: _ServerStore,
|
||
server_pk_b64: str,
|
||
transcript_hash: bytes,
|
||
conn_ctx: dict,
|
||
) -> dict:
|
||
try:
|
||
action = cmd.get("cmd")
|
||
|
||
if action == "server_info":
|
||
return {"ok": True, "server_public_key_b64": server_pk_b64}
|
||
|
||
if action == "create_registry":
|
||
state_bytes = _unb64(cmd["state_bytes_b64"])
|
||
state_cert = _unb64(cmd["state_cert_b64"])
|
||
rid = mgr.create(state_bytes, state_cert)
|
||
store.save_registry(rid.hex(), state_bytes, state_cert)
|
||
return {"ok": True, "registry_id": rid.hex()}
|
||
|
||
if action == "get_registry":
|
||
rid = bytes.fromhex(cmd["registry_id"])
|
||
state_bytes, state_cert = mgr.get(rid)
|
||
return {
|
||
"ok": True,
|
||
"state_bytes_b64": _b64(state_bytes),
|
||
"state_cert_b64": _b64(state_cert),
|
||
}
|
||
|
||
if action == "update_registry":
|
||
rid = bytes.fromhex(cmd["registry_id"])
|
||
state_bytes = _unb64(cmd["state_bytes_b64"])
|
||
state_cert = _unb64(cmd["state_cert_b64"])
|
||
mgr.update(rid, state_bytes, state_cert)
|
||
store.save_registry(cmd["registry_id"], state_bytes, state_cert)
|
||
return {"ok": True}
|
||
|
||
if action == "post_grant":
|
||
rid = bytes.fromhex(cmd["registry_id"])
|
||
proof = _unb64(cmd["admin_proof_b64"])
|
||
if not mgr.verify_admin(rid, proof, transcript_hash):
|
||
return {"error": "admin proof failed"}
|
||
entry = {
|
||
"eph_pk_b64": cmd["eph_pk_b64"],
|
||
"ciphertext_b64": cmd["ciphertext_b64"],
|
||
"to_tag_b64": cmd.get("to_tag_b64", ""),
|
||
}
|
||
gid, pool_index = store.post_grant(entry)
|
||
return {"ok": True, "grant_id": gid, "pool_index": pool_index}
|
||
|
||
if action == "pool_info":
|
||
info = store.pool_info()
|
||
return {"ok": True, **info}
|
||
|
||
if action == "pool_tags":
|
||
tags, version = store.pool_tags()
|
||
entries = [
|
||
{"eph_pk_b64": epk, "to_tag_b64": tag}
|
||
for epk, tag in tags
|
||
]
|
||
return {"ok": True, "tags": entries, "pool_version": version}
|
||
|
||
if action == "pir_hints":
|
||
raw, version = store.pir_hints()
|
||
offset = cmd.get("offset")
|
||
if offset is not None:
|
||
if cmd.get("pool_version") != version:
|
||
return {"error": "stale_version"}
|
||
try:
|
||
off = int(offset)
|
||
except (TypeError, ValueError):
|
||
return {"error": "bad offset"}
|
||
if off < 0 or off > len(raw):
|
||
return {"error": "bad offset"}
|
||
piece = raw[off : off + _PIR_HINT_CHUNK]
|
||
nxt = off + len(piece)
|
||
return {
|
||
"ok": True,
|
||
"pool_version": version,
|
||
"slice_b64": _b64(piece),
|
||
"offset": off,
|
||
"returned": len(piece),
|
||
"done": nxt >= len(raw),
|
||
}
|
||
if len(raw) <= _PIR_HINT_CHUNK:
|
||
return {"ok": True, "hints_b64": _b64(raw), "pool_version": version}
|
||
return {
|
||
"ok": True,
|
||
"pool_version": version,
|
||
"hints_total": len(raw),
|
||
"chunk": _PIR_HINT_CHUNK,
|
||
}
|
||
|
||
if action == "pir_query":
|
||
q_b64 = cmd.get("query_b64", "")
|
||
pv = cmd.get("pool_version", "")
|
||
offset = cmd.get("offset")
|
||
|
||
if offset is None:
|
||
ans_bytes = store.pir_answer_bytes(q_b64, pv)
|
||
if ans_bytes is None:
|
||
return {"error": "stale_version"}
|
||
conn_ctx.pop("pir_answer_buffer", None)
|
||
conn_ctx.pop("pir_answer_pv", None)
|
||
if len(ans_bytes) <= _PIR_HINT_CHUNK:
|
||
return {"ok": True, "answer_b64": _b64(ans_bytes), "pool_version": pv}
|
||
conn_ctx["pir_answer_buffer"] = ans_bytes
|
||
conn_ctx["pir_answer_pv"] = pv
|
||
return {
|
||
"ok": True,
|
||
"pool_version": pv,
|
||
"answer_total": len(ans_bytes),
|
||
"chunk": _PIR_HINT_CHUNK,
|
||
}
|
||
|
||
if conn_ctx.get("pir_answer_pv") != pv:
|
||
return {"error": "stale_version"}
|
||
buf = conn_ctx.get("pir_answer_buffer")
|
||
if buf is None:
|
||
return {"error": "no PIR answer in progress"}
|
||
try:
|
||
off = int(offset)
|
||
except (TypeError, ValueError):
|
||
return {"error": "bad offset"}
|
||
if off < 0 or off > len(buf):
|
||
return {"error": "bad offset"}
|
||
piece = buf[off : off + _PIR_HINT_CHUNK]
|
||
nxt = off + len(piece)
|
||
done = nxt >= len(buf)
|
||
if done:
|
||
conn_ctx.pop("pir_answer_buffer", None)
|
||
conn_ctx.pop("pir_answer_pv", None)
|
||
return {
|
||
"ok": True,
|
||
"pool_version": pv,
|
||
"slice_b64": _b64(piece),
|
||
"offset": off,
|
||
"returned": len(piece),
|
||
"done": done,
|
||
}
|
||
|
||
if action == "claim_grant":
|
||
entry = store.claim_grant(cmd["grant_id"])
|
||
if entry is None:
|
||
return {"error": "grant not found"}
|
||
return {"ok": True, "grant": entry}
|
||
|
||
if action == "get_grant_blob":
|
||
gid = cmd.get("grant_id", "")
|
||
if not gid:
|
||
return {"error": "missing grant_id"}
|
||
blob = store.get_grant_blob(gid)
|
||
if blob is None:
|
||
return {"error": "grant not found"}
|
||
return {"ok": True, **blob}
|
||
|
||
return {"error": f"unknown command: {action}"}
|
||
|
||
except Exception as exc:
|
||
return {"error": str(exc)}
|
||
|
||
|
||
# ── Connection handler ────────────────────────────────────────────────
|
||
|
||
def _handle_conn(conn: socket.socket, addr: tuple, node: zkac.Node,
|
||
mgr: zkac.RegistryManager, store: _ServerStore,
|
||
server_pk_b64: str):
|
||
peer = f"{addr[0]}:{addr[1]}"
|
||
try:
|
||
session = server_handshake_anon(conn, node)
|
||
framed = FramedSession(conn, session)
|
||
transcript_hash = bytes(session.transcript_hash())
|
||
|
||
hello = json.loads(framed.recv())
|
||
op = hello.get("op")
|
||
|
||
if op == "mgmt":
|
||
conn_ctx: dict = {}
|
||
while True:
|
||
try:
|
||
data = framed.recv()
|
||
except (ConnectionError, OSError):
|
||
break
|
||
cmd = json.loads(data)
|
||
resp = _dispatch(cmd, mgr, store, server_pk_b64, transcript_hash, conn_ctx)
|
||
framed.send(json.dumps(resp).encode())
|
||
|
||
elif op == "auth":
|
||
registry_id = bytes.fromhex(hello["registry_id"])
|
||
role_id = bytes.fromhex(hello["role_id"])
|
||
proof_bytes = _unb64(hello["bbs_auth_b64"])
|
||
|
||
ok = mgr.verify_presentation(
|
||
registry_id, role_id, proof_bytes, transcript_hash,
|
||
)
|
||
if not ok:
|
||
framed.send(json.dumps({"error": "auth failed"}).encode())
|
||
return
|
||
|
||
resp = {
|
||
"status": "authenticated",
|
||
"registry_id": registry_id.hex(),
|
||
"role_id": role_id.hex(),
|
||
}
|
||
framed.send(json.dumps(resp).encode())
|
||
|
||
while True:
|
||
try:
|
||
data = framed.recv()
|
||
except (ConnectionError, OSError):
|
||
break
|
||
framed.send(data)
|
||
else:
|
||
framed.send(json.dumps({"error": f"unknown op: {op}"}).encode())
|
||
|
||
except (ConnectionError, BrokenPipeError, OSError):
|
||
pass
|
||
except Exception as exc:
|
||
print(f"[server] {peer} error: {exc}")
|
||
traceback.print_exc()
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
# ── Public entry point ────────────────────────────────────────────────
|
||
|
||
def serve(data_dir: str, host: str = "127.0.0.1", port: int = 9800):
|
||
dd = Path(data_dir)
|
||
dd.mkdir(parents=True, exist_ok=True)
|
||
|
||
store = _ServerStore(dd)
|
||
kp = store.load_or_create_keypair()
|
||
server_pk_b64 = _b64(kp.public_key().to_bytes())
|
||
node = zkac.Node(kp)
|
||
|
||
mgr = zkac.RegistryManager()
|
||
n = store.load_all_registries(mgr)
|
||
|
||
print(f"server public key: {_unb64(server_pk_b64).hex()}")
|
||
print(f"loaded {n} registries")
|
||
print(f"listening on {host}:{port}")
|
||
|
||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||
sock.bind((host, port))
|
||
sock.listen(8)
|
||
try:
|
||
while True:
|
||
conn, addr = sock.accept()
|
||
threading.Thread(
|
||
target=_handle_conn,
|
||
args=(conn, addr, node, mgr, store, server_pk_b64),
|
||
daemon=True,
|
||
).start()
|
||
except KeyboardInterrupt:
|
||
print("\nshutdown")
|
||
finally:
|
||
sock.close()
|