"""Client-side operations over a unified encrypted channel (per local user id).""" from __future__ import annotations import base64 import hashlib import json import socket from pathlib import Path import zkac from zkac.tcp import FramedSession, client_handshake_anon from . import store def _b64(data: bytes) -> str: return base64.b64encode(data).decode() def _unb64(s: str) -> bytes: return base64.b64decode(s) def _parse_server(server: str) -> tuple[str, int]: host, sep, port_s = server.rpartition(":") if not sep: raise ValueError( f"invalid server {server!r}: expected host:port (e.g. 127.0.0.1:9800). " "That is the TCP address of the running node, not the userid from " "`zkac-node serve `." ) try: port = int(port_s, 10) except ValueError as e: raise ValueError( f"invalid server {server!r}: port after ':' must be a number (host:port)" ) from e if not 1 <= port <= 65535: raise ValueError(f"invalid server port {port}") return (host or "127.0.0.1"), port def parse_spec(spec: str) -> tuple[str, str, str]: """Parse 'host:port:registry_id:role' into (server, registry_id, role).""" parts = spec.rsplit(":", 2) if len(parts) != 3: raise ValueError(f"invalid spec {spec!r}, expected host:port:registry_id:role") return parts[0], parts[1], parts[2] def _resolve_server_pk(userid: str, server: str) -> zkac.PublicKey: pin = store.load_server_pin(userid, server) if pin is None: raise RuntimeError( f"no pinned transport key for {server!r} under client {userid!r} " f"(pins are per client identity in ~/.zkac/{userid}/, not the userid " "passed to `zkac-node serve <…>` on the server). " f"Run: zkac-node server pin {userid} {server} --key " ) return zkac.PublicKey.from_bytes(_unb64(pin["server_public_key_b64"])) def _mgmt_connect(userid: str, server: str) -> tuple[socket.socket, FramedSession]: host, port = _parse_server(server) sock = socket.create_connection((host, port)) server_pk = _resolve_server_pk(userid, server) node = zkac.Node(zkac.Keypair()) session = client_handshake_anon(sock, node, server_pk) framed = FramedSession(sock, session) framed.send(json.dumps({"op": "mgmt"}).encode()) return sock, framed def _mgmt_cmd(framed: FramedSession, cmd: dict) -> dict: framed.send(json.dumps(cmd).encode()) return json.loads(framed.recv()) def _mgmt_single(userid: str, server: str, cmd: dict) -> dict: sock, framed = _mgmt_connect(userid, server) try: return _ok(_mgmt_cmd(framed, cmd)) finally: sock.close() def _ok(resp: dict) -> dict: if resp.get("error"): raise RuntimeError(resp["error"]) return resp # ── PIR hint cache ─────────────────────────────────────────────────── def _cache_dir(userid: str) -> Path: d = store.user_dir(userid) / "pir_cache" d.mkdir(parents=True, exist_ok=True) return d def _server_cache_key(server: str) -> str: return server.replace(":", "_") def _load_cached_hints(userid: str, server: str, pool_version: str) -> bytes | None: meta_path = _cache_dir(userid) / f"{_server_cache_key(server)}.json" bin_path = _cache_dir(userid) / f"{_server_cache_key(server)}.bin" if not meta_path.exists() or not bin_path.exists(): return None meta = json.loads(meta_path.read_text()) if meta.get("pool_version") != pool_version: return None return bin_path.read_bytes() def _save_cached_hints(userid: str, server: str, pool_version: str, n_records: int, record_bytes: int, hints_bytes: bytes): key = _server_cache_key(server) meta = {"pool_version": pool_version, "n_records": n_records, "record_bytes": record_bytes} (_cache_dir(userid) / f"{key}.json").write_text(json.dumps(meta)) (_cache_dir(userid) / f"{key}.bin").write_bytes(hints_bytes) def _fetch_hints_bytes(framed: FramedSession) -> tuple[bytes, str]: """Download PIR hint blob; uses chunked wire protocol when hints exceed one frame.""" resp = _ok(_mgmt_cmd(framed, {"cmd": "pir_hints"})) if "hints_b64" in resp: return _unb64(resp["hints_b64"]), resp["pool_version"] total = int(resp["hints_total"]) pv = resp["pool_version"] buf = bytearray() off = 0 while off < total: resp = _ok( _mgmt_cmd( framed, {"cmd": "pir_hints", "offset": off, "pool_version": pv}, ) ) if resp.get("pool_version") != pv: raise RuntimeError("PIR hints pool_version changed during download") piece = _unb64(resp["slice_b64"]) if not piece and off < total: raise RuntimeError("PIR hints short read") buf.extend(piece) off += len(piece) if len(buf) != total: raise RuntimeError( f"PIR hints size mismatch (got {len(buf)}, expected {total})" ) return bytes(buf), pv def _pir_client(userid: str, framed: FramedSession, server: str) -> tuple[zkac.PirClient, str]: """Fetch pool_info, load or refresh hints, return (PirClient, pool_version).""" info = _ok(_mgmt_cmd(framed, {"cmd": "pool_info"})) n = info["n"] rb = info["record_bytes"] pv = info["pool_version"] cached = _load_cached_hints(userid, server, pv) if cached is not None: return zkac.PirClient(cached, n, rb), pv hints_bytes, pv = _fetch_hints_bytes(framed) _save_cached_hints(userid, server, pv, n, rb, hints_bytes) return zkac.PirClient(hints_bytes, n, rb), pv def _fetch_row( userid: str, framed: FramedSession, server: str, pir_client: zkac.PirClient, pool_version: str, pool_index: int, ) -> dict: q, state = pir_client.query(pool_index) q_b64 = _b64(q) resp = _ok( _mgmt_cmd( framed, {"cmd": "pir_query", "query_b64": q_b64, "pool_version": pool_version}, ) ) if "answer_b64" in resp: ans_bytes = _unb64(resp["answer_b64"]) else: total = int(resp["answer_total"]) buf = bytearray() off = 0 while off < total: chunk = _ok( _mgmt_cmd( framed, { "cmd": "pir_query", "query_b64": q_b64, "pool_version": pool_version, "offset": off, }, ) ) piece = _unb64(chunk["slice_b64"]) if not piece and off < total: raise RuntimeError("PIR answer short read") buf.extend(piece) off += len(piece) if len(buf) != total: raise RuntimeError( f"PIR answer size mismatch (got {len(buf)}, expected {total})" ) ans_bytes = bytes(buf) raw = bytes(pir_client.decode(ans_bytes, state)) row = json.loads(raw.rstrip(b"\x00").decode("utf-8")) if row.get("v") != 2: raise RuntimeError("unsupported PIR row version") ct_b64 = row.get("ciphertext_b64", "") expect_digest = row.get("ciphertext_sha256", "") actual = hashlib.sha256(_unb64(ct_b64)).hexdigest() if ct_b64 else "" if actual != expect_digest: raise RuntimeError("PIR row ciphertext digest mismatch") return { "eph_pk_b64": row.get("eph_pk_b64", ""), "ciphertext_b64": ct_b64, "to_tag_b64": row.get("to_tag_b64", ""), "claimed": row.get("claimed", False), } # ── Public operations ──────────────────────────────────────────────── def create_registry(userid: str, server: str, role_names: list[str]) -> str: identity = store.load_identity(userid) admin_mat = store.new_admin_material() bbs_issuer, bbs_pk, admin_cred = store.reconstruct_admin(admin_mat) role_entries = [(zkac.role_id(name), bbs_pk, 1) for name in role_names] state = zkac.RegistryState.build( bbs_pk, identity["issuance_pk"], 1, b"\x00" * 32, role_entries, ) state_bytes = state.serialize() state_cert = state.certify(admin_cred) registry_id = state.registry_id() resp = _mgmt_single(userid, server, { "cmd": "create_registry", "state_bytes_b64": _b64(state_bytes), "state_cert_b64": _b64(bytes(state_cert)), }) rid_hex = resp["registry_id"] store.save_admin(userid, rid_hex, { "server": server, "roles": role_names, **admin_mat, }) return rid_hex def update_registry(userid: str, server: str, registry_id_hex: str, add_roles: list[str]): admin_data = store.load_admin(userid, registry_id_hex) bbs_issuer, bbs_pk, admin_cred = store.reconstruct_admin(admin_data) identity = store.load_identity(userid) cur = _mgmt_single(userid, server, { "cmd": "get_registry", "registry_id": registry_id_hex, }) old_state = zkac.RegistryState.deserialize(_unb64(cur["state_bytes_b64"])) prev_hash = old_state.state_hash() new_version = old_state.version() + 1 old_roles = admin_data.get("roles", []) all_roles = list(old_roles) + [r for r in add_roles if r not in old_roles] role_entries = [(zkac.role_id(name), bbs_pk, 1) for name in all_roles] new_state = zkac.RegistryState.build( bbs_pk, identity["issuance_pk"], new_version, bytes(prev_hash), role_entries, ) new_cert = new_state.certify(admin_cred) _mgmt_single(userid, server, { "cmd": "update_registry", "registry_id": registry_id_hex, "state_bytes_b64": _b64(new_state.serialize()), "state_cert_b64": _b64(bytes(new_cert)), }) admin_data["roles"] = all_roles store.save_admin(userid, registry_id_hex, admin_data) def get_registry(userid: str, server: str, registry_id_hex: str) -> dict: return _mgmt_single(userid, server, { "cmd": "get_registry", "registry_id": registry_id_hex, }) def list_own_registries(userid: str) -> list[dict]: result = [] for rid in store.list_admin_registries(userid): data = store.load_admin(userid, rid) result.append({ "registry_id": rid, "server": data.get("server", "?"), "roles": data.get("roles", []), }) return result def grant(userid: str, server: str, registry_id_hex: str, role_name: str, recipient_pk_hex: str) -> tuple[str, int]: admin_data = store.load_admin(userid, registry_id_hex) roles = admin_data.get("roles", []) if role_name not in roles: raise RuntimeError(f"role {role_name!r} not in registry (have: {roles})") bbs_issuer, bbs_pk, admin_cred = store.reconstruct_admin(admin_data) role_rid = zkac.role_id(role_name) epoch = 1 req = zkac.prepare_blind_request() blind_sig = bbs_issuer.issue_blind(req.commitment_with_proof(), role_rid, epoch) payload = json.dumps({ "registry_id": registry_id_hex, "role_name": role_name, "epoch": epoch, "issuer_pk_b64": _b64(bbs_pk.to_bytes()), "blind_sig_b64": _b64(blind_sig), "member_secret_b64": _b64(req.member_secret()), "prover_blind_b64": _b64(req.prover_blind()), }).encode() recipient_pk = bytes.fromhex(recipient_pk_hex) eph_kp = zkac.IssuanceKeypair() ciphertext = eph_kp.encrypt(recipient_pk, payload) to_tag = zkac.grant_detection_tag(eph_kp.secret_bytes(), recipient_pk) sock, framed = _mgmt_connect(userid, server) try: transcript_hash = bytes(framed.session.transcript_hash()) admin_proof = admin_cred.present(transcript_hash) resp = _ok(_mgmt_cmd(framed, { "cmd": "post_grant", "registry_id": registry_id_hex, "eph_pk_b64": _b64(eph_kp.public_key_bytes()), "ciphertext_b64": _b64(ciphertext), "to_tag_b64": _b64(to_tag), "admin_proof_b64": _b64(admin_proof), })) finally: sock.close() return "inline-pir-row", resp.get("pool_index", -1) def _match_tags(userid: str, tags: list[dict]) -> list[int]: """Return pool indices whose detection tag matches our issuance key.""" identity = store.load_identity(userid) receiver_sk = identity["issuance_sk"] matches = [] for idx, entry in enumerate(tags): eph_pk_b64 = entry.get("eph_pk_b64", "") to_tag_b64 = entry.get("to_tag_b64", "") if not eph_pk_b64 or not to_tag_b64: continue eph_pk = _unb64(eph_pk_b64) expected = zkac.grant_detection_tag(receiver_sk, eph_pk) if _unb64(to_tag_b64) == bytes(expected): matches.append(idx) return matches def list_pending(userid: str, server: str) -> list[dict]: """Discover pending grants via detection tags, then PIR-fetch full rows.""" identity = store.load_identity(userid) receiver_kp = zkac.IssuanceKeypair.from_secret(identity["issuance_sk"]) info = _mgmt_single(userid, server, {"cmd": "server_info"}) store.pin_server(userid, server, info["server_public_key_b64"]) sock, framed = _mgmt_connect(userid, server) try: tags_resp = _ok(_mgmt_cmd(framed, {"cmd": "pool_tags"})) tags = tags_resp["tags"] matches = _match_tags(userid, tags) if not matches: return [] pir_cl, pv = _pir_client(userid, framed, server) results = [] for idx in matches: try: row = _fetch_row(userid, framed, server, pir_cl, pv, idx) except Exception: continue if row.get("claimed"): continue try: eph_pk = _unb64(row["eph_pk_b64"]) ct = _unb64(row["ciphertext_b64"]) plaintext = json.loads(receiver_kp.decrypt(eph_pk, ct)) results.append({ "pool_index": idx, "registry_id": plaintext.get("registry_id", "?"), "role_name": plaintext.get("role_name", "?"), }) except Exception: results.append({ "pool_index": idx, "registry_id": "?", "role_name": "(undecryptable)", }) return results finally: sock.close() def collect( userid: str, spec: str, *, pool_index: int | None = None, ) -> dict: server, registry_id_hex, role_name = parse_spec(spec) identity = store.load_identity(userid) receiver_kp = zkac.IssuanceKeypair.from_secret(identity["issuance_sk"]) info = _mgmt_single(userid, server, {"cmd": "server_info"}) store.pin_server(userid, server, info["server_public_key_b64"]) sock, framed = _mgmt_connect(userid, server) try: if pool_index is None: tags_resp = _ok(_mgmt_cmd(framed, {"cmd": "pool_tags"})) tags = tags_resp["tags"] matches = _match_tags(userid, tags) if not matches: raise RuntimeError("no matching grants found in pool") pir_cl, pv = _pir_client(userid, framed, server) found = None for idx in matches: try: row = _fetch_row(userid, framed, server, pir_cl, pv, idx) except Exception: continue if row.get("claimed"): continue try: eph_pk = _unb64(row["eph_pk_b64"]) ct = _unb64(row["ciphertext_b64"]) plaintext = json.loads(receiver_kp.decrypt(eph_pk, ct)) except Exception: continue if (plaintext.get("registry_id") == registry_id_hex and plaintext.get("role_name") == role_name): found = (idx, row, plaintext) break if found is None: raise RuntimeError( f"no unclaimed grant for {registry_id_hex}:{role_name} in pool" ) pool_index, target_row, target_payload = found else: pir_cl, pv = _pir_client(userid, framed, server) target_row = _fetch_row(userid, framed, server, pir_cl, pv, pool_index) if target_row.get("claimed"): raise RuntimeError("grant row is already claimed") try: eph_pk = _unb64(target_row["eph_pk_b64"]) ct = _unb64(target_row["ciphertext_b64"]) target_payload = json.loads(receiver_kp.decrypt(eph_pk, ct)) except Exception as exc: raise RuntimeError("PIR row did not decrypt for this user") from exc if (target_payload.get("registry_id") != registry_id_hex or target_payload.get("role_name") != role_name): raise RuntimeError( "PIR row does not match this collect spec" ) finally: sock.close() _ = _mgmt_single(userid, server, { "cmd": "get_registry", "registry_id": registry_id_hex, }) cred_data = { "blind_sig_b64": target_payload["blind_sig_b64"], "member_secret_b64": target_payload["member_secret_b64"], "prover_blind_b64": target_payload["prover_blind_b64"], "role_name": role_name, "epoch": target_payload["epoch"], "issuer_pk_b64": target_payload["issuer_pk_b64"], } cred = store.reconstruct_credential(cred_data) cred.present(b"self-test") store.save_credential(userid, registry_id_hex, role_name, cred_data) return {"registry_id": registry_id_hex, "role": role_name, "server": server} def authenticate(userid: str, registry_id_hex: str, role_name: str, server: str | None = None) -> dict: admin_data = None try: admin_data = store.load_admin(userid, registry_id_hex) except FileNotFoundError: pass if server is None: if admin_data and admin_data.get("server"): server = admin_data["server"] else: raise RuntimeError("server address required (--server host:port)") cred_data = store.load_credential_data(userid, registry_id_hex, role_name) cred = store.reconstruct_credential(cred_data) server_pk = _resolve_server_pk(userid, server) node = zkac.Node(zkac.Keypair()) host, port = _parse_server(server) sock = socket.create_connection((host, port)) try: session = client_handshake_anon(sock, node, server_pk) framed = FramedSession(sock, session) transcript_hash = bytes(session.transcript_hash()) auth_proof = cred.present(transcript_hash) role_rid = zkac.role_id(role_name) framed.send(json.dumps({ "op": "auth", "registry_id": registry_id_hex, "role_id": role_rid.hex(), "bbs_auth_b64": _b64(auth_proof), }).encode()) return json.loads(framed.recv()) finally: sock.close()