//! SimplePIR-based single-server Private Information Retrieval. //! //! Implements the first layer of DoublePIR (Henzinger–Hong–Corrigan-Gibbs– //! Meiklejohn–Vaikuntanathan, USENIX Security '23). For full-record retrieval //! the second compression layer is unnecessary — the client needs the entire //! column — so this is equivalent to SimplePIR. The second layer can be added //! as an optimisation for very large pools without API changes. //! //! Security: decisional LWE with n=1024, q=2^32, σ=6.4 (128-bit classical). use blake2::Blake2b512; use digest::Digest; use rand::Rng; use rand::rngs::OsRng; use super::params::*; use super::lwe; use super::db::Database; // ── Hints (public matrix seed + precomputed H = D · A^T) ──────────── pub struct Hints { seed: [u8; 32], n_records: usize, cells_per_record: usize, /// Row-major `cells_per_record × N_LWE` matrix (mod 2^32). hint: Vec, } impl Hints { pub fn n_records(&self) -> usize { self.n_records } pub fn cells_per_record(&self) -> usize { self.cells_per_record } pub fn seed(&self) -> &[u8; 32] { &self.seed } pub fn hint_matrix(&self) -> &[u32] { &self.hint } pub fn serialize(&self) -> Vec { let n = self.cells_per_record * N_LWE; let mut buf = Vec::with_capacity(32 + 8 + 8 + n * 4); buf.extend_from_slice(&self.seed); buf.extend_from_slice(&(self.n_records as u64).to_le_bytes()); buf.extend_from_slice(&(self.cells_per_record as u64).to_le_bytes()); for &v in &self.hint { buf.extend_from_slice(&v.to_le_bytes()); } buf } pub fn deserialize(data: &[u8]) -> Result { if data.len() < 48 { return Err("hint data too short"); } let mut seed = [0u8; 32]; seed.copy_from_slice(&data[..32]); let n_records = u64::from_le_bytes(data[32..40].try_into().unwrap()) as usize; let cells_per_record = u64::from_le_bytes(data[40..48].try_into().unwrap()) as usize; let n = cells_per_record.checked_mul(N_LWE).ok_or("overflow")?; if data.len() != 48 + n * 4 { return Err("hint data length mismatch"); } let hint: Vec = data[48..] .chunks_exact(4) .map(|c| u32::from_le_bytes(c.try_into().unwrap())) .collect(); Ok(Hints { seed, n_records, cells_per_record, hint }) } pub fn version(&self) -> [u8; 32] { let mut h = Blake2b512::new(); h.update(b"zkac-pir-hints-v1"); h.update(self.seed); h.update((self.n_records as u64).to_le_bytes()); h.update((self.cells_per_record as u64).to_le_bytes()); for &v in &self.hint { h.update(v.to_le_bytes()); } let digest = h.finalize(); let mut out = [0u8; 32]; out.copy_from_slice(&digest[..32]); out } } // ── Server ────────────────────────────────────────────────────────── pub struct Server { db: Database, hints: Hints, } impl Server { /// Build server state: generates a random public matrix A and precomputes /// the hint H = D · A^T. This is the expensive offline phase. pub fn new(db: Database) -> Self { let m = db.n_records(); let ell = db.cells_per_record(); let mut seed = [0u8; 32]; OsRng.fill(&mut seed); let hint = if m == 0 { Vec::new() } else { let a = lwe::gen_matrix(&seed, N_LWE, m); lwe::mat_mul_bt(db.data(), &a, ell, m, N_LWE) }; Server { hints: Hints { seed, n_records: m, cells_per_record: ell, hint }, db, } } pub fn hints(&self) -> &Hints { &self.hints } pub fn version(&self) -> [u8; 32] { self.hints.version() } /// Compute answer = D · query (mod 2^32). The query vector has `n_records` /// entries; the answer has `cells_per_record` entries. pub fn answer(&self, query: &[u32]) -> Vec { lwe::mat_vec_mul( self.db.data(), query, self.db.cells_per_record(), self.db.n_records(), ) } pub fn n_records(&self) -> usize { self.db.n_records() } pub fn record_bytes(&self) -> usize { self.db.record_bytes() } } // ── Client ────────────────────────────────────────────────────────── pub struct ClientState { secret: Vec, } pub struct Client { hints: Hints, } impl Client { pub fn new(hints: Hints) -> Self { Client { hints } } pub fn version(&self) -> [u8; 32] { self.hints.version() } pub fn n_records(&self) -> usize { self.hints.n_records } pub fn record_bytes(&self) -> usize { self.hints.cells_per_record } /// Generate a PIR query for `index`. Returns the query vector (to send to /// the server) and opaque client state (needed for decoding the answer). pub fn query(&self, index: usize) -> (Vec, ClientState) { assert!(index < self.hints.n_records, "PIR query index out of range"); let mut rng = OsRng; let s = lwe::sample_uniform_vec(&mut rng, N_LWE); let e = lwe::sample_error_vec(&mut rng, self.hints.n_records); // q = A^T · s + e (A is N_LWE × n_records) let a = lwe::gen_matrix(self.hints.seed(), N_LWE, self.hints.n_records); let mut q = lwe::mat_t_vec_mul(&a, &s, N_LWE, self.hints.n_records); for j in 0..self.hints.n_records { q[j] = q[j].wrapping_add(e[j]); } q[index] = q[index].wrapping_add(DELTA); (q, ClientState { secret: s }) } /// Decode the server's answer using the saved client state. /// Returns the `record_bytes`-length plaintext record. pub fn decode(&self, answer: &[u32], state: &ClientState) -> Vec { // h_s = H · s (cells_per_record × N_LWE) · (N_LWE) → cells_per_record let h_s = lwe::mat_vec_mul( self.hints.hint_matrix(), &state.secret, self.hints.cells_per_record, N_LWE, ); (0..self.hints.cells_per_record) .map(|i| lwe::round_to_plaintext(answer[i].wrapping_sub(h_s[i]))) .collect() } } // ── Serialization helpers for query / answer vectors ──────────────── pub fn serialize_vec(v: &[u32]) -> Vec { let mut buf = Vec::with_capacity(v.len() * 4); for &val in v { buf.extend_from_slice(&val.to_le_bytes()); } buf } pub fn deserialize_vec(data: &[u8]) -> Vec { data.chunks_exact(4) .map(|c| u32::from_le_bytes(c.try_into().unwrap())) .collect() } #[cfg(test)] mod tests { use super::*; fn make_test_records(n: usize, rec_bytes: usize) -> Vec> { (0..n).map(|i| { let mut rec = vec![0u8; rec_bytes]; rec[0] = i as u8; rec[1] = (i.wrapping_mul(37) & 0xFF) as u8; if rec_bytes > 2 { rec[rec_bytes - 1] = 0xAA; } rec }).collect() } #[test] fn roundtrip_small() { let records = make_test_records(8, 128); let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect(); let db = Database::new(&refs, 128); let server = Server::new(db); let hints_bytes = server.hints().serialize(); let client = Client::new(Hints::deserialize(&hints_bytes).unwrap()); for target in 0..8 { let (query, state) = client.query(target); let answer = server.answer(&query); let decoded = client.decode(&answer, &state); assert_eq!(decoded, records[target]); } } #[test] fn roundtrip_serialized_query_answer() { let records = make_test_records(4, 64); let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect(); let db = Database::new(&refs, 64); let server = Server::new(db); let client = Client::new(Hints::deserialize(&server.hints().serialize()).unwrap()); let (q, state) = client.query(2); let q_bytes = serialize_vec(&q); let q2 = deserialize_vec(&q_bytes); assert_eq!(q, q2); let ans = server.answer(&q2); let ans_bytes = serialize_vec(&ans); let ans2 = deserialize_vec(&ans_bytes); let decoded = client.decode(&ans2, &state); assert_eq!(decoded, records[2]); } #[test] fn version_changes_on_rebuild() { let records = make_test_records(4, 64); let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect(); let db1 = Database::new(&refs, 64); let s1 = Server::new(db1); let v1 = s1.version(); let db2 = Database::new(&refs, 64); let s2 = Server::new(db2); let v2 = s2.version(); // Different seeds → different versions (with overwhelming probability). assert_ne!(v1, v2); } #[test] fn hints_roundtrip() { let records = make_test_records(4, 64); let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect(); let db = Database::new(&refs, 64); let server = Server::new(db); let original = server.hints(); let bytes = original.serialize(); let restored = Hints::deserialize(&bytes).unwrap(); assert_eq!(original.seed(), restored.seed()); assert_eq!(original.n_records(), restored.n_records()); assert_eq!(original.cells_per_record(), restored.cells_per_record()); assert_eq!(original.hint_matrix(), restored.hint_matrix()); assert_eq!(original.version(), restored.version()); } }