288 lines
9.7 KiB
Rust
288 lines
9.7 KiB
Rust
//! SimplePIR-based single-server Private Information Retrieval.
|
||
//!
|
||
//! Implements the first layer of DoublePIR (Henzinger–Hong–Corrigan-Gibbs–
|
||
//! Meiklejohn–Vaikuntanathan, USENIX Security '23). For full-record retrieval
|
||
//! the second compression layer is unnecessary — the client needs the entire
|
||
//! column — so this is equivalent to SimplePIR. The second layer can be added
|
||
//! as an optimisation for very large pools without API changes.
|
||
//!
|
||
//! Security: decisional LWE with n=1024, q=2^32, σ=6.4 (128-bit classical).
|
||
|
||
use blake2::Blake2b512;
|
||
use digest::Digest;
|
||
use rand::Rng;
|
||
use rand::rngs::OsRng;
|
||
|
||
use super::params::*;
|
||
use super::lwe;
|
||
use super::db::Database;
|
||
|
||
// ── Hints (public matrix seed + precomputed H = D · A^T) ────────────
|
||
|
||
pub struct Hints {
|
||
seed: [u8; 32],
|
||
n_records: usize,
|
||
cells_per_record: usize,
|
||
/// Row-major `cells_per_record × N_LWE` matrix (mod 2^32).
|
||
hint: Vec<u32>,
|
||
}
|
||
|
||
impl Hints {
|
||
pub fn n_records(&self) -> usize { self.n_records }
|
||
pub fn cells_per_record(&self) -> usize { self.cells_per_record }
|
||
pub fn seed(&self) -> &[u8; 32] { &self.seed }
|
||
pub fn hint_matrix(&self) -> &[u32] { &self.hint }
|
||
|
||
pub fn serialize(&self) -> Vec<u8> {
|
||
let n = self.cells_per_record * N_LWE;
|
||
let mut buf = Vec::with_capacity(32 + 8 + 8 + n * 4);
|
||
buf.extend_from_slice(&self.seed);
|
||
buf.extend_from_slice(&(self.n_records as u64).to_le_bytes());
|
||
buf.extend_from_slice(&(self.cells_per_record as u64).to_le_bytes());
|
||
for &v in &self.hint {
|
||
buf.extend_from_slice(&v.to_le_bytes());
|
||
}
|
||
buf
|
||
}
|
||
|
||
pub fn deserialize(data: &[u8]) -> Result<Self, &'static str> {
|
||
if data.len() < 48 {
|
||
return Err("hint data too short");
|
||
}
|
||
let mut seed = [0u8; 32];
|
||
seed.copy_from_slice(&data[..32]);
|
||
let n_records = u64::from_le_bytes(data[32..40].try_into().unwrap()) as usize;
|
||
let cells_per_record = u64::from_le_bytes(data[40..48].try_into().unwrap()) as usize;
|
||
let n = cells_per_record.checked_mul(N_LWE).ok_or("overflow")?;
|
||
if data.len() != 48 + n * 4 {
|
||
return Err("hint data length mismatch");
|
||
}
|
||
let hint: Vec<u32> = data[48..]
|
||
.chunks_exact(4)
|
||
.map(|c| u32::from_le_bytes(c.try_into().unwrap()))
|
||
.collect();
|
||
Ok(Hints { seed, n_records, cells_per_record, hint })
|
||
}
|
||
|
||
pub fn version(&self) -> [u8; 32] {
|
||
let mut h = Blake2b512::new();
|
||
h.update(b"zkac-pir-hints-v1");
|
||
h.update(self.seed);
|
||
h.update((self.n_records as u64).to_le_bytes());
|
||
h.update((self.cells_per_record as u64).to_le_bytes());
|
||
for &v in &self.hint {
|
||
h.update(v.to_le_bytes());
|
||
}
|
||
let digest = h.finalize();
|
||
let mut out = [0u8; 32];
|
||
out.copy_from_slice(&digest[..32]);
|
||
out
|
||
}
|
||
}
|
||
|
||
// ── Server ──────────────────────────────────────────────────────────
|
||
|
||
pub struct Server {
|
||
db: Database,
|
||
hints: Hints,
|
||
}
|
||
|
||
impl Server {
|
||
/// Build server state: generates a random public matrix A and precomputes
|
||
/// the hint H = D · A^T. This is the expensive offline phase.
|
||
pub fn new(db: Database) -> Self {
|
||
let m = db.n_records();
|
||
let ell = db.cells_per_record();
|
||
|
||
let mut seed = [0u8; 32];
|
||
OsRng.fill(&mut seed);
|
||
|
||
let hint = if m == 0 {
|
||
Vec::new()
|
||
} else {
|
||
let a = lwe::gen_matrix(&seed, N_LWE, m);
|
||
lwe::mat_mul_bt(db.data(), &a, ell, m, N_LWE)
|
||
};
|
||
|
||
Server {
|
||
hints: Hints { seed, n_records: m, cells_per_record: ell, hint },
|
||
db,
|
||
}
|
||
}
|
||
|
||
pub fn hints(&self) -> &Hints { &self.hints }
|
||
|
||
pub fn version(&self) -> [u8; 32] { self.hints.version() }
|
||
|
||
/// Compute answer = D · query (mod 2^32). The query vector has `n_records`
|
||
/// entries; the answer has `cells_per_record` entries.
|
||
pub fn answer(&self, query: &[u32]) -> Vec<u32> {
|
||
lwe::mat_vec_mul(
|
||
self.db.data(), query,
|
||
self.db.cells_per_record(), self.db.n_records(),
|
||
)
|
||
}
|
||
|
||
pub fn n_records(&self) -> usize { self.db.n_records() }
|
||
pub fn record_bytes(&self) -> usize { self.db.record_bytes() }
|
||
}
|
||
|
||
// ── Client ──────────────────────────────────────────────────────────
|
||
|
||
pub struct ClientState {
|
||
secret: Vec<u32>,
|
||
}
|
||
|
||
pub struct Client {
|
||
hints: Hints,
|
||
}
|
||
|
||
impl Client {
|
||
pub fn new(hints: Hints) -> Self {
|
||
Client { hints }
|
||
}
|
||
|
||
pub fn version(&self) -> [u8; 32] { self.hints.version() }
|
||
pub fn n_records(&self) -> usize { self.hints.n_records }
|
||
pub fn record_bytes(&self) -> usize { self.hints.cells_per_record }
|
||
|
||
/// Generate a PIR query for `index`. Returns the query vector (to send to
|
||
/// the server) and opaque client state (needed for decoding the answer).
|
||
pub fn query(&self, index: usize) -> (Vec<u32>, ClientState) {
|
||
assert!(index < self.hints.n_records, "PIR query index out of range");
|
||
|
||
let mut rng = OsRng;
|
||
let s = lwe::sample_uniform_vec(&mut rng, N_LWE);
|
||
let e = lwe::sample_error_vec(&mut rng, self.hints.n_records);
|
||
|
||
// q = A^T · s + e (A is N_LWE × n_records)
|
||
let a = lwe::gen_matrix(self.hints.seed(), N_LWE, self.hints.n_records);
|
||
let mut q = lwe::mat_t_vec_mul(&a, &s, N_LWE, self.hints.n_records);
|
||
|
||
for j in 0..self.hints.n_records {
|
||
q[j] = q[j].wrapping_add(e[j]);
|
||
}
|
||
q[index] = q[index].wrapping_add(DELTA);
|
||
|
||
(q, ClientState { secret: s })
|
||
}
|
||
|
||
/// Decode the server's answer using the saved client state.
|
||
/// Returns the `record_bytes`-length plaintext record.
|
||
pub fn decode(&self, answer: &[u32], state: &ClientState) -> Vec<u8> {
|
||
// h_s = H · s (cells_per_record × N_LWE) · (N_LWE) → cells_per_record
|
||
let h_s = lwe::mat_vec_mul(
|
||
self.hints.hint_matrix(), &state.secret,
|
||
self.hints.cells_per_record, N_LWE,
|
||
);
|
||
|
||
(0..self.hints.cells_per_record)
|
||
.map(|i| lwe::round_to_plaintext(answer[i].wrapping_sub(h_s[i])))
|
||
.collect()
|
||
}
|
||
}
|
||
|
||
// ── Serialization helpers for query / answer vectors ────────────────
|
||
|
||
pub fn serialize_vec(v: &[u32]) -> Vec<u8> {
|
||
let mut buf = Vec::with_capacity(v.len() * 4);
|
||
for &val in v {
|
||
buf.extend_from_slice(&val.to_le_bytes());
|
||
}
|
||
buf
|
||
}
|
||
|
||
pub fn deserialize_vec(data: &[u8]) -> Vec<u32> {
|
||
data.chunks_exact(4)
|
||
.map(|c| u32::from_le_bytes(c.try_into().unwrap()))
|
||
.collect()
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
fn make_test_records(n: usize, rec_bytes: usize) -> Vec<Vec<u8>> {
|
||
(0..n).map(|i| {
|
||
let mut rec = vec![0u8; rec_bytes];
|
||
rec[0] = i as u8;
|
||
rec[1] = (i.wrapping_mul(37) & 0xFF) as u8;
|
||
if rec_bytes > 2 {
|
||
rec[rec_bytes - 1] = 0xAA;
|
||
}
|
||
rec
|
||
}).collect()
|
||
}
|
||
|
||
#[test]
|
||
fn roundtrip_small() {
|
||
let records = make_test_records(8, 128);
|
||
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
|
||
let db = Database::new(&refs, 128);
|
||
let server = Server::new(db);
|
||
let hints_bytes = server.hints().serialize();
|
||
let client = Client::new(Hints::deserialize(&hints_bytes).unwrap());
|
||
|
||
for target in 0..8 {
|
||
let (query, state) = client.query(target);
|
||
let answer = server.answer(&query);
|
||
let decoded = client.decode(&answer, &state);
|
||
assert_eq!(decoded, records[target]);
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn roundtrip_serialized_query_answer() {
|
||
let records = make_test_records(4, 64);
|
||
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
|
||
let db = Database::new(&refs, 64);
|
||
let server = Server::new(db);
|
||
let client = Client::new(Hints::deserialize(&server.hints().serialize()).unwrap());
|
||
|
||
let (q, state) = client.query(2);
|
||
let q_bytes = serialize_vec(&q);
|
||
let q2 = deserialize_vec(&q_bytes);
|
||
assert_eq!(q, q2);
|
||
|
||
let ans = server.answer(&q2);
|
||
let ans_bytes = serialize_vec(&ans);
|
||
let ans2 = deserialize_vec(&ans_bytes);
|
||
let decoded = client.decode(&ans2, &state);
|
||
assert_eq!(decoded, records[2]);
|
||
}
|
||
|
||
#[test]
|
||
fn version_changes_on_rebuild() {
|
||
let records = make_test_records(4, 64);
|
||
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
|
||
let db1 = Database::new(&refs, 64);
|
||
let s1 = Server::new(db1);
|
||
let v1 = s1.version();
|
||
|
||
let db2 = Database::new(&refs, 64);
|
||
let s2 = Server::new(db2);
|
||
let v2 = s2.version();
|
||
|
||
// Different seeds → different versions (with overwhelming probability).
|
||
assert_ne!(v1, v2);
|
||
}
|
||
|
||
#[test]
|
||
fn hints_roundtrip() {
|
||
let records = make_test_records(4, 64);
|
||
let refs: Vec<&[u8]> = records.iter().map(|r| r.as_slice()).collect();
|
||
let db = Database::new(&refs, 64);
|
||
let server = Server::new(db);
|
||
|
||
let original = server.hints();
|
||
let bytes = original.serialize();
|
||
let restored = Hints::deserialize(&bytes).unwrap();
|
||
|
||
assert_eq!(original.seed(), restored.seed());
|
||
assert_eq!(original.n_records(), restored.n_records());
|
||
assert_eq!(original.cells_per_record(), restored.cells_per_record());
|
||
assert_eq!(original.hint_matrix(), restored.hint_matrix());
|
||
assert_eq!(original.version(), restored.version());
|
||
}
|
||
}
|