labrador/
prover.rs

1use crate::commitments::common_instances::AjtaiInstances;
2use crate::commitments::outer_commitments::DecompositionParameters;
3use crate::commitments::outer_commitments::OuterCommitment;
4use crate::core::garbage_polynomials::GarbagePolynomials;
5use crate::ring::rq_matrix::RqMatrix;
6use crate::ring::zq::Zq;
7use crate::ring::zq::ZqVector;
8use crate::{
9    core::{
10        aggregate,
11        challenge_set::ChallengeSet,
12        env_params::EnvironmentParameters,
13        jl::{ProjectionMatrix, Projections},
14        statement::Statement,
15    },
16    ring::rq_vector::RqVector,
17};
18use rand::rng;
19
20#[derive(Debug)]
21pub enum ProverError {
22    /// Indicates that the L2 norm (squared) of the witness exceeded the allowed threshold.
23    WitnessL2NormViolated { norm_squared: Zq, allowed: Zq },
24    ProjectionError {
25        index: usize,
26        expected: Zq,
27        computed: Zq,
28    },
29}
30
31// Proof contains the parameters will be sent to verifier
32// All parameters are from tr, line 2 on page 18
33pub struct Proof {
34    pub u_1: RqVector,
35    pub p: Projections,
36    pub b_ct_aggr: RqVector,
37    pub u_2: RqVector,
38    pub z: RqVector,
39    pub t_i: Vec<RqVector>,
40    pub g_ij: RqMatrix,
41    pub h_ij: RqMatrix,
42}
43
44// pub struct Challenges just for testing, should be replaced by the Transcript
45pub struct Challenges {
46    pub pi: Vec<Vec<Vec<Zq>>>,
47    pub psi: Vec<Vec<Zq>>,
48    pub omega: Vec<Vec<Zq>>,
49    pub random_alpha: RqVector,
50    pub random_beta: RqVector,
51    pub random_c: RqVector,
52}
53
54impl Challenges {
55    pub fn new(ep: &EnvironmentParameters) -> Self {
56        // generate random psi with size: k * constraint_l, each element is Zq
57        let psi: Vec<Vec<Zq>> = (0..ep.kappa)
58            .map(|_| Vec::<Zq>::random(&mut rng(), ep.constraint_l))
59            .collect();
60
61        // generate randm omega is with size: k * lambda2, each element is Zq
62        let omega: Vec<Vec<Zq>> = (0..ep.kappa)
63            .map(|_| Vec::<Zq>::random(&mut rng(), 2 * ep.lambda))
64            .collect();
65
66        // \pi is from JL projection, pi contains r matrices and each matrix: security_level2 * (n*d), (security_level2 is 256 in the paper).
67        let pi: Vec<Vec<Vec<Zq>>> = Self::get_pi(ep.r, ep.n);
68
69        // generate random alpha and beta from challenge set
70        let cs_alpha: ChallengeSet = ChallengeSet::new();
71        let random_alpha: RqVector = (0..ep.constraint_k)
72            .map(|_| cs_alpha.get_challenges().clone())
73            .collect();
74
75        let cs_beta: ChallengeSet = ChallengeSet::new();
76        let random_beta: RqVector = (0..ep.constraint_k)
77            .map(|_| cs_beta.get_challenges().clone())
78            .collect();
79
80        let cs_c: ChallengeSet = ChallengeSet::new();
81        let random_c: RqVector = (0..ep.r).map(|_| cs_c.get_challenges().clone()).collect();
82
83        Self {
84            pi,
85            psi,
86            omega,
87            random_alpha,
88            random_beta,
89            random_c,
90        }
91    }
92
93    pub fn get_pi(r: usize, n: usize) -> Vec<Vec<Vec<Zq>>> {
94        (0..r)
95            .map(|_| ProjectionMatrix::new(n).get_matrix().clone())
96            .collect()
97    }
98}
99pub struct Witness {
100    pub s: Vec<RqVector>,
101}
102
103impl Witness {
104    pub fn new(ep: &EnvironmentParameters) -> Self {
105        let s = (0..ep.r)
106            .map(|_| RqVector::random_ternary(&mut rng(), ep.n))
107            .collect();
108        Self { s }
109    }
110}
111
112pub struct LabradorProver<'a> {
113    pub pp: &'a AjtaiInstances,
114    pub witness: &'a Witness,
115    pub st: &'a Statement,
116    pub tr: &'a Challenges,
117}
118
119impl<'a> LabradorProver<'a> {
120    pub fn new(
121        pp: &'a AjtaiInstances,
122        witness: &'a Witness,
123        st: &'a Statement,
124        tr: &'a Challenges,
125    ) -> Self {
126        Self {
127            pp,
128            witness,
129            st,
130            tr,
131        }
132    }
133
134    /// all prove steps are from page 17
135    pub fn prove(&self, ep: &EnvironmentParameters) -> Result<Proof, ProverError> {
136        // check the L2 norm of the witness
137        // not sure whether this should be handled during the proving or managed by the witness generator.
138        Self::check_witness_l2norm(self, ep).unwrap();
139        // Step 1: Outer commitments u_1 starts: --------------------------------------------
140
141        // Ajtai Commitments t_i = A * s_i
142        let t_i: Vec<RqVector> = self
143            .witness
144            .s
145            .iter()
146            .map(|s_i| self.pp.commitment_scheme_a.commit(s_i).unwrap())
147            .collect();
148
149        // This replaces the following code
150        let mut garbage_polynomials = GarbagePolynomials::new(self.witness.s.clone());
151        garbage_polynomials.compute_g();
152        // calculate outer commitment u_1 = \sum(B_ik * t_i^(k)) + \sum(C_ijk * g_ij^(k))
153        let mut outer_commitments = OuterCommitment::new(self.pp);
154        outer_commitments.compute_u1(
155            RqMatrix::new(t_i.clone()),
156            DecompositionParameters::new(ep.b, ep.t_1).unwrap(),
157            garbage_polynomials.g.clone(),
158            DecompositionParameters::new(ep.b, ep.t_2).unwrap(),
159        );
160        // Step 1: Outer commitments u_1 ends: ----------------------------------------------
161
162        // Step 2: JL projection starts: ----------------------------------------------------
163
164        // JL projection p_j + check p_j = ct(sum(<\sigma_{-1}(pi_i^(j)), s_i>))
165        let matrices = &self.tr.pi;
166        let p = Projections::new(matrices, &self.witness.s);
167
168        // Notice that this check is resource-intensive due to the multiplication of two ZqVector<256> instances,
169        // followed by the removal of high-degree terms. It might not be a necessary check.
170        Self::check_projection(self, p.get_projection()).unwrap();
171
172        // Step 2: JL projection ends: ------------------------------------------------------
173
174        // Step 3: Aggregation starts: --------------------------------------------------------------
175
176        // first aggregation
177        let aggr_1 = aggregate::AggregationOne::new(self.witness, self.st, ep, self.tr);
178        // second aggregation
179        let aggr_2 = aggregate::AggregationTwo::new(&aggr_1, self.st, ep, self.tr);
180
181        // Aggregation ends: ----------------------------------------------------------------
182
183        // Step 4: Calculate h_ij, u_2, and z starts: ---------------------------------------
184
185        let phi_i = aggr_2.phi_i;
186        garbage_polynomials.compute_h(&phi_i);
187        outer_commitments.compute_u2(
188            garbage_polynomials.h.clone(),
189            DecompositionParameters::new(ep.b, ep.t_1).unwrap(),
190        );
191
192        // calculate z = c_1*s_1 + ... + c_r*s_r
193        let z = aggregate::calculate_z(&self.witness.s, &self.tr.random_c);
194
195        // Step 4: Calculate h_ij, u_2, and z ends: -----------------------------------------
196
197        Ok(Proof {
198            u_1: outer_commitments.u_1,
199            p,
200            b_ct_aggr: aggr_1.b_ct_aggr,
201            u_2: outer_commitments.u_2,
202            z,
203            t_i,
204            g_ij: garbage_polynomials.g,
205            h_ij: garbage_polynomials.h,
206        })
207    }
208
209    /// check p_j? = ct(sum(<σ−1(pi_i^(j)), s_i>))
210    fn check_projection(&self, p: &[Zq]) -> Result<bool, ProverError> {
211        let s_coeffs: Vec<Vec<Zq>> = self
212            .witness
213            .s
214            .iter()
215            .map(|s_i| {
216                s_i.iter()
217                    .flat_map(|s_i_p| *s_i_p.get_coefficients())
218                    .collect()
219            })
220            .collect();
221
222        for (j, &p_j) in p.iter().enumerate() {
223            let mut poly = vec![Zq::ZERO; p.len()];
224            for (i, s_i) in s_coeffs.iter().enumerate() {
225                let pi_ele = &self.tr.pi[i][j];
226                let pi_ele_ca = pi_ele.conjugate_automorphism();
227                poly = poly.add(&(pi_ele_ca.multiply(s_i)));
228            }
229
230            if poly[0] != p_j {
231                return Err(ProverError::ProjectionError {
232                    index: j,
233                    expected: p_j,
234                    computed: poly[0],
235                });
236            }
237        }
238
239        Ok(true)
240    }
241
242    /// check the L2 norm of the witness, || s_i || <= beta
243    fn check_witness_l2norm(&self, ep: &EnvironmentParameters) -> Result<bool, ProverError> {
244        let beta2 = ep.beta * ep.beta;
245        for polys in &self.witness.s {
246            let witness_l2norm_squared = RqVector::compute_norm_squared(polys);
247            if witness_l2norm_squared > beta2 {
248                return Err(ProverError::WitnessL2NormViolated {
249                    norm_squared: witness_l2norm_squared,
250                    allowed: beta2,
251                });
252            }
253        }
254        Ok(true)
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_prove() {
264        // set up example environment parameters, use default set for testing.
265        let ep_1 = EnvironmentParameters::default();
266        // generate a random witness based on environment parameters above
267        let witness_1 = Witness::new(&ep_1);
268        // generate public statement based on witness_1
269        let st: Statement = Statement::new(&witness_1, &ep_1);
270        // generate the common reference string matrices A, B, C, D
271        let pp = AjtaiInstances::new(&ep_1);
272        // generate random challenges used between prover and verifier.
273        let tr = Challenges::new(&ep_1);
274
275        // create a new prover
276        let prover = LabradorProver::new(&pp, &witness_1, &st, &tr);
277        let _proof = prover.prove(&ep_1).unwrap();
278    }
279}