labrador/
verifier.rs

1#![allow(clippy::result_large_err)]
2
3use thiserror::Error;
4
5use crate::commitments::common_instances::AjtaiInstances;
6use crate::commitments::outer_commitments::{self, DecompositionParameters};
7use crate::core::aggregate::{FunctionsAggregation, ZeroConstantFunctionsAggregation};
8use crate::core::inner_product;
9use crate::relation::env_params;
10use crate::relation::{env_params::EnvironmentParameters, statement::Statement};
11use crate::ring::rq::Rq;
12use crate::ring::rq_matrix::RqMatrix;
13use crate::ring::rq_vector::RqVector;
14use crate::ring::zq::Zq;
15use crate::transcript::{LabradorTranscript, Sponge};
16
17#[derive(Debug, Error)]
18pub enum VerifierError {
19    #[error("matrix not symmetric at ({i},{j}): expected {expected:?}, found {found:?}")]
20    NotSymmetric {
21        i: usize,
22        j: usize,
23        expected: Rq,
24        found: Rq,
25    },
26    #[error("B0 mismatch at index {index}: expected {expected}, computed {computed}")]
27    B0Mismatch {
28        index: usize,
29        expected: Zq,
30        computed: Zq,
31    },
32    #[error("‖z‖² = {norm} exceeds allowed bound {allowed}")]
33    NormSumExceeded { norm: Zq, allowed: Zq },
34    #[error("A·z check failed: expected {expected:?}, computed {computed:?}")]
35    AzError {
36        computed: RqVector,
37        expected: RqVector,
38    },
39    #[error("⟨z,z⟩ mismatch: expected {expected:?}, computed {computed:?}")]
40    ZInnerError { computed: Rq, expected: Rq },
41    #[error("φ(z) mismatch: expected {expected:?}, computed {computed:?}")]
42    PhiError { computed: Rq, expected: Rq },
43    #[error("relation check failed")]
44    RelationCheckFailed,
45    #[error("outer commitment mismatch: expected {expected:?}, computed {computed:?}")]
46    OuterCommitError {
47        computed: RqVector,
48        expected: RqVector,
49    },
50    #[error(transparent)]
51    DecompositionError(#[from] outer_commitments::DecompositionError),
52}
53pub struct LabradorVerifier<'a> {
54    params: &'a EnvironmentParameters,
55    crs: &'a AjtaiInstances,
56    st: &'a Statement,
57    // Aggregation instances
58    constant_aggregator: ZeroConstantFunctionsAggregation<'a>,
59    funcs_aggregator: FunctionsAggregation<'a>,
60}
61
62impl<'a> LabradorVerifier<'a> {
63    pub fn new(
64        params: &'a EnvironmentParameters,
65        crs: &'a AjtaiInstances,
66        st: &'a Statement,
67    ) -> Self {
68        Self {
69            params,
70            crs,
71            st,
72            constant_aggregator: ZeroConstantFunctionsAggregation::new(params),
73            funcs_aggregator: FunctionsAggregation::new(params),
74        }
75    }
76
77    /// All check conditions are from page 18
78    pub fn verify<S: Sponge>(
79        &mut self,
80        proof: &LabradorTranscript<S>,
81    ) -> Result<bool, VerifierError> {
82        let mut transcript = LabradorTranscript::new(S::default());
83
84        transcript.absorb_u1(&proof.u1);
85        let projections = transcript.generate_projections(
86            env_params::SECURITY_PARAMETER,
87            self.params.rank,
88            self.params.multiplicity,
89        );
90        transcript.absorb_vector_p(&proof.vector_p);
91        let size_of_psi = usize::div_ceil(env_params::SECURITY_PARAMETER, self.params.log_q);
92        let size_of_omega = size_of_psi;
93        let psi = transcript.generate_vector_psi(size_of_psi, self.params.constraint_l);
94        let omega = transcript.generate_vector_omega(size_of_omega, env_params::SECURITY_PARAMETER);
95        transcript.absorb_vector_b_ct_aggr(&proof.b_ct_aggr);
96        let vector_alpha = transcript.generate_rq_vector(self.params.constraint_k);
97        let size_of_beta = size_of_psi;
98        let vector_beta = transcript.generate_rq_vector(size_of_beta);
99        transcript.absorb_u2(&proof.u2);
100        let challenges =
101            transcript.generate_challenges(env_params::OPERATOR_NORM, self.params.multiplicity);
102
103        // check b_0^{''(k)} ?= <omega^(k),p> + \sum(psi_l^(k) * b_0^{'(l)})
104        Self::check_b_0_aggr(self, proof, self.params, &psi, &omega)?;
105
106        // 3. line 14: check norm_sum(z, t, g, h) <= (beta')^2
107
108        // decompose z into z = z^(0) + z^(1) * b, only two parts.
109        let z_ij = RqVector::decompose(&proof.z, self.params.b, 2);
110        let t_ij: Vec<Vec<RqVector>> = proof
111            .t
112            .get_elements()
113            .iter()
114            .map(|i| RqVector::decompose(i, self.params.b, self.params.t_1))
115            .collect();
116        let g_ij: Vec<Vec<RqVector>> = proof
117            .g
118            .get_elements()
119            .iter()
120            .map(|i| RqVector::decompose(i, self.params.b, self.params.t_2))
121            .collect();
122        let h_ij: Vec<Vec<RqVector>> = proof
123            .h
124            .get_elements()
125            .iter()
126            .map(|i| RqVector::decompose(i, self.params.b, self.params.t_1))
127            .collect();
128        let norm_z_ij = z_ij
129            .iter()
130            .fold(Zq::ZERO, |acc, p| acc + p.l2_norm_squared());
131        let norm_t_ij = Self::norm_squared(&t_ij);
132        let norm_g_ij = Self::norm_squared(&g_ij);
133        let norm_h_ij = Self::norm_squared(&h_ij);
134        let norm_sum = norm_z_ij + norm_t_ij + norm_g_ij + norm_h_ij;
135
136        if norm_sum > self.params.beta * self.params.beta {
137            return Err(VerifierError::NormSumExceeded {
138                norm: norm_sum,
139                allowed: self.params.beta * self.params.beta,
140            });
141        }
142
143        // 4. line 15: check Az ?= c_1 * t_1 + ... + c_r * t_r
144        let az = self.crs.commitment_scheme_a.matrix() * &proof.z;
145        let ct_sum = inner_product::compute_linear_combination(
146            proof.t.get_elements(),
147            challenges.get_elements(),
148        );
149        if az != ct_sum {
150            return Err(VerifierError::AzError {
151                computed: az,
152                expected: ct_sum,
153            });
154        }
155
156        // 5. lne 16: check <z, z> ?= \sum(g_ij * c_i * c_j)
157
158        let z_inner = inner_product::compute_linear_combination(
159            proof.z.get_elements(),
160            proof.z.get_elements(),
161        );
162        let sum_gij_cij = Self::calculate_gh_ci_cj(&proof.g, &challenges, self.params.multiplicity);
163
164        if z_inner != sum_gij_cij {
165            return Err(VerifierError::ZInnerError {
166                computed: z_inner,
167                expected: sum_gij_cij,
168            });
169        }
170
171        // 6. line 17: check \sum(<\phi_i, z>c_i) ?= \sum(h_ij * c_i * c_j)
172        self.constant_aggregator.calculate_agg_phi_double_prime(
173            &self.st.phi_ct,
174            &projections.get_conjugated_projection_matrices(),
175            &psi,
176            &omega,
177        );
178        self.funcs_aggregator.calculate_aggr_phi(
179            &self.st.phi_constraint,
180            self.constant_aggregator.get_phi_double_prime(),
181            &vector_alpha,
182            &vector_beta,
183        );
184        let sum_phi_z_c =
185            Self::calculate_phi_z_c(self.funcs_aggregator.get_appr_phi(), &challenges, &proof.z);
186        let sum_hij_cij = Self::calculate_gh_ci_cj(&proof.h, &challenges, self.params.multiplicity);
187
188        // Left side multiple by 2 because of when we calculate h_ij, we didn't apply the division (divided by 2)
189        if &sum_phi_z_c * &Zq::TWO != sum_hij_cij {
190            return Err(VerifierError::PhiError {
191                computed: &sum_phi_z_c * &Zq::TWO,
192                expected: sum_hij_cij,
193            });
194        }
195
196        // 7. line 18: check \sum(a_ij * g_ij) + \sum(h_ii) - b ?= 0
197
198        self.constant_aggregator
199            .calculate_agg_a_double_prime(&psi, &self.st.a_ct);
200        self.funcs_aggregator.calculate_agg_a(
201            &self.st.a_constraint,
202            self.constant_aggregator.get_alpha_double_prime(),
203            &vector_alpha,
204            &vector_beta,
205        );
206
207        self.funcs_aggregator.calculate_aggr_b(
208            &self.st.b_constraint,
209            &proof.b_ct_aggr,
210            &vector_alpha,
211            &vector_beta,
212        );
213
214        if !Self::check_relation(
215            self.funcs_aggregator.get_agg_a(),
216            self.funcs_aggregator.get_aggr_b(),
217            &proof.g,
218            &proof.h,
219        ) {
220            return Err(VerifierError::RelationCheckFailed);
221        }
222
223        // 8. line 19: u_1 ?= \sum(\sum(B_ik * t_i^(k))) + \sum(\sum(C_ijk * g_ij^(k)))
224
225        let u_1 = &proof.u1;
226        let commitment_u1 = outer_commitments::compute_u1(
227            self.crs,
228            &proof.t,
229            DecompositionParameters::new(self.params.b, self.params.t_1)?,
230            &proof.g,
231            DecompositionParameters::new(self.params.b, self.params.t_2)?,
232        );
233
234        if proof.u1 != commitment_u1 {
235            return Err(VerifierError::OuterCommitError {
236                computed: u_1.clone(),
237                expected: commitment_u1,
238            });
239        }
240
241        // 9. line 20: u_2 ?= \sum(\sum(D_ijk * h_ij^(k)))
242        let commitment_u2 = outer_commitments::compute_u2(
243            self.crs,
244            &proof.h,
245            DecompositionParameters::new(self.params.b, self.params.t_1)?,
246        );
247
248        if proof.u2 != commitment_u2 {
249            return Err(VerifierError::OuterCommitError {
250                computed: commitment_u2,
251                expected: proof.u2.clone(),
252            });
253        }
254
255        Ok(true)
256    }
257
258    /// calculate the right hand side of line 16 or line 17, \sum(g_ij * c_i * c_j) or \sum(h_ij * c_i * c_j)
259    fn calculate_gh_ci_cj(x_ij: &RqMatrix, random_c: &RqVector, r: usize) -> Rq {
260        (0..r)
261            .map(|i| {
262                (0..r)
263                    .map(|j| {
264                        &(x_ij.get_cell(i, j) * &random_c.get_elements()[i])
265                            * &random_c.get_elements()[j]
266                    })
267                    .fold(Rq::zero(), |acc, x| &acc + &x)
268            })
269            .fold(Rq::zero(), |acc, x| &acc + &x)
270    }
271
272    /// calculate the left hand side of line 17, \sum(<\phi_z, z> * c_i)
273    fn calculate_phi_z_c(phi: &[RqVector], c: &RqVector, z: &RqVector) -> Rq {
274        phi.iter()
275            .zip(c.get_elements())
276            .map(|(phi_i, c_i)| {
277                &(inner_product::compute_linear_combination(phi_i.get_elements(), z.get_elements()))
278                    * c_i
279            })
280            .fold(Rq::zero(), |acc, x| &acc + &x)
281    }
282
283    fn norm_squared(polys: &[Vec<RqVector>]) -> Zq {
284        polys.iter().fold(Zq::ZERO, |acc, poly| {
285            acc + poly
286                .iter()
287                .fold(Zq::ZERO, |acc, p| acc + p.l2_norm_squared())
288        })
289    }
290
291    /// line 18, page 18: check if \sum(a_{ij} * g_{ij}) + \sum(h_{ii}) - b ?= 0
292    /// in the verifier process, page 18 from the paper.
293    ///
294    /// param: a_primes: a_{ij}^{''(k)}
295    /// param: b_primes: b^{''(k)}
296    /// param: g: g_{ij}
297    /// param: h: h_{ii}
298    ///
299    /// return: true if the relation holds, false otherwise
300    pub fn check_relation(a_primes: &RqMatrix, b_primes: &Rq, g: &RqMatrix, h: &RqMatrix) -> bool {
301        let r = a_primes.get_elements().len();
302
303        let mut sum_a_primes_g = Rq::zero();
304        // walk only over the stored half: i ≤ j
305        for i in 0..r {
306            for j in 0..r {
307                sum_a_primes_g = &sum_a_primes_g + &(a_primes.get_cell(i, j) * g.get_cell(i, j));
308            }
309        }
310
311        let sum_h_ii = (0..r).fold(Rq::zero(), |acc, i| &acc + h.get_cell(i, i));
312
313        let b_primes2 = b_primes * &Zq::TWO;
314        let sum_a_primes_g2 = &sum_a_primes_g * &Zq::TWO;
315
316        &sum_a_primes_g2 + &sum_h_ii == b_primes2
317    }
318
319    fn check_b_0_aggr<S: Sponge>(
320        &self,
321        proof: &LabradorTranscript<S>,
322        ep: &EnvironmentParameters,
323        psi: &[Vec<Zq>],
324        omega: &[Vec<Zq>],
325    ) -> Result<bool, VerifierError> {
326        for k in 0..ep.kappa {
327            let b_0_poly = proof.b_ct_aggr.get_elements()[k].get_coefficients()[0];
328            let mut b_0: Zq = (0..ep.constraint_l)
329                .map(|l| psi[k][l] * self.st.b_0_ct[l])
330                .sum();
331
332            let inner_omega_p =
333                inner_product::compute_linear_combination(&omega[k], &proof.vector_p);
334            b_0 += inner_omega_p;
335            if b_0 != b_0_poly {
336                return Err(VerifierError::B0Mismatch {
337                    index: k,
338                    expected: b_0_poly,
339                    computed: b_0,
340                });
341            }
342        }
343
344        Ok(true)
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use crate::prover::LabradorProver;
352    use crate::relation::witness::Witness;
353    use crate::transcript::sponges::shake::ShakeSponge;
354
355    #[test]
356    fn test_verify() {
357        let ep_1 = EnvironmentParameters::default();
358        // generate a random witness based on ep above
359        let witness_1 = Witness::new(ep_1.rank, ep_1.multiplicity, ep_1.beta);
360        // generate public statements based on witness_1
361        let st: Statement = Statement::new(&witness_1, &ep_1);
362        // generate the common reference string matrices
363        let crs = AjtaiInstances::new(&ep_1);
364
365        // create a new prover
366        let mut prover = LabradorProver::new(&ep_1, &crs, &witness_1, &st);
367        let proof: LabradorTranscript<ShakeSponge> = prover.prove().unwrap();
368
369        // create a new verifier
370        let mut verifier = LabradorVerifier::new(&ep_1, &crs, &st);
371        let result = verifier.verify(&proof);
372        assert!(result.unwrap());
373    }
374}