464 lines
16 KiB
Python
464 lines
16 KiB
Python
"""Client-side operations over a unified encrypted channel (per local user id)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import json
|
|
import socket
|
|
from pathlib import Path
|
|
|
|
import zkac
|
|
from zkac.tcp import FramedSession, client_handshake_anon
|
|
|
|
from . import store
|
|
|
|
|
|
def _b64(data: bytes) -> str:
|
|
return base64.b64encode(data).decode()
|
|
|
|
|
|
def _unb64(s: str) -> bytes:
|
|
return base64.b64decode(s)
|
|
|
|
|
|
def _parse_server(server: str) -> tuple[str, int]:
|
|
host, _, port = server.rpartition(":")
|
|
return host or "127.0.0.1", int(port)
|
|
|
|
|
|
def parse_spec(spec: str) -> tuple[str, str, str]:
|
|
"""Parse 'host:port:registry_id:role' into (server, registry_id, role)."""
|
|
parts = spec.rsplit(":", 2)
|
|
if len(parts) != 3:
|
|
raise ValueError(f"invalid spec {spec!r}, expected host:port:registry_id:role")
|
|
return parts[0], parts[1], parts[2]
|
|
|
|
|
|
def _resolve_server_pk(userid: str, server: str) -> zkac.PublicKey:
|
|
pin = store.load_server_pin(userid, server)
|
|
if pin is None:
|
|
raise RuntimeError(
|
|
f"no pinned key for {server}; run: zkac-node server pin {userid} {server} --key <hex>"
|
|
)
|
|
return zkac.PublicKey.from_bytes(_unb64(pin["server_public_key_b64"]))
|
|
|
|
|
|
def _mgmt_connect(userid: str, server: str) -> tuple[socket.socket, FramedSession]:
|
|
host, port = _parse_server(server)
|
|
sock = socket.create_connection((host, port))
|
|
server_pk = _resolve_server_pk(userid, server)
|
|
node = zkac.Node(zkac.Keypair())
|
|
session = client_handshake_anon(sock, node, server_pk)
|
|
framed = FramedSession(sock, session)
|
|
framed.send(json.dumps({"op": "mgmt"}).encode())
|
|
return sock, framed
|
|
|
|
|
|
def _mgmt_cmd(framed: FramedSession, cmd: dict) -> dict:
|
|
framed.send(json.dumps(cmd).encode())
|
|
return json.loads(framed.recv())
|
|
|
|
|
|
def _mgmt_single(userid: str, server: str, cmd: dict) -> dict:
|
|
sock, framed = _mgmt_connect(userid, server)
|
|
try:
|
|
return _ok(_mgmt_cmd(framed, cmd))
|
|
finally:
|
|
sock.close()
|
|
|
|
|
|
def _ok(resp: dict) -> dict:
|
|
if resp.get("error"):
|
|
raise RuntimeError(resp["error"])
|
|
return resp
|
|
|
|
|
|
# ── PIR hint cache ───────────────────────────────────────────────────
|
|
|
|
def _cache_dir(userid: str) -> Path:
|
|
d = store.user_dir(userid) / "pir_cache"
|
|
d.mkdir(parents=True, exist_ok=True)
|
|
return d
|
|
|
|
|
|
def _server_cache_key(server: str) -> str:
|
|
return server.replace(":", "_")
|
|
|
|
|
|
def _load_cached_hints(userid: str, server: str, pool_version: str) -> bytes | None:
|
|
meta_path = _cache_dir(userid) / f"{_server_cache_key(server)}.json"
|
|
bin_path = _cache_dir(userid) / f"{_server_cache_key(server)}.bin"
|
|
if not meta_path.exists() or not bin_path.exists():
|
|
return None
|
|
meta = json.loads(meta_path.read_text())
|
|
if meta.get("pool_version") != pool_version:
|
|
return None
|
|
return bin_path.read_bytes()
|
|
|
|
|
|
def _save_cached_hints(userid: str, server: str, pool_version: str,
|
|
n_records: int, record_bytes: int, hints_bytes: bytes):
|
|
key = _server_cache_key(server)
|
|
meta = {"pool_version": pool_version, "n_records": n_records, "record_bytes": record_bytes}
|
|
(_cache_dir(userid) / f"{key}.json").write_text(json.dumps(meta))
|
|
(_cache_dir(userid) / f"{key}.bin").write_bytes(hints_bytes)
|
|
|
|
|
|
def _pir_client(userid: str, framed: FramedSession, server: str) -> tuple[zkac.PirClient, str]:
|
|
"""Fetch pool_info, load or refresh hints, return (PirClient, pool_version)."""
|
|
info = _ok(_mgmt_cmd(framed, {"cmd": "pool_info"}))
|
|
n = info["n"]
|
|
rb = info["record_bytes"]
|
|
pv = info["pool_version"]
|
|
|
|
cached = _load_cached_hints(userid, server, pv)
|
|
if cached is not None:
|
|
return zkac.PirClient(cached, n, rb), pv
|
|
|
|
resp = _ok(_mgmt_cmd(framed, {"cmd": "pir_hints"}))
|
|
hints_bytes = _unb64(resp["hints_b64"])
|
|
pv = resp["pool_version"]
|
|
_save_cached_hints(userid, server, pv, n, rb, hints_bytes)
|
|
return zkac.PirClient(hints_bytes, n, rb), pv
|
|
|
|
|
|
def _fetch_row(
|
|
userid: str, framed: FramedSession, server: str,
|
|
pir_client: zkac.PirClient, pool_version: str, pool_index: int,
|
|
) -> dict:
|
|
q, state = pir_client.query(pool_index)
|
|
resp = _ok(_mgmt_cmd(framed, {
|
|
"cmd": "pir_query",
|
|
"query_b64": _b64(q),
|
|
"pool_version": pool_version,
|
|
}))
|
|
raw = bytes(pir_client.decode(_unb64(resp["answer_b64"]), state))
|
|
return json.loads(raw.rstrip(b"\x00").decode("utf-8"))
|
|
|
|
|
|
# ── Public operations ────────────────────────────────────────────────
|
|
|
|
def create_registry(userid: str, server: str, role_names: list[str]) -> str:
|
|
identity = store.load_identity(userid)
|
|
admin_mat = store.new_admin_material()
|
|
bbs_issuer, bbs_pk, admin_cred = store.reconstruct_admin(admin_mat)
|
|
|
|
role_entries = [(zkac.role_id(name), bbs_pk, 1) for name in role_names]
|
|
state = zkac.RegistryState.build(
|
|
bbs_pk, identity["issuance_pk"], 1, b"\x00" * 32, role_entries,
|
|
)
|
|
state_bytes = state.serialize()
|
|
state_cert = state.certify(admin_cred)
|
|
registry_id = state.registry_id()
|
|
|
|
resp = _mgmt_single(userid, server, {
|
|
"cmd": "create_registry",
|
|
"state_bytes_b64": _b64(state_bytes),
|
|
"state_cert_b64": _b64(bytes(state_cert)),
|
|
})
|
|
|
|
rid_hex = resp["registry_id"]
|
|
store.save_admin(userid, rid_hex, {
|
|
"server": server,
|
|
"roles": role_names,
|
|
**admin_mat,
|
|
})
|
|
return rid_hex
|
|
|
|
|
|
def update_registry(userid: str, server: str, registry_id_hex: str, add_roles: list[str]):
|
|
admin_data = store.load_admin(userid, registry_id_hex)
|
|
bbs_issuer, bbs_pk, admin_cred = store.reconstruct_admin(admin_data)
|
|
identity = store.load_identity(userid)
|
|
|
|
cur = _mgmt_single(userid, server, {
|
|
"cmd": "get_registry", "registry_id": registry_id_hex,
|
|
})
|
|
|
|
old_state = zkac.RegistryState.deserialize(_unb64(cur["state_bytes_b64"]))
|
|
prev_hash = old_state.state_hash()
|
|
new_version = old_state.version() + 1
|
|
|
|
old_roles = admin_data.get("roles", [])
|
|
all_roles = list(old_roles) + [r for r in add_roles if r not in old_roles]
|
|
role_entries = [(zkac.role_id(name), bbs_pk, 1) for name in all_roles]
|
|
|
|
new_state = zkac.RegistryState.build(
|
|
bbs_pk, identity["issuance_pk"], new_version, bytes(prev_hash), role_entries,
|
|
)
|
|
new_cert = new_state.certify(admin_cred)
|
|
|
|
_mgmt_single(userid, server, {
|
|
"cmd": "update_registry",
|
|
"registry_id": registry_id_hex,
|
|
"state_bytes_b64": _b64(new_state.serialize()),
|
|
"state_cert_b64": _b64(bytes(new_cert)),
|
|
})
|
|
|
|
admin_data["roles"] = all_roles
|
|
store.save_admin(userid, registry_id_hex, admin_data)
|
|
|
|
|
|
def get_registry(userid: str, server: str, registry_id_hex: str) -> dict:
|
|
return _mgmt_single(userid, server, {
|
|
"cmd": "get_registry", "registry_id": registry_id_hex,
|
|
})
|
|
|
|
|
|
def list_own_registries(userid: str) -> list[dict]:
|
|
result = []
|
|
for rid in store.list_admin_registries(userid):
|
|
data = store.load_admin(userid, rid)
|
|
result.append({
|
|
"registry_id": rid,
|
|
"server": data.get("server", "?"),
|
|
"roles": data.get("roles", []),
|
|
})
|
|
return result
|
|
|
|
|
|
def grant(userid: str, server: str, registry_id_hex: str, role_name: str,
|
|
recipient_pk_hex: str) -> tuple[str, int]:
|
|
admin_data = store.load_admin(userid, registry_id_hex)
|
|
roles = admin_data.get("roles", [])
|
|
if role_name not in roles:
|
|
raise RuntimeError(f"role {role_name!r} not in registry (have: {roles})")
|
|
|
|
bbs_issuer, bbs_pk, admin_cred = store.reconstruct_admin(admin_data)
|
|
role_rid = zkac.role_id(role_name)
|
|
epoch = 1
|
|
|
|
req = zkac.prepare_blind_request()
|
|
blind_sig = bbs_issuer.issue_blind(req.commitment_with_proof(), role_rid, epoch)
|
|
|
|
payload = json.dumps({
|
|
"registry_id": registry_id_hex,
|
|
"role_name": role_name,
|
|
"epoch": epoch,
|
|
"issuer_pk_b64": _b64(bbs_pk.to_bytes()),
|
|
"blind_sig_b64": _b64(blind_sig),
|
|
"member_secret_b64": _b64(req.member_secret()),
|
|
"prover_blind_b64": _b64(req.prover_blind()),
|
|
}).encode()
|
|
|
|
recipient_pk = bytes.fromhex(recipient_pk_hex)
|
|
eph_kp = zkac.IssuanceKeypair()
|
|
ciphertext = eph_kp.encrypt(recipient_pk, payload)
|
|
to_tag = zkac.grant_detection_tag(eph_kp.secret_bytes(), recipient_pk)
|
|
|
|
sock, framed = _mgmt_connect(userid, server)
|
|
try:
|
|
transcript_hash = bytes(framed.session.transcript_hash())
|
|
admin_proof = admin_cred.present(transcript_hash)
|
|
resp = _ok(_mgmt_cmd(framed, {
|
|
"cmd": "post_grant",
|
|
"registry_id": registry_id_hex,
|
|
"eph_pk_b64": _b64(eph_kp.public_key_bytes()),
|
|
"ciphertext_b64": _b64(ciphertext),
|
|
"to_tag_b64": _b64(to_tag),
|
|
"admin_proof_b64": _b64(admin_proof),
|
|
}))
|
|
finally:
|
|
sock.close()
|
|
|
|
return resp["grant_id"], resp.get("pool_index", -1)
|
|
|
|
|
|
def _match_tags(userid: str, tags: list[dict]) -> list[int]:
|
|
"""Return pool indices whose detection tag matches our issuance key."""
|
|
identity = store.load_identity(userid)
|
|
receiver_sk = identity["issuance_sk"]
|
|
matches = []
|
|
for idx, entry in enumerate(tags):
|
|
eph_pk_b64 = entry.get("eph_pk_b64", "")
|
|
to_tag_b64 = entry.get("to_tag_b64", "")
|
|
if not eph_pk_b64 or not to_tag_b64:
|
|
continue
|
|
eph_pk = _unb64(eph_pk_b64)
|
|
expected = zkac.grant_detection_tag(receiver_sk, eph_pk)
|
|
if _unb64(to_tag_b64) == bytes(expected):
|
|
matches.append(idx)
|
|
return matches
|
|
|
|
|
|
def list_pending(userid: str, server: str) -> list[dict]:
|
|
"""Discover pending grants via detection tags, then PIR-fetch matches."""
|
|
identity = store.load_identity(userid)
|
|
receiver_kp = zkac.IssuanceKeypair.from_secret(identity["issuance_sk"])
|
|
|
|
info = _mgmt_single(userid, server, {"cmd": "server_info"})
|
|
store.pin_server(userid, server, info["server_public_key_b64"])
|
|
|
|
sock, framed = _mgmt_connect(userid, server)
|
|
try:
|
|
tags_resp = _ok(_mgmt_cmd(framed, {"cmd": "pool_tags"}))
|
|
tags = tags_resp["tags"]
|
|
matches = _match_tags(userid, tags)
|
|
|
|
if not matches:
|
|
return []
|
|
|
|
pir_cl, pv = _pir_client(userid, framed, server)
|
|
results = []
|
|
for idx in matches:
|
|
try:
|
|
row = _fetch_row(userid, framed, server, pir_cl, pv, idx)
|
|
except Exception:
|
|
continue
|
|
if row.get("claimed"):
|
|
continue
|
|
try:
|
|
eph_pk = _unb64(row["eph_pk_b64"])
|
|
ct = _unb64(row["ciphertext_b64"])
|
|
plaintext = json.loads(receiver_kp.decrypt(eph_pk, ct))
|
|
results.append({
|
|
"grant_id": row["grant_id"],
|
|
"pool_index": idx,
|
|
"registry_id": plaintext.get("registry_id", "?"),
|
|
"role_name": plaintext.get("role_name", "?"),
|
|
})
|
|
except Exception:
|
|
results.append({
|
|
"grant_id": row.get("grant_id", "?"),
|
|
"pool_index": idx,
|
|
"registry_id": "?",
|
|
"role_name": "(undecryptable)",
|
|
})
|
|
return results
|
|
finally:
|
|
sock.close()
|
|
|
|
|
|
def collect(
|
|
userid: str,
|
|
spec: str,
|
|
*,
|
|
pool_index: int | None = None,
|
|
) -> dict:
|
|
server, registry_id_hex, role_name = parse_spec(spec)
|
|
identity = store.load_identity(userid)
|
|
receiver_kp = zkac.IssuanceKeypair.from_secret(identity["issuance_sk"])
|
|
|
|
info = _mgmt_single(userid, server, {"cmd": "server_info"})
|
|
store.pin_server(userid, server, info["server_public_key_b64"])
|
|
|
|
sock, framed = _mgmt_connect(userid, server)
|
|
try:
|
|
if pool_index is None:
|
|
tags_resp = _ok(_mgmt_cmd(framed, {"cmd": "pool_tags"}))
|
|
tags = tags_resp["tags"]
|
|
matches = _match_tags(userid, tags)
|
|
if not matches:
|
|
raise RuntimeError("no matching grants found in pool")
|
|
|
|
pir_cl, pv = _pir_client(userid, framed, server)
|
|
found = None
|
|
for idx in matches:
|
|
try:
|
|
row = _fetch_row(userid, framed, server, pir_cl, pv, idx)
|
|
except Exception:
|
|
continue
|
|
if row.get("claimed"):
|
|
continue
|
|
try:
|
|
eph_pk = _unb64(row["eph_pk_b64"])
|
|
ct = _unb64(row["ciphertext_b64"])
|
|
plaintext = json.loads(receiver_kp.decrypt(eph_pk, ct))
|
|
except Exception:
|
|
continue
|
|
if (plaintext.get("registry_id") == registry_id_hex and
|
|
plaintext.get("role_name") == role_name):
|
|
found = (idx, row, plaintext)
|
|
break
|
|
if found is None:
|
|
raise RuntimeError(
|
|
f"no unclaimed grant for {registry_id_hex}:{role_name} in pool"
|
|
)
|
|
pool_index, target_row, target_payload = found
|
|
else:
|
|
pir_cl, pv = _pir_client(userid, framed, server)
|
|
target_row = _fetch_row(userid, framed, server, pir_cl, pv, pool_index)
|
|
if target_row.get("claimed"):
|
|
raise RuntimeError("grant row is already claimed")
|
|
try:
|
|
eph_pk = _unb64(target_row["eph_pk_b64"])
|
|
ct = _unb64(target_row["ciphertext_b64"])
|
|
target_payload = json.loads(receiver_kp.decrypt(eph_pk, ct))
|
|
except Exception as exc:
|
|
raise RuntimeError("PIR row did not decrypt for this user") from exc
|
|
if (target_payload.get("registry_id") != registry_id_hex or
|
|
target_payload.get("role_name") != role_name):
|
|
raise RuntimeError(
|
|
"PIR row does not match this collect spec"
|
|
)
|
|
finally:
|
|
sock.close()
|
|
|
|
target_grant_id = target_row["grant_id"]
|
|
|
|
_mgmt_single(userid, server, {
|
|
"cmd": "claim_grant",
|
|
"grant_id": target_grant_id,
|
|
})
|
|
|
|
_ = _mgmt_single(userid, server, {
|
|
"cmd": "get_registry", "registry_id": registry_id_hex,
|
|
})
|
|
|
|
cred_data = {
|
|
"blind_sig_b64": target_payload["blind_sig_b64"],
|
|
"member_secret_b64": target_payload["member_secret_b64"],
|
|
"prover_blind_b64": target_payload["prover_blind_b64"],
|
|
"role_name": role_name,
|
|
"epoch": target_payload["epoch"],
|
|
"issuer_pk_b64": target_payload["issuer_pk_b64"],
|
|
}
|
|
cred = store.reconstruct_credential(cred_data)
|
|
cred.present(b"self-test")
|
|
|
|
store.save_credential(userid, registry_id_hex, role_name, cred_data)
|
|
|
|
return {"registry_id": registry_id_hex, "role": role_name, "server": server}
|
|
|
|
|
|
def authenticate(userid: str, registry_id_hex: str, role_name: str,
|
|
server: str | None = None) -> dict:
|
|
admin_data = None
|
|
try:
|
|
admin_data = store.load_admin(userid, registry_id_hex)
|
|
except FileNotFoundError:
|
|
pass
|
|
|
|
if server is None:
|
|
if admin_data and admin_data.get("server"):
|
|
server = admin_data["server"]
|
|
else:
|
|
raise RuntimeError("server address required (--server host:port)")
|
|
|
|
cred_data = store.load_credential_data(userid, registry_id_hex, role_name)
|
|
cred = store.reconstruct_credential(cred_data)
|
|
|
|
server_pk = _resolve_server_pk(userid, server)
|
|
node = zkac.Node(zkac.Keypair())
|
|
host, port = _parse_server(server)
|
|
|
|
sock = socket.create_connection((host, port))
|
|
try:
|
|
session = client_handshake_anon(sock, node, server_pk)
|
|
framed = FramedSession(sock, session)
|
|
|
|
transcript_hash = bytes(session.transcript_hash())
|
|
auth_proof = cred.present(transcript_hash)
|
|
role_rid = zkac.role_id(role_name)
|
|
|
|
framed.send(json.dumps({
|
|
"op": "auth",
|
|
"registry_id": registry_id_hex,
|
|
"role_id": role_rid.hex(),
|
|
"bbs_auth_b64": _b64(auth_proof),
|
|
}).encode())
|
|
|
|
return json.loads(framed.recv())
|
|
finally:
|
|
sock.close()
|