223 lines
6.7 KiB
Python
223 lines
6.7 KiB
Python
"""
|
|
Length-prefixed TCP framing for ZKAC handshakes and encrypted sessions.
|
|
|
|
Wire format: each message is ``uint32_le(length) || payload`` with ``length``
|
|
counting only ``payload`` bytes. Handshake payloads match the in-memory protocol
|
|
(32-byte init; server reply is ``response_msg || identity_proof``; then auth).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import socket
|
|
import struct
|
|
from typing import TYPE_CHECKING, Tuple
|
|
|
|
from zkac import MAX_BBS_AUTH_PROOF_BYTES
|
|
|
|
if TYPE_CHECKING:
|
|
from zkac import Credential, Node, PublicKey, RegistryManager, RoleRegistry, Session
|
|
|
|
# Largest frame: BBS+ auth ciphertext (bound by library) plus handshake/AEAD slack.
|
|
MAX_TCP_FRAME_BYTES: int = MAX_BBS_AUTH_PROOF_BYTES + 4096
|
|
|
|
_HANDSHAKE_MSG_LEN = 32
|
|
|
|
|
|
def _read_exact(sock: socket.socket, n: int) -> bytes:
|
|
buf = bytearray()
|
|
while len(buf) < n:
|
|
chunk = sock.recv(n - len(buf))
|
|
if not chunk:
|
|
raise ConnectionError("connection closed before read completed")
|
|
buf.extend(chunk)
|
|
return bytes(buf)
|
|
|
|
|
|
def read_frame(sock: socket.socket) -> bytes:
|
|
"""Read one length-prefixed frame from *sock*."""
|
|
(length,) = struct.unpack("<I", _read_exact(sock, 4))
|
|
if length > MAX_TCP_FRAME_BYTES:
|
|
raise ValueError(f"frame length {length} exceeds maximum ({MAX_TCP_FRAME_BYTES})")
|
|
if length == 0:
|
|
return b""
|
|
return _read_exact(sock, length)
|
|
|
|
|
|
def write_frame(sock: socket.socket, payload: bytes) -> None:
|
|
"""Write one length-prefixed frame to *sock*."""
|
|
if len(payload) > MAX_TCP_FRAME_BYTES:
|
|
raise ValueError(f"payload length {len(payload)} exceeds maximum ({MAX_TCP_FRAME_BYTES})")
|
|
sock.sendall(struct.pack("<I", len(payload)) + payload)
|
|
|
|
|
|
def client_handshake(
|
|
sock: socket.socket,
|
|
node: Node,
|
|
expected_server_pk: PublicKey,
|
|
credential: Credential,
|
|
) -> Session:
|
|
"""
|
|
Run the ZKAC client side over *sock* (TCP connected to the server).
|
|
|
|
Returns the authenticated :class:`Session` for ``encrypt`` / ``decrypt``.
|
|
"""
|
|
pending, init_msg = node.connect()
|
|
if len(init_msg) != _HANDSHAKE_MSG_LEN:
|
|
raise ValueError("internal error: init_msg must be 32 bytes")
|
|
write_frame(sock, init_msg)
|
|
|
|
bundle = read_frame(sock)
|
|
if len(bundle) < _HANDSHAKE_MSG_LEN:
|
|
raise ValueError("server handshake bundle too short")
|
|
response_msg = bundle[:_HANDSHAKE_MSG_LEN]
|
|
identity_proof = bundle[_HANDSHAKE_MSG_LEN:]
|
|
|
|
session, auth_packet = node.complete_connect(
|
|
pending, response_msg, identity_proof, expected_server_pk, credential
|
|
)
|
|
write_frame(sock, auth_packet)
|
|
return session
|
|
|
|
|
|
def server_handshake(
|
|
sock: socket.socket,
|
|
node: Node,
|
|
registry: RoleRegistry,
|
|
) -> Tuple[Session, bytes]:
|
|
"""
|
|
Run the ZKAC server side over *sock* (accepted TCP connection).
|
|
|
|
Returns ``(session, role_id)`` where ``role_id`` is 32 bytes after successful
|
|
BBS+ verification.
|
|
"""
|
|
init_msg = read_frame(sock)
|
|
if len(init_msg) != _HANDSHAKE_MSG_LEN:
|
|
raise ValueError("init_msg must be 32 bytes")
|
|
|
|
session, response_msg = node.accept(init_msg)
|
|
if len(response_msg) != _HANDSHAKE_MSG_LEN:
|
|
raise ValueError("internal error: response_msg must be 32 bytes")
|
|
|
|
identity_proof = node.prove_identity(session)
|
|
bundle = response_msg + identity_proof
|
|
write_frame(sock, bundle)
|
|
|
|
auth_packet = read_frame(sock)
|
|
role_id = node.verify_auth(session, auth_packet, registry)
|
|
return session, role_id
|
|
|
|
|
|
def client_handshake_managed(
|
|
sock: socket.socket,
|
|
node: Node,
|
|
expected_server_pk: PublicKey,
|
|
credential: Credential,
|
|
registry_id: bytes,
|
|
) -> Session:
|
|
"""
|
|
Client handshake against a client-managed registry.
|
|
Like :func:`client_handshake` but includes ``registry_id`` in the auth packet.
|
|
"""
|
|
pending, init_msg = node.connect()
|
|
write_frame(sock, init_msg)
|
|
|
|
bundle = read_frame(sock)
|
|
if len(bundle) < _HANDSHAKE_MSG_LEN:
|
|
raise ValueError("server handshake bundle too short")
|
|
response_msg = bundle[:_HANDSHAKE_MSG_LEN]
|
|
identity_proof = bundle[_HANDSHAKE_MSG_LEN:]
|
|
|
|
session, auth_packet = node.complete_connect_managed(
|
|
pending, response_msg, identity_proof, expected_server_pk, credential, registry_id
|
|
)
|
|
write_frame(sock, auth_packet)
|
|
return session
|
|
|
|
|
|
def server_handshake_managed(
|
|
sock: socket.socket,
|
|
node: Node,
|
|
manager: RegistryManager,
|
|
) -> Tuple[Session, bytes, bytes]:
|
|
"""
|
|
Server handshake with a :class:`RegistryManager`.
|
|
|
|
Returns ``(session, registry_id, role_id)``.
|
|
"""
|
|
init_msg = read_frame(sock)
|
|
if len(init_msg) != _HANDSHAKE_MSG_LEN:
|
|
raise ValueError("init_msg must be 32 bytes")
|
|
|
|
session, response_msg = node.accept(init_msg)
|
|
identity_proof = node.prove_identity(session)
|
|
write_frame(sock, response_msg + identity_proof)
|
|
|
|
auth_packet = read_frame(sock)
|
|
registry_id, role_id = node.verify_auth_managed(session, auth_packet, manager)
|
|
return session, registry_id, role_id
|
|
|
|
|
|
def client_handshake_anon(
|
|
sock: socket.socket,
|
|
node: Node,
|
|
expected_server_pk: PublicKey,
|
|
) -> Session:
|
|
"""
|
|
Anonymous client handshake: verify server identity only, no BBS+ auth.
|
|
|
|
Returns an encrypted :class:`Session` for management traffic.
|
|
"""
|
|
pending, init_msg = node.connect()
|
|
write_frame(sock, init_msg)
|
|
|
|
bundle = read_frame(sock)
|
|
if len(bundle) < _HANDSHAKE_MSG_LEN:
|
|
raise ValueError("server handshake bundle too short")
|
|
response_msg = bundle[:_HANDSHAKE_MSG_LEN]
|
|
identity_proof = bundle[_HANDSHAKE_MSG_LEN:]
|
|
|
|
session = node.complete_connect_anon(
|
|
pending, response_msg, identity_proof, expected_server_pk,
|
|
)
|
|
return session
|
|
|
|
|
|
def server_handshake_anon(
|
|
sock: socket.socket,
|
|
node: Node,
|
|
) -> Session:
|
|
"""
|
|
Server-side anonymous handshake: prove identity, no BBS+ verification.
|
|
|
|
Returns an encrypted :class:`Session`.
|
|
"""
|
|
init_msg = read_frame(sock)
|
|
if len(init_msg) != _HANDSHAKE_MSG_LEN:
|
|
raise ValueError("init_msg must be 32 bytes")
|
|
|
|
session, response_msg = node.accept(init_msg)
|
|
identity_proof = node.prove_identity(session)
|
|
write_frame(sock, response_msg + identity_proof)
|
|
return session
|
|
|
|
|
|
class FramedSession:
|
|
"""
|
|
One ZKAC ciphertext per TCP frame: encrypt before send, decrypt after recv.
|
|
"""
|
|
|
|
def __init__(self, sock: socket.socket, session: Session) -> None:
|
|
self._sock = sock
|
|
self._session = session
|
|
|
|
@property
|
|
def session(self) -> Session:
|
|
return self._session
|
|
|
|
def send(self, plaintext: bytes) -> None:
|
|
packet = self._session.encrypt(plaintext)
|
|
write_frame(self._sock, packet)
|
|
|
|
def recv(self) -> bytes:
|
|
return self._session.decrypt(read_frame(self._sock))
|