ZKAC/src/pir/doublepir.rs
2026-04-19 14:26:47 +02:00

288 lines
9.7 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! SimplePIR-based single-server Private Information Retrieval.
//!
//! Implements the first layer of DoublePIR (HenzingerHongCorrigan-Gibbs
//! MeiklejohnVaikuntanathan, USENIX Security '23). For full-record retrieval
//! the second compression layer is unnecessary — the client needs the entire
//! column — so this is equivalent to SimplePIR. The second layer can be added
//! as an optimisation for very large pools without API changes.
//!
//! Security: decisional LWE with n=1024, q=2^32, σ=6.4 (128-bit classical).
use blake2::Blake2b512;
use digest::Digest;
use rand::Rng;
use rand::rngs::OsRng;
use super::params::*;
use super::lwe;
use super::db::Database;
// ── Hints (public matrix seed + precomputed H = D · A^T) ────────────
pub struct Hints {
seed: [u8; 32],
n_records: usize,
cells_per_record: usize,
/// Row-major `cells_per_record × N_LWE` matrix (mod 2^32).
hint: Vec<u32>,
}
impl Hints {
pub fn n_records(&self) -> usize { self.n_records }
pub fn cells_per_record(&self) -> usize { self.cells_per_record }
pub fn seed(&self) -> &[u8; 32] { &self.seed }
pub fn hint_matrix(&self) -> &[u32] { &self.hint }
pub fn serialize(&self) -> Vec<u8> {
let n = self.cells_per_record * N_LWE;
let mut buf = Vec::with_capacity(32 + 8 + 8 + n * 4);
buf.extend_from_slice(&self.seed);
buf.extend_from_slice(&(self.n_records as u64).to_le_bytes());
buf.extend_from_slice(&(self.cells_per_record as u64).to_le_bytes());
for &v in &self.hint {
buf.extend_from_slice(&v.to_le_bytes());
}
buf
}
pub fn deserialize(data: &[u8]) -> Result<Self, &'static str> {
if data.len() < 48 {
return Err("hint data too short");
}
let mut seed = [0u8; 32];
seed.copy_from_slice(&data[..32]);
let n_records = u64::from_le_bytes(data[32..40].try_into().unwrap()) as usize;
let cells_per_record = u64::from_le_bytes(data[40..48].try_into().unwrap()) as usize;
let n = cells_per_record.checked_mul(N_LWE).ok_or("overflow")?;
if data.len() != 48 + n * 4 {
return Err("hint data length mismatch");
}
let hint: Vec<u32> = data[48..]
.chunks_exact(4)
.map(|c| u32::from_le_bytes(c.try_into().unwrap()))
.collect();
Ok(Hints { seed, n_records, cells_per_record, hint })
}
pub fn version(&self) -> [u8; 32] {
let mut h = Blake2b512::new();
h.update(b"zkac-pir-hints-v1");
h.update(self.seed);
h.update((self.n_records as u64).to_le_bytes());
h.update((self.cells_per_record as u64).to_le_bytes());
for &v in &self.hint {
h.update(v.to_le_bytes());
}
let digest = h.finalize();
let mut out = [0u8; 32];
out.copy_from_slice(&digest[..32]);
out
}
}
// ── Server ──────────────────────────────────────────────────────────
pub struct Server {
db: Database,
hints: Hints,
}
impl Server {
/// Build server state: generates a random public matrix A and precomputes
/// the hint H = D · A^T. This is the expensive offline phase.
pub fn new(db: Database) -> Self {
let m = db.n_records();
let ell = db.cells_per_record();
let mut seed = [0u8; 32];
OsRng.fill(&mut seed);
let hint = if m == 0 {
Vec::new()
} else {
let a = lwe::gen_matrix(&seed, N_LWE, m);
lwe::mat_mul_bt(db.data(), &a, ell, m, N_LWE)
};
Server {
hints: Hints { seed, n_records: m, cells_per_record: ell, hint },
db,
}
}
pub fn hints(&self) -> &Hints { &self.hints }
pub fn version(&self) -> [u8; 32] { self.hints.version() }
/// Compute answer = D · query (mod 2^32). The query vector has `n_records`
/// entries; the answer has `cells_per_record` entries.
pub fn answer(&self, query: &[u32]) -> Vec<u32> {
lwe::mat_vec_mul(
self.db.data(), query,
self.db.cells_per_record(), self.db.n_records(),
)
}
pub fn n_records(&self) -> usize { self.db.n_records() }
pub fn record_bytes(&self) -> usize { self.db.record_bytes() }
}
// ── Client ──────────────────────────────────────────────────────────
pub struct ClientState {
secret: Vec<u32>,
}
pub struct Client {
hints: Hints,
}
impl Client {
pub fn new(hints: Hints) -> Self {
Client { hints }
}
pub fn version(&self) -> [u8; 32] { self.hints.version() }
pub fn n_records(&self) -> usize { self.hints.n_records }
pub fn record_bytes(&self) -> usize { self.hints.cells_per_record }
/// Generate a PIR query for `index`. Returns the query vector (to send to
/// the server) and opaque client state (needed for decoding the answer).
pub fn query(&self, index: usize) -> (Vec<u32>, ClientState) {
assert!(index < self.hints.n_records, "PIR query index out of range");
let mut rng = OsRng;
let s = lwe::sample_uniform_vec(&mut rng, N_LWE);
let e = lwe::sample_error_vec(&mut rng, self.hints.n_records);
// q = A^T · s + e (A is N_LWE × n_records)
let a = lwe::gen_matrix(self.hints.seed(), N_LWE, self.hints.n_records);
let mut q = lwe::mat_t_vec_mul(&a, &s, N_LWE, self.hints.n_records);
for j in 0..self.hints.n_records {
q[j] = q[j].wrapping_add(e[j]);
}
q[index] = q[index].wrapping_add(DELTA);
(q, ClientState { secret: s })
}
/// Decode the server's answer using the saved client state.
/// Returns the `record_bytes`-length plaintext record.
pub fn decode(&self, answer: &[u32], state: &ClientState) -> Vec<u8> {
// h_s = H · s (cells_per_record × N_LWE) · (N_LWE) → cells_per_record
let h_s = lwe::mat_vec_mul(
self.hints.hint_matrix(), &state.secret,
self.hints.cells_per_record, N_LWE,
);
(0..self.hints.cells_per_record)
.map(|i| lwe::round_to_plaintext(answer[i].wrapping_sub(h_s[i])))
.collect()
}
}
// ── Serialization helpers for query / answer vectors ────────────────
pub fn serialize_vec(v: &[u32]) -> Vec<u8> {
let mut buf = Vec::with_capacity(v.len() * 4);
for &val in v {
buf.extend_from_slice(&val.to_le_bytes());
}
buf
}
pub fn deserialize_vec(data: &[u8]) -> Vec<u32> {
data.chunks_exact(4)
.map(|c| u32::from_le_bytes(c.try_into().unwrap()))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_records(n: usize, rec_bytes: usize) -> Vec<Vec<u8>> {
(0..n).map(|i| {
let mut rec = vec![0u8; rec_bytes];
rec[0] = i as u8;
rec[1] = (i.wrapping_mul(37) & 0xFF) as u8;
if rec_bytes > 2 {
rec[rec_bytes - 1] = 0xAA;
}
rec
}).collect()
}
#[test]
fn roundtrip_small() {
let records = make_test_records(8, 128);
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let db = Database::new(&refs, 128);
let server = Server::new(db);
let hints_bytes = server.hints().serialize();
let client = Client::new(Hints::deserialize(&hints_bytes).unwrap());
for target in 0..8 {
let (query, state) = client.query(target);
let answer = server.answer(&query);
let decoded = client.decode(&answer, &state);
assert_eq!(decoded, records[target]);
}
}
#[test]
fn roundtrip_serialized_query_answer() {
let records = make_test_records(4, 64);
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let db = Database::new(&refs, 64);
let server = Server::new(db);
let client = Client::new(Hints::deserialize(&server.hints().serialize()).unwrap());
let (q, state) = client.query(2);
let q_bytes = serialize_vec(&q);
let q2 = deserialize_vec(&q_bytes);
assert_eq!(q, q2);
let ans = server.answer(&q2);
let ans_bytes = serialize_vec(&ans);
let ans2 = deserialize_vec(&ans_bytes);
let decoded = client.decode(&ans2, &state);
assert_eq!(decoded, records[2]);
}
#[test]
fn version_changes_on_rebuild() {
let records = make_test_records(4, 64);
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let db1 = Database::new(&refs, 64);
let s1 = Server::new(db1);
let v1 = s1.version();
let db2 = Database::new(&refs, 64);
let s2 = Server::new(db2);
let v2 = s2.version();
// Different seeds → different versions (with overwhelming probability).
assert_ne!(v1, v2);
}
#[test]
fn hints_roundtrip() {
let records = make_test_records(4, 64);
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
let db = Database::new(&refs, 64);
let server = Server::new(db);
let original = server.hints();
let bytes = original.serialize();
let restored = Hints::deserialize(&bytes).unwrap();
assert_eq!(original.seed(), restored.seed());
assert_eq!(original.n_records(), restored.n_records());
assert_eq!(original.cells_per_record(), restored.cells_per_record());
assert_eq!(original.hint_matrix(), restored.hint_matrix());
assert_eq!(original.version(), restored.version());
}
}