"""Client-side operations over a unified encrypted channel (per local user id).""" from __future__ import annotations import base64 import json import socket import zkac from zkac.tcp import FramedSession, client_handshake_anon from . import pir, store def _b64(data: bytes) -> str: return base64.b64encode(data).decode() def _unb64(s: str) -> bytes: return base64.b64decode(s) def _parse_server(server: str) -> tuple[str, int]: host, _, port = server.rpartition(":") return host or "127.0.0.1", int(port) def parse_spec(spec: str) -> tuple[str, str, str]: """Parse 'host:port:registry_id:role' into (server, registry_id, role).""" parts = spec.rsplit(":", 2) if len(parts) != 3: raise ValueError(f"invalid spec {spec!r}, expected host:port:registry_id:role") return parts[0], parts[1], parts[2] def _resolve_server_pk(userid: str, server: str) -> zkac.PublicKey: pin = store.load_server_pin(userid, server) if pin is None: raise RuntimeError( f"no pinned key for {server}; run: zkac-node server pin {userid} {server} --key " ) return zkac.PublicKey.from_bytes(_unb64(pin["server_public_key_b64"])) def _mgmt_connect(userid: str, server: str) -> tuple[socket.socket, FramedSession]: host, port = _parse_server(server) sock = socket.create_connection((host, port)) server_pk = _resolve_server_pk(userid, server) node = zkac.Node(zkac.Keypair()) session = client_handshake_anon(sock, node, server_pk) framed = FramedSession(sock, session) framed.send(json.dumps({"op": "mgmt"}).encode()) return sock, framed def _mgmt_cmd(framed: FramedSession, cmd: dict) -> dict: framed.send(json.dumps(cmd).encode()) return json.loads(framed.recv()) def _mgmt_single(userid: str, server: str, cmd: dict) -> dict: sock, framed = _mgmt_connect(userid, server) try: return _ok(_mgmt_cmd(framed, cmd)) finally: sock.close() def _ok(resp: dict) -> dict: if resp.get("error"): raise RuntimeError(resp["error"]) return resp def create_registry(userid: str, server: str, role_names: list[str]) -> str: identity = store.load_identity(userid) admin_mat = store.new_admin_material() bbs_issuer, bbs_pk, admin_cred = store.reconstruct_admin(admin_mat) role_entries = [(zkac.role_id(name), bbs_pk, 1) for name in role_names] state = zkac.RegistryState.build( bbs_pk, identity["issuance_pk"], 1, b"\x00" * 32, role_entries, ) state_bytes = state.serialize() state_cert = state.certify(admin_cred) registry_id = state.registry_id() resp = _mgmt_single(userid, server, { "cmd": "create_registry", "state_bytes_b64": _b64(state_bytes), "state_cert_b64": _b64(bytes(state_cert)), }) rid_hex = resp["registry_id"] store.save_admin(userid, rid_hex, { "server": server, "roles": role_names, **admin_mat, }) return rid_hex def update_registry(userid: str, server: str, registry_id_hex: str, add_roles: list[str]): admin_data = store.load_admin(userid, registry_id_hex) bbs_issuer, bbs_pk, admin_cred = store.reconstruct_admin(admin_data) identity = store.load_identity(userid) cur = _mgmt_single(userid, server, { "cmd": "get_registry", "registry_id": registry_id_hex, }) old_state = zkac.RegistryState.deserialize(_unb64(cur["state_bytes_b64"])) prev_hash = old_state.state_hash() new_version = old_state.version() + 1 old_roles = admin_data.get("roles", []) all_roles = list(old_roles) + [r for r in add_roles if r not in old_roles] role_entries = [(zkac.role_id(name), bbs_pk, 1) for name in all_roles] new_state = zkac.RegistryState.build( bbs_pk, identity["issuance_pk"], new_version, bytes(prev_hash), role_entries, ) new_cert = new_state.certify(admin_cred) _mgmt_single(userid, server, { "cmd": "update_registry", "registry_id": registry_id_hex, "state_bytes_b64": _b64(new_state.serialize()), "state_cert_b64": _b64(bytes(new_cert)), }) admin_data["roles"] = all_roles store.save_admin(userid, registry_id_hex, admin_data) def get_registry(userid: str, server: str, registry_id_hex: str) -> dict: return _mgmt_single(userid, server, { "cmd": "get_registry", "registry_id": registry_id_hex, }) def list_own_registries(userid: str) -> list[dict]: result = [] for rid in store.list_admin_registries(userid): data = store.load_admin(userid, rid) result.append({ "registry_id": rid, "server": data.get("server", "?"), "roles": data.get("roles", []), }) return result def grant(userid: str, server: str, registry_id_hex: str, role_name: str, recipient_pk_hex: str) -> tuple[str, int]: admin_data = store.load_admin(userid, registry_id_hex) roles = admin_data.get("roles", []) if role_name not in roles: raise RuntimeError(f"role {role_name!r} not in registry (have: {roles})") bbs_issuer, bbs_pk, admin_cred = store.reconstruct_admin(admin_data) role_rid = zkac.role_id(role_name) epoch = 1 req = zkac.prepare_blind_request() blind_sig = bbs_issuer.issue_blind(req.commitment_with_proof(), role_rid, epoch) payload = json.dumps({ "registry_id": registry_id_hex, "role_name": role_name, "epoch": epoch, "issuer_pk_b64": _b64(bbs_pk.to_bytes()), "blind_sig_b64": _b64(blind_sig), "member_secret_b64": _b64(req.member_secret()), "prover_blind_b64": _b64(req.prover_blind()), }).encode() recipient_pk = bytes.fromhex(recipient_pk_hex) eph_kp = zkac.IssuanceKeypair() ciphertext = eph_kp.encrypt(recipient_pk, payload) sock, framed = _mgmt_connect(userid, server) try: transcript_hash = bytes(framed.session.transcript_hash()) admin_proof = admin_cred.present(transcript_hash) resp = _ok(_mgmt_cmd(framed, { "cmd": "post_grant", "registry_id": registry_id_hex, "eph_pk_b64": _b64(eph_kp.public_key_bytes()), "ciphertext_b64": _b64(ciphertext), "admin_proof_b64": _b64(admin_proof), })) finally: sock.close() return resp["grant_id"], resp.get("pool_index", -1) def _pir_recover_row( framed_a: FramedSession, framed_b: FramedSession, n: int, pool_index: int, ) -> dict: idx_a, idx_b = pir.pir_query_indices(n, pool_index) xa = _ok(_mgmt_cmd(framed_a, {"cmd": "pir_fold", "indices": idx_a})) xb = _ok(_mgmt_cmd(framed_b, {"cmd": "pir_fold", "indices": idx_b})) raw = pir.pir_recover(_unb64(xa["xor_b64"]), _unb64(xb["xor_b64"])) return pir.unpad_grant_record(raw) def _fetch_grant_entry_pir( userid: str, server_a: str, server_b: str, pool_index: int, ) -> dict: sock_a, fa = _mgmt_connect(userid, server_a) sock_b, fb = _mgmt_connect(userid, server_b) try: a = _ok(_mgmt_cmd(fa, {"cmd": "mail_pool_len"})) b = _ok(_mgmt_cmd(fb, {"cmd": "mail_pool_len"})) if a["n"] != b["n"] or a["record_bytes"] != b["record_bytes"]: raise RuntimeError("PIR peers disagree on pool length or record size") n = a["n"] if not (0 <= pool_index < n): raise RuntimeError("pool_index out of range for current pool") return _pir_recover_row(fa, fb, n, pool_index) finally: sock_a.close() sock_b.close() def list_pending(userid: str, server: str, pir_peer: str) -> list[dict]: """Scan the grant pool via two-server XOR PIR (one Chor query per row).""" identity = store.load_identity(userid) receiver_kp = zkac.IssuanceKeypair.from_secret(identity["issuance_sk"]) info = _mgmt_single(userid, server, {"cmd": "server_info"}) store.pin_server(userid, server, info["server_public_key_b64"]) info_b = _mgmt_single(userid, pir_peer, {"cmd": "server_info"}) store.pin_server(userid, pir_peer, info_b["server_public_key_b64"]) sock_a, fa = _mgmt_connect(userid, server) sock_b, fb = _mgmt_connect(userid, pir_peer) try: a = _ok(_mgmt_cmd(fa, {"cmd": "mail_pool_len"})) b = _ok(_mgmt_cmd(fb, {"cmd": "mail_pool_len"})) if a["n"] != b["n"] or a["record_bytes"] != b["record_bytes"]: raise RuntimeError("PIR peers disagree on pool length or record size") n = a["n"] results = [] for i in range(n): row = _pir_recover_row(fa, fb, n, i) if row.get("claimed"): continue try: eph_pk = _unb64(row["eph_pk_b64"]) ct = _unb64(row["ciphertext_b64"]) plaintext = json.loads(receiver_kp.decrypt(eph_pk, ct)) results.append({ "grant_id": row["grant_id"], "pool_index": i, "registry_id": plaintext.get("registry_id", "?"), "role_name": plaintext.get("role_name", "?"), }) except Exception: results.append({ "grant_id": row.get("grant_id", "?"), "pool_index": i, "registry_id": "?", "role_name": "(undecryptable)", }) return results finally: sock_a.close() sock_b.close() def collect( userid: str, spec: str, *, pir_peer: str, pool_index: int, ) -> dict: server, registry_id_hex, role_name = parse_spec(spec) identity = store.load_identity(userid) receiver_kp = zkac.IssuanceKeypair.from_secret(identity["issuance_sk"]) info = _mgmt_single(userid, server, {"cmd": "server_info"}) store.pin_server(userid, server, info["server_public_key_b64"]) info_b = _mgmt_single(userid, pir_peer, {"cmd": "server_info"}) store.pin_server(userid, pir_peer, info_b["server_public_key_b64"]) row = _fetch_grant_entry_pir(userid, server, pir_peer, pool_index) if row.get("claimed"): raise RuntimeError("grant row is already claimed") try: eph_pk = _unb64(row["eph_pk_b64"]) ct = _unb64(row["ciphertext_b64"]) plaintext = json.loads(receiver_kp.decrypt(eph_pk, ct)) except Exception as exc: raise RuntimeError( "PIR row did not decrypt for this user (wrong pool_index or desynced peers)" ) from exc if (plaintext.get("registry_id") != registry_id_hex or plaintext.get("role_name") != role_name): raise RuntimeError( "PIR row does not match this collect spec (check pool_index and peers)" ) target_grant_id = row["grant_id"] target_payload = plaintext _mgmt_single(userid, server, { "cmd": "claim_grant", "grant_id": target_grant_id, }) _ = _mgmt_single(userid, server, { "cmd": "get_registry", "registry_id": registry_id_hex, }) cred_data = { "blind_sig_b64": target_payload["blind_sig_b64"], "member_secret_b64": target_payload["member_secret_b64"], "prover_blind_b64": target_payload["prover_blind_b64"], "role_name": role_name, "epoch": target_payload["epoch"], "issuer_pk_b64": target_payload["issuer_pk_b64"], } cred = store.reconstruct_credential(cred_data) cred.present(b"self-test") store.save_credential(userid, registry_id_hex, role_name, cred_data) return {"registry_id": registry_id_hex, "role": role_name, "server": server} def authenticate(userid: str, registry_id_hex: str, role_name: str, server: str | None = None) -> dict: admin_data = None try: admin_data = store.load_admin(userid, registry_id_hex) except FileNotFoundError: pass if server is None: if admin_data and admin_data.get("server"): server = admin_data["server"] else: raise RuntimeError("server address required (--server host:port)") cred_data = store.load_credential_data(userid, registry_id_hex, role_name) cred = store.reconstruct_credential(cred_data) server_pk = _resolve_server_pk(userid, server) node = zkac.Node(zkac.Keypair()) host, port = _parse_server(server) sock = socket.create_connection((host, port)) try: session = client_handshake_anon(sock, node, server_pk) framed = FramedSession(sock, session) transcript_hash = bytes(session.transcript_hash()) auth_proof = cred.present(transcript_hash) role_rid = zkac.role_id(role_name) framed.send(json.dumps({ "op": "auth", "registry_id": registry_id_hex, "role_id": role_rid.hex(), "bbs_auth_b64": _b64(auth_proof), }).encode()) return json.loads(framed.recv()) finally: sock.close()