From 4b2543fad2a350f8cc4ddcd8c6156c5e05c998aa Mon Sep 17 00:00:00 2001 From: everbarry Date: Wed, 15 Apr 2026 11:32:01 +0200 Subject: [PATCH] TCP/UDP demo --- demo/client_cli.py | 89 +++++++++++++++++ demo/creds/.gitignore | 2 + demo/server.py | 196 ++++++++++++++++++++++++++++++++++++++ demo/setup_demo.py | 79 +++++++++++++++ demo/static/index.html | 33 +++++++ python/zkac/tcp.py | 128 +++++++++++++++++++++++++ python/zkac/udp.py | 177 ++++++++++++++++++++++++++++++++++ src/credential/schnorr.rs | 25 +++++ src/python.rs | 19 ++++ tests/test_zkac.py | 7 ++ tests/test_zkac_tcp.py | 113 ++++++++++++++++++++++ tests/test_zkac_udp.py | 83 ++++++++++++++++ 12 files changed, 951 insertions(+) create mode 100644 demo/client_cli.py create mode 100644 demo/creds/.gitignore create mode 100644 demo/server.py create mode 100644 demo/setup_demo.py create mode 100644 demo/static/index.html create mode 100644 python/zkac/tcp.py create mode 100644 python/zkac/udp.py create mode 100644 tests/test_zkac_tcp.py create mode 100644 tests/test_zkac_udp.py diff --git a/demo/client_cli.py b/demo/client_cli.py new file mode 100644 index 0000000..5c46b85 --- /dev/null +++ b/demo/client_cli.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +""" +ZKAC TCP client: load a member credential, complete handshake, request /api (encrypted JSON). + +The "API" is not HTTP on port 8765 — it is one JSON request inside the ZKAC session on --port (default 9876). +""" + +from __future__ import annotations + +import argparse +import base64 +import json +import socket +from pathlib import Path + +import zkac +from zkac.tcp import FramedSession, client_handshake + + +def load_credential(member_json: Path) -> zkac.Credential: + """Rebuild Credential from setup_demo.py output (same fields as zkac.Credential.finalize).""" + m = json.loads(member_json.read_text(encoding="utf-8")) + pk = zkac.BbsPublicKey.from_bytes(base64.b64decode(m["issuer_public_key_b64"])) + rid = bytes.fromhex(m["role_id_hex"]) + return zkac.Credential.finalize( + base64.b64decode(m["blind_sig_b64"]), + base64.b64decode(m["member_secret_b64"]), + base64.b64decode(m["prover_blind_b64"]), + rid, + int(m["epoch"]), + pk, + ) + + +def load_server_pk(creds_dir: Path) -> zkac.PublicKey: + """Pinned server identity: must match the Keypair used by server.py (from transport.json).""" + t = json.loads((creds_dir / "transport.json").read_text(encoding="utf-8")) + raw = base64.b64decode(t["server_public_key_b64"]) + return zkac.PublicKey.from_bytes(raw) + + +def main() -> None: + ap = argparse.ArgumentParser(description="ZKAC demo client (TCP + credential)") + ap.add_argument( + "--creds-dir", + type=Path, + default=Path(__file__).resolve().parent / "creds", + help="Directory with transport.json and member_*.json", + ) + ap.add_argument( + "--member", + type=Path, + help="Path to member_*.json (default: creds-dir/member_analyst.json)", + ) + ap.add_argument("--host", default="127.0.0.1") + ap.add_argument("--port", type=int, default=9876) + args = ap.parse_args() + + creds_dir: Path = args.creds_dir + member_path = args.member or (creds_dir / "member_analyst.json") + if not member_path.is_file(): + raise SystemExit(f"Missing member file: {member_path}") + + credential = load_credential(member_path) + server_pk = load_server_pk(creds_dir) + + # Ephemeral client transport identity (not the BBS+ member secret — that is inside credential). + client_kp = zkac.Keypair() + node = zkac.Node(client_kp) + + sock = socket.create_connection((args.host, args.port)) + try: + # X25519 + server Schnorr + BBS+ auth; returns symmetric Session. + session = client_handshake(sock, node, server_pk, credential) + framed = FramedSession(sock, session) + + # Logical GET /api: path is checked by server after decrypt. + request_obj = {"path": "/api"} + payload = json.dumps(request_obj).encode("utf-8") + framed.send(payload) + + reply = framed.recv().decode("utf-8") + print(json.dumps(json.loads(reply), indent=2)) + finally: + sock.close() + + +if __name__ == "__main__": + main() diff --git a/demo/creds/.gitignore b/demo/creds/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/demo/creds/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/demo/server.py b/demo/server.py new file mode 100644 index 0000000..08a355c --- /dev/null +++ b/demo/server.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 +""" +HTTP static site + ZKAC TCP service. Authenticated /api is accessed over TCP with zkac.tcp +after setup_demo.py has created creds/. +""" + +from __future__ import annotations + +import argparse +import base64 +import json +import os +import threading +import http.server +import socket +import traceback +from pathlib import Path + +import zkac +from zkac.tcp import FramedSession, server_handshake + + +def load_registry(creds_dir: Path, epoch: int) -> zkac.RoleRegistry: + """Load issuer public key and register every demo role at the same epoch.""" + iss = json.loads((creds_dir / "issuer.json").read_text(encoding="utf-8")) + issuer_pk = zkac.BbsPublicKey.from_bytes( + base64.b64decode(iss["issuer_public_key_b64"]) + ) + reg = zkac.RoleRegistry() + for name in ("analyst", "operator"): + reg.register_role(zkac.role_id(name), issuer_pk, epoch) + return reg + + +def _role_debug_label(role_id: bytes) -> str: + """Map verified role_id bytes to a short label for logs (demo only).""" + for name in ("analyst", "operator"): + if role_id == zkac.role_id(name): + return name + return "unknown" + + +def api_body_for_role(role_id: bytes) -> dict: + """JSON returned for the logical /api resource after ZKAC auth; varies by credential role.""" + analyst = zkac.role_id("analyst") + operator = zkac.role_id("operator") + if role_id == analyst: + return { + "path": "/api", + "role": "analyst", + "datasets": ["summary", "aggregated_metrics"], + "note": "Analyst tier: aggregated data only.", + } + if role_id == operator: + return { + "path": "/api", + "role": "operator", + "datasets": ["summary", "aggregated_metrics", "raw_logs", "pii"], + "note": "Operator tier: full API slice including raw logs.", + } + return {"error": "unknown role", "path": "/api"} + + +def handle_zkac_client( + conn: socket.socket, + client_addr: tuple, + creds_dir: Path, + registry: zkac.RoleRegistry, +) -> None: + """ + One TCP connection: ZKAC handshake + BBS+ auth, then one framed JSON request and response. + Each handler rebuilds the server Node from persisted secret (Keypair is consumed by Node). + """ + peer = f"{client_addr[0]}:{client_addr[1]}" + print(f"[zkac] connect peer={peer}") + + try: + # Same long-term server identity every time; from_secret_key because Node consumes Keypair. + t = json.loads((creds_dir / "transport.json").read_text(encoding="utf-8")) + sk = base64.b64decode(t["server_secret_key_b64"]) + server_kp = zkac.Keypair.from_secret_key(sk) + node = zkac.Node(server_kp) + + session, role_id = server_handshake(conn, node, registry) + label = _role_debug_label(role_id) + print( + f"[zkac] handshake_ok peer={peer} role_id={role_id.hex()} role={label!r}" + ) + + framed = FramedSession(conn, session) + raw = framed.recv() + print( + f"[zkac] request peer={peer} plaintext_bytes={len(raw)} raw={raw!r}" + ) + + req = json.loads(raw.decode("utf-8")) + print(f"[zkac] request_json peer={peer} parsed={req!r}") + + path = req.get("path") + if path != "/api": + err_body = {"error": "unsupported path", "allowed": ["/api"], "got": path} + out = json.dumps(err_body).encode() + framed.send(out) + print( + f"[zkac] response peer={peer} status=reject path={path!r} response_bytes={len(out)}" + ) + return + + body = api_body_for_role(role_id) + out_bytes = json.dumps(body).encode() + framed.send(out_bytes) + print( + f"[zkac] response peer={peer} status=ok path=/api role={label!r} " + f"response_bytes={len(out_bytes)} body_keys={list(body.keys())}" + ) + except (ConnectionError, BrokenPipeError, OSError) as e: + print(f"[zkac] peer={peer} connection_error: {e!r}") + except (json.JSONDecodeError, ValueError) as e: + print(f"[zkac] peer={peer} protocol_error: {e!r}") + except Exception as e: + print(f"[zkac] peer={peer} unexpected_error: {e!r}") + traceback.print_exc() + finally: + conn.close() + print(f"[zkac] closed peer={peer}") + + +def run_http(host: str, port: int, static_root: Path) -> None: + # Process-wide CWD: only this thread should rely on relative paths after chdir. + os.chdir(static_root) + + class Handler(http.server.SimpleHTTPRequestHandler): + def log_message(self, fmt: str, *args) -> None: + # Default fmt is like '%s - - [%s] %s' — include client address for debugging. + try: + line = fmt % args if args else fmt + except (TypeError, ValueError): + line = f"{fmt} {args}" + peer_ip = self.client_address[0] if self.client_address else "?" + peer_port = self.client_address[1] if len(self.client_address) > 1 else "?" + print(f"[http] peer={peer_ip}:{peer_port} | {line.strip()}") + + http.server.HTTPServer((host, port), Handler).serve_forever() + + +def main() -> None: + ap = argparse.ArgumentParser(description="ZKAC demo HTTP + TCP server") + ap.add_argument( + "--creds-dir", + type=Path, + default=Path(__file__).resolve().parent / "creds", + ) + ap.add_argument("--http-host", default="127.0.0.1") + ap.add_argument("--http-port", type=int, default=8765) + ap.add_argument("--zkac-host", default="127.0.0.1") + ap.add_argument("--zkac-port", type=int, default=9876) + args = ap.parse_args() + creds_dir: Path = args.creds_dir + if not (creds_dir / "transport.json").is_file(): + raise SystemExit(f"Missing {creds_dir}/transport.json — run setup_demo.py first.") + + # Epoch must match the member files issued at setup (any member file is enough). + member = json.loads((creds_dir / "member_analyst.json").read_text(encoding="utf-8")) + epoch = int(member["epoch"]) + registry = load_registry(creds_dir, epoch) + + static_root = Path(__file__).resolve().parent / "static" + if not static_root.is_dir(): + raise SystemExit(f"Missing static directory: {static_root}") + + http_thread = threading.Thread( + target=run_http, + args=(args.http_host, args.http_port, static_root), + daemon=True, + ) + http_thread.start() + print( + f"HTTP http://{args.http_host}:{args.http_port}/ (static demo page)\n" + f"ZKAC {args.zkac_host}:{args.zkac_port} (authenticated /api over TCP)" + ) + + zkac_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + zkac_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + zkac_sock.bind((args.zkac_host, args.zkac_port)) + zkac_sock.listen(8) + while True: + conn, addr = zkac_sock.accept() + threading.Thread( + target=handle_zkac_client, + args=(conn, addr, creds_dir, registry), + daemon=True, + ).start() + + +if __name__ == "__main__": + main() diff --git a/demo/setup_demo.py b/demo/setup_demo.py new file mode 100644 index 0000000..95cae44 --- /dev/null +++ b/demo/setup_demo.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +""" +Generate demo credentials under creds/: issuer, server transport key, two member credentials. +Run once before starting the server. +""" + +from __future__ import annotations + +import argparse +import base64 +import json +from pathlib import Path + +import zkac + +# Human-readable role names; each becomes a 32-byte opaque role_id via zkac.role_id(). +# Must stay in sync with server.py (registry + api_body_for_role). +ROLES = ("analyst", "operator") + + +def main() -> None: + ap = argparse.ArgumentParser(description="Generate ZKAC demo credential files.") + ap.add_argument( + "--output-dir", + type=Path, + default=Path(__file__).resolve().parent / "creds", + help="Directory to write files (default: demo/creds)", + ) + args = ap.parse_args() + out: Path = args.output_dir + out.mkdir(parents=True, exist_ok=True) + + # BBS+ issuer: signs blind credentials; server only needs the public key in RoleRegistry. + issuer = zkac.BbsIssuer() + issuer_pk = issuer.public_key() + epoch = 1 + + # Long-term Ristretto identity for the TCP server (X25519 handshake + Schnorr identity proof). + server_kp = zkac.Keypair() + server_pk = server_kp.public_key() + + issuer_payload = { + "issuer_secret_key_b64": base64.b64encode(issuer.secret_key_bytes()).decode(), + "issuer_public_key_b64": base64.b64encode(issuer_pk.to_bytes()).decode(), + } + (out / "issuer.json").write_text(json.dumps(issuer_payload, indent=2), encoding="utf-8") + + transport_payload = { + "server_secret_key_b64": base64.b64encode(server_kp.secret_key_bytes()).decode(), + "server_public_key_b64": base64.b64encode(server_pk.to_bytes()).decode(), + } + (out / "transport.json").write_text(json.dumps(transport_payload, indent=2), encoding="utf-8") + + # One blind issuance per role: issuer never learns member_secret. + for role_name in ROLES: + rid = zkac.role_id(role_name) + req = zkac.prepare_blind_request() + blind_sig = issuer.issue_blind(req.commitment_with_proof(), rid, epoch) + member = { + "role_name": role_name, + "role_id_hex": rid.hex(), + "epoch": epoch, + "blind_sig_b64": base64.b64encode(blind_sig).decode(), + "member_secret_b64": base64.b64encode(req.member_secret()).decode(), + "prover_blind_b64": base64.b64encode(req.prover_blind()).decode(), + "issuer_public_key_b64": base64.b64encode(issuer_pk.to_bytes()).decode(), + } + (out / f"member_{role_name}.json").write_text( + json.dumps(member, indent=2), encoding="utf-8" + ) + + print(f"Wrote issuer, transport, and member files to {out}") + print( + f"Roles: {', '.join(ROLES)} — use member_{ROLES[0]}.json / member_{ROLES[1]}.json with client_cli.py" + ) + + +if __name__ == "__main__": + main() diff --git a/demo/static/index.html b/demo/static/index.html new file mode 100644 index 0000000..694f349 --- /dev/null +++ b/demo/static/index.html @@ -0,0 +1,33 @@ + + + + + + ZKAC demo + + + +

ZKAC demo

+

+ This page is served over normal HTTP. Role-based /api data is not on this port: + it is exposed only after a ZKAC session on the separate TCP port (BBS+ credential + encrypted transport). +

+ +

1. Generate credentials

+
python setup_demo.py
+

Creates creds/ with issuer keys, server transport keys, and two members: analyst and operator.

+ +

2. Start the server

+
python server.py
+

HTTP (this page) defaults to 127.0.0.1:8765. ZKAC TCP defaults to 127.0.0.1:9876.

+ +

3. CLI client

+
python client_cli.py --member creds/member_analyst.json
+python client_cli.py --member creds/member_operator.json
+

Each command runs a full handshake and requests {"path":"/api"}. The JSON response lists datasets allowed for that role.

+ + diff --git a/python/zkac/tcp.py b/python/zkac/tcp.py new file mode 100644 index 0000000..e01ded2 --- /dev/null +++ b/python/zkac/tcp.py @@ -0,0 +1,128 @@ +""" +Length-prefixed TCP framing for ZKAC handshakes and encrypted sessions. + +Wire format: each message is ``uint32_le(length) || payload`` with ``length`` +counting only ``payload`` bytes. Handshake payloads match the in-memory protocol +(32-byte init; server reply is ``response_msg || identity_proof``; then auth). +""" + +from __future__ import annotations + +import socket +import struct +from typing import TYPE_CHECKING, Tuple + +from zkac import MAX_BBS_AUTH_PROOF_BYTES + +if TYPE_CHECKING: + from zkac import Credential, Node, PublicKey, RoleRegistry, Session + +# Largest frame: BBS+ auth ciphertext (bound by library) plus handshake/AEAD slack. +MAX_TCP_FRAME_BYTES: int = MAX_BBS_AUTH_PROOF_BYTES + 4096 + +_HANDSHAKE_MSG_LEN = 32 + + +def _read_exact(sock: socket.socket, n: int) -> bytes: + buf = bytearray() + while len(buf) < n: + chunk = sock.recv(n - len(buf)) + if not chunk: + raise ConnectionError("connection closed before read completed") + buf.extend(chunk) + return bytes(buf) + + +def read_frame(sock: socket.socket) -> bytes: + """Read one length-prefixed frame from *sock*.""" + (length,) = struct.unpack(" MAX_TCP_FRAME_BYTES: + raise ValueError(f"frame length {length} exceeds maximum ({MAX_TCP_FRAME_BYTES})") + if length == 0: + return b"" + return _read_exact(sock, length) + + +def write_frame(sock: socket.socket, payload: bytes) -> None: + """Write one length-prefixed frame to *sock*.""" + if len(payload) > MAX_TCP_FRAME_BYTES: + raise ValueError(f"payload length {len(payload)} exceeds maximum ({MAX_TCP_FRAME_BYTES})") + sock.sendall(struct.pack(" Session: + """ + Run the ZKAC client side over *sock* (TCP connected to the server). + + Returns the authenticated :class:`Session` for ``encrypt`` / ``decrypt``. + """ + pending, init_msg = node.connect() + if len(init_msg) != _HANDSHAKE_MSG_LEN: + raise ValueError("internal error: init_msg must be 32 bytes") + write_frame(sock, init_msg) + + bundle = read_frame(sock) + if len(bundle) < _HANDSHAKE_MSG_LEN: + raise ValueError("server handshake bundle too short") + response_msg = bundle[:_HANDSHAKE_MSG_LEN] + identity_proof = bundle[_HANDSHAKE_MSG_LEN:] + + session, auth_packet = node.complete_connect( + pending, response_msg, identity_proof, expected_server_pk, credential + ) + write_frame(sock, auth_packet) + return session + + +def server_handshake( + sock: socket.socket, + node: Node, + registry: RoleRegistry, +) -> Tuple[Session, bytes]: + """ + Run the ZKAC server side over *sock* (accepted TCP connection). + + Returns ``(session, role_id)`` where ``role_id`` is 32 bytes after successful + BBS+ verification. + """ + init_msg = read_frame(sock) + if len(init_msg) != _HANDSHAKE_MSG_LEN: + raise ValueError("init_msg must be 32 bytes") + + session, response_msg = node.accept(init_msg) + if len(response_msg) != _HANDSHAKE_MSG_LEN: + raise ValueError("internal error: response_msg must be 32 bytes") + + identity_proof = node.prove_identity(session) + bundle = response_msg + identity_proof + write_frame(sock, bundle) + + auth_packet = read_frame(sock) + role_id = node.verify_auth(session, auth_packet, registry) + return session, role_id + + +class FramedSession: + """ + One ZKAC ciphertext per TCP frame: encrypt before send, decrypt after recv. + """ + + def __init__(self, sock: socket.socket, session: Session) -> None: + self._sock = sock + self._session = session + + @property + def session(self) -> Session: + return self._session + + def send(self, plaintext: bytes) -> None: + packet = self._session.encrypt(plaintext) + write_frame(self._sock, packet) + + def recv(self) -> bytes: + return self._session.decrypt(read_frame(self._sock)) diff --git a/python/zkac/udp.py b/python/zkac/udp.py new file mode 100644 index 0000000..0a50ca6 --- /dev/null +++ b/python/zkac/udp.py @@ -0,0 +1,177 @@ +""" +Length-prefixed UDP datagram framing for ZKAC handshakes and encrypted sessions. + +Wire format matches :mod:`zkac.tcp`: each datagram is ``uint32_le(length) || payload`` +with *length* counting only *payload* bytes. **One datagram = one frame** (do not +split a frame across packets). + +**Reliability:** UDP is unordered and lossy. This module does not add ACKs or +retransmits. Use TCP (``zkac.tcp``) if you need a reliable stream without +building your own reliability layer. + +**Size:** Large BBS+ auth packets can exceed typical path MTUs (~1500 B). If +``send`` raises ``OSError`` (e.g. ``EMSGSIZE``), use TCP or reduce proof size / +raise MTU on controlled networks. +""" + +from __future__ import annotations + +import socket +import struct +from typing import TYPE_CHECKING, Optional, Tuple + +from zkac import MAX_BBS_AUTH_PROOF_BYTES + +if TYPE_CHECKING: + from zkac import Credential, Node, PublicKey, RoleRegistry, Session + +# Same logical cap as tcp framing; note UDP + large proofs may hit EMSGSIZE on send. +MAX_UDP_FRAME_BYTES: int = MAX_BBS_AUTH_PROOF_BYTES + 4096 + +# IPv4 max UDP payload (theoretical); recv buffer size hint. +MAX_UDP_DATAGRAM_BYTES: int = 65507 + +_HANDSHAKE_MSG_LEN = 32 + + +def _build_framed_datagram(payload: bytes) -> bytes: + if len(payload) > MAX_UDP_FRAME_BYTES: + raise ValueError( + f"payload length {len(payload)} exceeds maximum ({MAX_UDP_FRAME_BYTES})" + ) + return struct.pack(" bytes: + if len(data) < 4: + raise ValueError("datagram too short for length prefix") + (length,) = struct.unpack(" MAX_UDP_FRAME_BYTES: + raise ValueError(f"frame length {length} exceeds maximum ({MAX_UDP_FRAME_BYTES})") + if len(data) != 4 + length: + raise ValueError( + f"datagram size mismatch: expected {4 + length} bytes, got {len(data)}" + ) + return data[4:] if length else b"" + + +def write_datagram(sock: socket.socket, payload: bytes, addr: Optional[tuple] = None) -> None: + """ + Send one framed datagram. If *addr* is ``None``, *sock* must be connected + (e.g. after :meth:`socket.socket.connect`). + """ + packet = _build_framed_datagram(payload) + if len(packet) > MAX_UDP_DATAGRAM_BYTES: + raise ValueError("framed datagram exceeds maximum UDP payload size") + if addr is not None: + sock.sendto(packet, addr) + else: + sock.send(packet) + + +def read_datagram(sock: socket.socket, bufsize: int = MAX_UDP_DATAGRAM_BYTES) -> bytes: + """ + Receive one framed datagram on a **connected** UDP socket (``recv``). + """ + data = sock.recv(bufsize) + if not data: + raise ConnectionError("received empty datagram (peer closed?)") + return _parse_framed_datagram(data) + + +def read_datagram_from( + sock: socket.socket, bufsize: int = MAX_UDP_DATAGRAM_BYTES +) -> Tuple[bytes, tuple]: + """ + Receive one framed datagram on an **unconnected** UDP socket (``recvfrom``). + Returns ``(payload, addr)``. + """ + data, addr = sock.recvfrom(bufsize) + if not data: + raise ConnectionError("received empty datagram") + return _parse_framed_datagram(data), addr + + +def client_handshake( + sock: socket.socket, + server_addr: tuple, + node: Node, + expected_server_pk: PublicKey, + credential: Credential, +) -> Session: + """ + Run the ZKAC client side over UDP. Connects *sock* to *server_addr* and + exchanges three framed datagrams (init → server bundle → auth). + + *server_addr* is ``(host, port)`` for :meth:`socket.socket.connect`. + """ + sock.connect(server_addr) + + pending, init_msg = node.connect() + if len(init_msg) != _HANDSHAKE_MSG_LEN: + raise ValueError("internal error: init_msg must be 32 bytes") + write_datagram(sock, init_msg) + + bundle = read_datagram(sock) + if len(bundle) < _HANDSHAKE_MSG_LEN: + raise ValueError("server handshake bundle too short") + response_msg = bundle[:_HANDSHAKE_MSG_LEN] + identity_proof = bundle[_HANDSHAKE_MSG_LEN:] + + session, auth_packet = node.complete_connect( + pending, response_msg, identity_proof, expected_server_pk, credential + ) + write_datagram(sock, auth_packet) + return session + + +def server_handshake( + sock: socket.socket, + node: Node, + registry: RoleRegistry, +) -> Tuple[Session, bytes, tuple]: + """ + Run the ZKAC server side over UDP. Waits for the first datagram, then + :meth:`socket.socket.connect` to that peer so the rest of the handshake + uses the same path. + + Returns ``(session, role_id, client_addr)``. + """ + init_msg, client_addr = read_datagram_from(sock) + if len(init_msg) != _HANDSHAKE_MSG_LEN: + raise ValueError("init_msg must be 32 bytes") + + sock.connect(client_addr) + + session, response_msg = node.accept(init_msg) + if len(response_msg) != _HANDSHAKE_MSG_LEN: + raise ValueError("internal error: response_msg must be 32 bytes") + + identity_proof = node.prove_identity(session) + bundle = response_msg + identity_proof + write_datagram(sock, bundle) + + auth_packet = read_datagram(sock) + role_id = node.verify_auth(session, auth_packet, registry) + return session, role_id, client_addr + + +class FramedSession: + """ + One ZKAC ciphertext per UDP datagram; *sock* must be connected. + """ + + def __init__(self, sock: socket.socket, session: Session) -> None: + self._sock = sock + self._session = session + + @property + def session(self) -> Session: + return self._session + + def send(self, plaintext: bytes) -> None: + packet = self._session.encrypt(plaintext) + write_datagram(self._sock, packet) + + def recv(self) -> bytes: + return self._session.decrypt(read_datagram(self._sock)) diff --git a/src/credential/schnorr.rs b/src/credential/schnorr.rs index f827d4f..523bdee 100644 --- a/src/credential/schnorr.rs +++ b/src/credential/schnorr.rs @@ -50,6 +50,22 @@ fn challenge(r: &CompressedRistretto, pk: &CompressedRistretto, msg: &[u8]) -> S impl Keypair { pub fn generate(rng: &mut R) -> Self { let scalar = Scalar::random(rng); + Self::from_scalar(scalar) + } + + /// 32-byte canonical encoding of the secret scalar (for persistence). + pub fn secret_key_bytes(&self) -> [u8; 32] { + self.secret.scalar.to_bytes() + } + + /// Restore from [`secret_key_bytes`](Self::secret_key_bytes). + pub fn from_secret_key_bytes(bytes: &[u8; 32]) -> Result { + let scalar = Option::from(Scalar::from_canonical_bytes(*bytes)) + .ok_or_else(|| Error::DeserializationError("invalid secret key scalar"))?; + Ok(Self::from_scalar(scalar)) + } + + fn from_scalar(scalar: Scalar) -> Self { let point = &scalar * RISTRETTO_BASEPOINT_TABLE; Keypair { secret: SecretKey { scalar }, @@ -186,4 +202,13 @@ mod tests { let s2 = kp.sign(b"same msg"); assert_eq!(s1.to_bytes(), s2.to_bytes()); } + + #[test] + fn keypair_secret_roundtrip() { + let kp = Keypair::generate(&mut OsRng); + let bytes = kp.secret_key_bytes(); + let kp2 = Keypair::from_secret_key_bytes(&bytes).unwrap(); + assert_eq!(kp.public().to_bytes(), kp2.public().to_bytes()); + assert_eq!(kp.sign(b"m").to_bytes(), kp2.sign(b"m").to_bytes()); + } } diff --git a/src/python.rs b/src/python.rs index cdbbf05..4157b15 100644 --- a/src/python.rs +++ b/src/python.rs @@ -46,6 +46,25 @@ impl PyKeypair { let sig = kp.sign(msg); Ok(PyBytes::new(py, &sig.to_bytes())) } + + fn secret_key_bytes<'py>(&self, py: Python<'py>) -> PyResult> { + let kp = self.inner.as_ref().ok_or_else(|| { + PyValueError::new_err("keypair was consumed by Node") + })?; + Ok(PyBytes::new(py, &kp.secret_key_bytes())) + } + + #[staticmethod] + fn from_secret_key(bytes: &[u8]) -> PyResult { + if bytes.len() != 32 { + return Err(PyValueError::new_err("secret key must be 32 bytes")); + } + let arr: [u8; 32] = bytes.try_into().unwrap(); + let inner = credential::Keypair::from_secret_key_bytes(&arr).map_err(to_py_err)?; + Ok(PyKeypair { + inner: Some(inner), + }) + } } // ── Ristretto PublicKey ────────────────────────────────────────────── diff --git a/tests/test_zkac.py b/tests/test_zkac.py index a486bf1..fb5951a 100644 --- a/tests/test_zkac.py +++ b/tests/test_zkac.py @@ -22,6 +22,13 @@ class TestKeypairAndPublicKey: assert r.startswith("PublicKey(") assert len(r) == len("PublicKey()") + 64 + def test_secret_key_roundtrip(self): + kp = zkac.Keypair() + sk = kp.secret_key_bytes() + assert len(sk) == 32 + kp2 = zkac.Keypair.from_secret_key(sk) + assert kp.public_key().to_bytes() == kp2.public_key().to_bytes() + def test_different_keypairs_different_pubkeys(self): pk1 = zkac.Keypair().public_key() pk2 = zkac.Keypair().public_key() diff --git a/tests/test_zkac_tcp.py b/tests/test_zkac_tcp.py new file mode 100644 index 0000000..1ca7900 --- /dev/null +++ b/tests/test_zkac_tcp.py @@ -0,0 +1,113 @@ +import socket +import threading + +import pytest + +import zkac +from zkac.tcp import ( + FramedSession, + MAX_TCP_FRAME_BYTES, + client_handshake, + read_frame, + server_handshake, + write_frame, +) + + +def _make_credential(): + issuer = zkac.BbsIssuer() + pk = issuer.public_key() + rid = zkac.role_id("admin") + req = zkac.prepare_blind_request() + sig = issuer.issue_blind(req.commitment_with_proof(), rid, 1) + cred = zkac.Credential.finalize( + sig, req.member_secret(), req.prover_blind(), rid, 1, pk + ) + return issuer, pk, rid, cred + + +class TestFraming: + def test_read_write_roundtrip(self): + a, b = socket.socketpair() + try: + payload = b"hello" * 400 + write_frame(a, payload) + assert read_frame(b) == payload + finally: + a.close() + b.close() + + def test_oversized_length_rejected(self): + a, b = socket.socketpair() + try: + a.sendall((MAX_TCP_FRAME_BYTES + 1).to_bytes(4, "little")) + with pytest.raises(ValueError, match="exceeds maximum"): + read_frame(b) + finally: + a.close() + b.close() + + +class TestHandshakeOverTcp: + def test_full_handshake_matching_keys(self): + _, pk, rid, cred = _make_credential() + reg = zkac.RoleRegistry() + reg.register_role(rid, pk, 1) + + client_sock, server_sock = socket.socketpair() + server_kp = zkac.Keypair() + server_pk = server_kp.public_key() + + def run_server(): + try: + srv = zkac.Node(server_kp) + s, verified = server_handshake(server_sock, srv, reg) + assert verified == rid + pkt = s.encrypt(b"admin command") + write_frame(server_sock, pkt) + finally: + server_sock.close() + + t = threading.Thread(target=run_server) + t.start() + try: + cli = zkac.Node(zkac.Keypair()) + session = client_handshake(client_sock, cli, server_pk, cred) + wire = read_frame(client_sock) + assert session.decrypt(wire) == b"admin command" + finally: + client_sock.close() + t.join(timeout=5) + assert not t.is_alive() + + +class TestFramedSession: + def test_framed_encrypt_roundtrip(self): + _, pk, rid, cred = _make_credential() + reg = zkac.RoleRegistry() + reg.register_role(rid, pk, 1) + + client_sock, server_sock = socket.socketpair() + server_kp = zkac.Keypair() + server_pk = server_kp.public_key() + + def run_server(): + try: + srv = zkac.Node(server_kp) + session, _ = server_handshake(server_sock, srv, reg) + framed = FramedSession(server_sock, session) + framed.send(b"reply") + finally: + server_sock.close() + + t = threading.Thread(target=run_server) + t.start() + try: + cli = zkac.Node(zkac.Keypair()) + session = client_handshake(client_sock, cli, server_pk, cred) + framed = FramedSession(client_sock, session) + assert framed.recv() == b"reply" + finally: + client_sock.close() + t.join(timeout=5) + assert not t.is_alive() diff --git a/tests/test_zkac_udp.py b/tests/test_zkac_udp.py new file mode 100644 index 0000000..345b655 --- /dev/null +++ b/tests/test_zkac_udp.py @@ -0,0 +1,83 @@ +import socket +import threading + +import zkac +from zkac.udp import ( + FramedSession, + client_handshake, + read_datagram, + server_handshake, + write_datagram, +) + + +def _make_credential(): + issuer = zkac.BbsIssuer() + pk = issuer.public_key() + rid = zkac.role_id("admin") + req = zkac.prepare_blind_request() + sig = issuer.issue_blind(req.commitment_with_proof(), rid, 1) + cred = zkac.Credential.finalize( + sig, req.member_secret(), req.prover_blind(), rid, 1, pk + ) + return pk, rid, cred + + +class TestFraming: + def test_write_read_connected(self): + a, b = socket.socketpair(socket.AF_UNIX, socket.SOCK_DGRAM) + try: + payload = b"udp-framed" + write_datagram(a, payload) + assert read_datagram(b) == payload + finally: + a.close() + b.close() + + +class TestHandshakeOverUdp: + def test_full_handshake_localhost(self): + pk, rid, cred = _make_credential() + reg = zkac.RoleRegistry() + reg.register_role(rid, pk, 1) + + server_kp = zkac.Keypair() + server_pk = server_kp.public_key() + client_kp = zkac.Keypair() + + ready = threading.Event() + port_holder: list[int] = [] + err: list[BaseException] = [] + + def run_server(): + srv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + try: + srv.bind(("127.0.0.1", 0)) + port_holder.append(srv.getsockname()[1]) + ready.set() + node = zkac.Node(server_kp) + session, verified, _addr = server_handshake(srv, node, reg) + assert verified == rid + framed = FramedSession(srv, session) + framed.send(b"pong: " + framed.recv()) + except BaseException as e: + err.append(e) + finally: + srv.close() + + t = threading.Thread(target=run_server, daemon=True) + t.start() + ready.wait() + cli = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + sess = client_handshake( + cli, ("127.0.0.1", port_holder[0]), zkac.Node(client_kp), server_pk, cred + ) + cf = FramedSession(cli, sess) + cf.send(b"ping") + assert cf.recv() == b"pong: ping" + finally: + cli.close() + t.join(timeout=5.0) + assert not err, err