labrador/
verifier.rs

1use crate::core::{
2    aggregate, crs::PublicPrams, env_params::EnvironmentParameters, statement::Statement,
3};
4use crate::prover::{Challenges, Proof};
5use crate::ring::poly::{PolyRing, PolyVector};
6use crate::ring::zq::Zq;
7
8#[derive(Debug)]
9pub enum VerifierError {
10    NotSymmetric {
11        i: usize,
12        j: usize,
13        expected: PolyRing,
14        found: PolyRing,
15    },
16    B0Mismatch {
17        index: usize,
18        expected: Zq,
19        computed: Zq,
20    },
21    NormSumExceeded {
22        norm: Zq,
23        allowed: Zq,
24    },
25    AzError {
26        computed: PolyVector,
27        expected: PolyVector,
28    },
29    ZInnerError {
30        computed: PolyRing,
31        expected: PolyRing,
32    },
33    PhiError {
34        computed: PolyRing,
35        expected: PolyRing,
36    },
37    RelationCheckFailed,
38    OuterCommitError {
39        computed: PolyVector,
40        expected: PolyVector,
41    },
42}
43pub struct LabradorVerifier<'a> {
44    pub pp: &'a PublicPrams,
45    pub st: &'a Statement,
46    pub tr: &'a Challenges,
47}
48
49impl<'a> LabradorVerifier<'a> {
50    pub fn new(pp: &'a PublicPrams, st: &'a Statement, tr: &'a Challenges) -> Self {
51        Self { pp, st, tr }
52    }
53
54    /// All check conditions are from page 18
55    pub fn verify(&self, proof: &Proof, ep: &EnvironmentParameters) -> Result<bool, VerifierError> {
56        // 1. line 08: check g_ij ?= g_ji
57        // 2. line 09: check h_ij ?= h_ji
58        for i in 0..ep.r {
59            for j in (i + 1)..ep.r {
60                let g_ij = &proof.g_ij[i].get_elements()[j];
61                let g_ji = &proof.g_ij[j].get_elements()[i];
62                if g_ij != g_ji {
63                    return Err(VerifierError::NotSymmetric {
64                        i,
65                        j,
66                        expected: g_ji.clone(),
67                        found: g_ij.clone(),
68                    });
69                }
70                let h_ij = &proof.h_ij[i].get_elements()[j];
71                let h_ji = &proof.h_ij[j].get_elements()[i];
72                if h_ij != h_ji {
73                    return Err(VerifierError::NotSymmetric {
74                        i,
75                        j,
76                        expected: h_ji.clone(),
77                        found: h_ij.clone(),
78                    });
79                }
80            }
81        }
82
83        // check b_0^{''(k)} ?= <omega^(k),p> + \sum(psi_l^(k) * b_0^{'(l)})
84        Self::check_b_0_aggr(self, proof, ep).unwrap();
85
86        // 3. line 14: check norm_sum(z, t, g, h) <= (beta')^2
87
88        // decompose z into z = z^(0) + z^(1) * b, only two parts.
89        let z_ij = PolyVector::decompose(&proof.z, ep.b, 2);
90        let t_ij: Vec<Vec<PolyVector>> = proof
91            .t_i
92            .iter()
93            .map(|i| PolyVector::decompose(i, ep.b, ep.t_1))
94            .collect();
95        let g_ij: Vec<Vec<PolyVector>> = proof
96            .g_ij
97            .iter()
98            .map(|i| PolyVector::decompose(i, ep.b, ep.t_2))
99            .collect();
100        let h_ij: Vec<Vec<PolyVector>> = proof
101            .h_ij
102            .iter()
103            .map(|i| PolyVector::decompose(i, ep.b, ep.t_1))
104            .collect();
105        let norm_z_ij = z_ij
106            .iter()
107            .fold(Zq::ZERO, |acc, p| acc + p.compute_norm_squared());
108        let norm_t_ij = Self::norm_squared(&t_ij);
109        let norm_g_ij = Self::norm_squared(&g_ij);
110        let norm_h_ij = Self::norm_squared(&h_ij);
111        let norm_sum = norm_z_ij + norm_t_ij + norm_g_ij + norm_h_ij;
112
113        if norm_sum > ep.beta * ep.beta {
114            return Err(VerifierError::NormSumExceeded {
115                norm: norm_sum,
116                allowed: ep.beta * ep.beta,
117            });
118        }
119
120        // 4. line 15: check Az ?= c_1 * t_1 + ... + c_r * t_r
121
122        let az = &proof.z * &self.pp.matrix_a;
123        let ct_sum: PolyVector = aggregate::calculate_z(&proof.t_i, &self.tr.random_c);
124
125        if az != ct_sum {
126            return Err(VerifierError::AzError {
127                computed: az,
128                expected: ct_sum,
129            });
130        }
131
132        // 5. lne 16: check <z, z> ?= \sum(g_ij * c_i * c_j)
133
134        let z_inner = proof.z.inner_product_poly_vector(&proof.z);
135        let sum_gij_cij = Self::calculate_gh_ci_cj(&proof.g_ij, &self.tr.random_c, ep.r);
136
137        if z_inner != sum_gij_cij {
138            return Err(VerifierError::ZInnerError {
139                computed: z_inner,
140                expected: sum_gij_cij,
141            });
142        }
143
144        // 6. line 17: check \sum(<\phi_i, z>c_i) ?= \sum(h_ij * c_i * c_j)
145
146        let phi_ct_aggr = aggregate::AggregationOne::get_phi_ct_aggr(
147            &self.st.phi_ct,
148            &self.tr.pi,
149            &self.tr.psi,
150            &self.tr.omega,
151            ep,
152        );
153        let phi_i = aggregate::AggregationTwo::get_phi_i(
154            &self.st.phi_constraint,
155            &phi_ct_aggr,
156            &self.tr.random_alpha,
157            &self.tr.random_beta,
158            ep,
159        );
160        let sum_phi_z_c = Self::calculate_phi_z_c(&phi_i, &self.tr.random_c, &proof.z);
161        let sum_hij_cij = Self::calculate_gh_ci_cj(&proof.h_ij, &self.tr.random_c, ep.r);
162
163        // Left side multiple by 2 because of when we calculate h_ij, we didn't apply the division (divided by 2)
164        if &sum_phi_z_c * &Zq::TWO != sum_hij_cij {
165            return Err(VerifierError::PhiError {
166                computed: &sum_phi_z_c * &Zq::TWO,
167                expected: sum_hij_cij,
168            });
169        }
170
171        // 7. line 18: check \sum(a_ij * g_ij) + \sum(h_ii) - b ?= 0
172
173        let a_ct_aggr = aggregate::AggregationOne::get_a_ct_aggr(&self.tr.psi, &self.st.a_ct, ep);
174        let a_primes = aggregate::AggregationTwo::get_a_i(
175            &self.st.a_constraint,
176            &a_ct_aggr,
177            &self.tr.random_alpha,
178            &self.tr.random_beta,
179            ep,
180        );
181        let b_primes = aggregate::AggregationTwo::get_b_i(
182            &self.st.b_constraint,
183            &proof.b_ct_aggr,
184            &self.tr.random_alpha,
185            &self.tr.random_beta,
186            ep,
187        );
188
189        if !Self::check_relation(&a_primes, &b_primes, &proof.g_ij, &proof.h_ij) {
190            return Err(VerifierError::RelationCheckFailed);
191        }
192
193        // 8. line 19: u_1 ?= \sum(\sum(B_ik * t_i^(k))) + \sum(\sum(C_ijk * g_ij^(k)))
194
195        let u_1 = &proof.u_1;
196        let outer_commit_u_1 =
197            aggregate::calculate_u_1(&self.pp.matrix_b, &self.pp.matrix_c, &t_ij, &g_ij, ep);
198
199        if u_1 != &outer_commit_u_1 {
200            return Err(VerifierError::OuterCommitError {
201                computed: u_1.clone(),
202                expected: outer_commit_u_1,
203            });
204        }
205
206        // 9. line 20: u_2 ?= \sum(\sum(D_ijk * h_ij^(k)))
207
208        let u_2 = &proof.u_2;
209        let outer_commit_u_2 = aggregate::calculate_u_2(&self.pp.matrix_d, &h_ij, ep);
210
211        if u_2 != &outer_commit_u_2 {
212            return Err(VerifierError::OuterCommitError {
213                computed: u_2.clone(),
214                expected: outer_commit_u_2,
215            });
216        }
217
218        Ok(true)
219    }
220
221    /// calculate the right hand side of line 16 or line 17, \sum(g_ij * c_i * c_j) or \sum(h_ij * c_i * c_j)
222    fn calculate_gh_ci_cj(x_ij: &[PolyVector], random_c: &PolyVector, r: usize) -> PolyRing {
223        (0..r)
224            .map(|i| {
225                (0..r)
226                    .map(|j| {
227                        &(&x_ij[i].get_elements()[j] * &random_c.get_elements()[i])
228                            * &random_c.get_elements()[j]
229                    })
230                    .fold(PolyRing::zero_poly(), |acc, x| &acc + &x)
231            })
232            .fold(PolyRing::zero_poly(), |acc, x| &acc + &x)
233    }
234
235    /// calculate the left hand side of line 17, \sum(<\phi_z, z> * c_i)
236    fn calculate_phi_z_c(phi: &[PolyVector], c: &PolyVector, z: &PolyVector) -> PolyRing {
237        phi.iter()
238            .zip(c.iter())
239            .map(|(phi_i, c_i)| &(phi_i.inner_product_poly_vector(z)) * c_i)
240            .fold(PolyRing::zero_poly(), |acc, x| &acc + &x)
241    }
242
243    fn norm_squared(polys: &[Vec<PolyVector>]) -> Zq {
244        polys.iter().fold(Zq::ZERO, |acc, poly| {
245            acc + poly
246                .iter()
247                .fold(Zq::ZERO, |acc, p| acc + p.compute_norm_squared())
248        })
249    }
250
251    /// line 18, page 18: check if \sum(a_{ij} * g_{ij}) + \sum(h_{ii}) - b ?= 0
252    /// in the verifier process, page 18 from the paper.
253    ///
254    /// param: a_primes: a_{ij}^{''(k)}
255    /// param: b_primes: b^{''(k)}
256    /// param: g: g_{ij}
257    /// param: h: h_{ii}
258    ///
259    /// return: true if the relation holds, false otherwise
260    pub fn check_relation(
261        a_primes: &[PolyVector],
262        b_primes: &PolyRing,
263        g: &[PolyVector],
264        h: &[PolyVector],
265    ) -> bool {
266        let r = a_primes.len();
267        let d = a_primes[0].get_elements()[0].get_coeffs().len();
268
269        let sum_a_primes_g: PolyRing = a_primes
270            .iter()
271            .zip(g.iter())
272            .map(|(a_i, g_i)| {
273                a_i.iter()
274                    .zip(g_i.iter())
275                    .map(|(a_ij, g_ij)| a_ij * g_ij)
276                    .fold(PolyRing::new(vec![Zq::ZERO; d]), |acc, val| &acc + &val)
277            })
278            .fold(PolyRing::new(vec![Zq::ZERO; d]), |acc, val| &acc + &val);
279
280        let sum_h_ii: PolyRing = (0..r).fold(PolyRing::new(vec![Zq::ZERO; d]), |acc, i| {
281            &acc + &h[i].get_elements()[i]
282        });
283
284        let b_primes2 = b_primes * &Zq::TWO;
285        let sum_a_primes_g2 = &sum_a_primes_g * &Zq::TWO;
286
287        &sum_a_primes_g2 + &sum_h_ii == b_primes2
288    }
289
290    fn check_b_0_aggr(
291        &self,
292        proof: &Proof,
293        ep: &EnvironmentParameters,
294    ) -> Result<bool, VerifierError> {
295        for k in 0..ep.k {
296            let b_0_poly = proof.b_ct_aggr.get_elements()[k].get_coeffs()[0];
297            let mut b_0: Zq = (0..ep.constraint_l)
298                .map(|l| self.tr.psi[k].get_coeffs()[l] * self.st.b_0_ct.get_coeffs()[l])
299                .sum();
300            let inner_omega_p = self.tr.omega[k].inner_product(proof.p.get_projection());
301            b_0 += inner_omega_p;
302            if b_0 != b_0_poly {
303                return Err(VerifierError::B0Mismatch {
304                    index: k,
305                    expected: b_0_poly,
306                    computed: b_0,
307                });
308            }
309        }
310
311        Ok(true)
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318    use crate::prover::{LabradorProver, Witness};
319
320    #[test]
321    fn test_verify() {
322        // set up example environment, use default set for testing.
323        let ep_1 = EnvironmentParameters::default();
324        // generate a random witness based on ep above
325        let witness_1 = Witness::new(&ep_1);
326        // generate public statements based on witness_1
327        let st: Statement = Statement::new(&witness_1, &ep_1);
328        // generate the common reference string matrices
329        let pp = PublicPrams::new(&ep_1);
330        // generate random challenges
331        let tr = Challenges::new(&ep_1);
332
333        // create a new prover
334        let prover = LabradorProver::new(&pp, &witness_1, &st, &tr);
335        let proof = prover.prove(&ep_1).unwrap();
336
337        // create a new verifier
338        let verifier = LabradorVerifier::new(&pp, &st, &tr);
339        let result = verifier.verify(&proof, &ep_1);
340        assert!(result.unwrap());
341    }
342}