Skip to main content

zip_plus/pcs/
phase_prove.rs

1use crate::{
2    ZipError,
3    code::LinearCode,
4    pcs::{
5        structs::{ZipPlus, ZipPlusHint, ZipPlusParams, ZipTypes},
6        utils::{point_to_tensor, validate_input},
7    },
8    pcs_transcript::PcsProverTranscript,
9};
10use crypto_primitives::{FromWithConfig, IntoWithConfig, PrimeField};
11use itertools::Itertools;
12use num_traits::{ConstOne, ConstZero, Zero};
13#[cfg(feature = "parallel")]
14use rayon::prelude::*;
15use zinc_poly::{Polynomial, mle::DenseMultilinearExtension};
16use zinc_transcript::traits::{Transcribable, Transcript};
17use zinc_utils::{
18    UNCHECKED, cfg_chunks, cfg_iter, cfg_iter_mut,
19    from_ref::FromRef,
20    inner_product::{InnerProduct, MBSInnerProduct},
21    mul_by_scalar::MulByScalar,
22};
23
24impl<Zt: ZipTypes, Lc: LinearCode<Zt>> ZipPlus<Zt, Lc> {
25    /// Generates an opening proof for one or more committed multilinear
26    /// polynomials at an evaluation point, using the Zip+ protocol.
27    ///
28    /// This replaces the old two-phase (test + evaluate) approach with a single
29    /// merged phase. The key idea: alpha-projection (Eval → CombR) is used for
30    /// *both* the proximity argument and the evaluation claim, eliminating the
31    /// separate field-domain projection via `projecting_element` γ.
32    ///
33    /// # Algorithm
34    /// 1. Computes points: `(q_0, q_1) = point_to_tensor(point)` where `q_0`
35    ///    (length `num_rows`) combines rows and `q_1` (length `row_len`)
36    ///    combines columns.
37    /// 2. Per polynomial, samples random challenges `alphas` (`[α_0, …, α_d]`).
38    ///    For each decoded row `w_j` takes the inner product `<entry, alphas>`
39    ///    of every entry in the row, producing `w'_j` — a row of `CombR`
40    ///    integers.
41    /// 3. Computes `b` (length `num_rows`), accumulated across all polys: `b_j
42    ///    += <w'_j, q_1>` for each row `j`.
43    /// 4. Writes `b` to the transcript and computes `eval = <q_0, b>`.
44    /// 5. Samples combination coefficients `betas` (or hardcodes `[1]` when
45    ///    `num_rows == 1`) and computes `combined_row` (CombR, length
46    ///    `row_len`) = `sum_i(sum_j(s_j * w'_ij))`, accumulated across all
47    ///    polynomials
48    /// 6. Writes `combined_row` to the transcript.
49    /// 7. Opens `NUM_COLUMN_OPENINGS` Merkle columns: for each, squeezes a
50    ///    column index, writes per-polynomial column values (Cw entries), and
51    ///    appends the Merkle proof.
52    ///
53    /// # Transcript layout
54    /// ```text
55    /// [field_cfg sampled]
56    /// [per-poly alphas sampled]
57    /// [b written as F elements]
58    /// [coeffs s sampled (or hardcoded [1])]
59    /// [combined_row written as CombR]
60    /// [column openings: idx, per-poly column values, merkle proof] × NUM_COLUMN_OPENINGS
61    /// ```
62    ///
63    /// # Parameters
64    /// - `pp`: Public parameters containing `num_vars`, `num_rows`, and the
65    ///   linear code configuration.
66    /// - `polys`: Slice of multilinear polynomials (batch). All must have
67    ///   `num_vars` variables matching `pp`.
68    /// - `point`: The evaluation point (in `Zt::Pt` coordinates, length
69    ///   `num_vars`).
70    /// - `commit_hint`: The `ZipPlusHint` returned by `commit`, containing
71    ///   per-polynomial codeword matrices and the shared Merkle tree.
72    ///
73    /// # Returns
74    /// A `Result` containing:
75    /// - `F`: The combined evaluation `<q_0, b>`, which equals
76    ///   `sum_i(alpha_projected_eval_i(point))` across all batched polys.
77    /// - `ZipPlusProof`: The serialized transcript (b, combined_row, column
78    ///   openings + Merkle proofs) for the verifier.
79    ///
80    /// # Errors
81    /// - Returns `ZipError::InvalidPcsParam` if any polynomial has more
82    ///   variables than `pp` supports.
83    /// - Returns `ZipError::OverflowError` (when `CHECK_FOR_OVERFLOW` is true)
84    ///   if intermediate CombR sums exceed the integer precision.
85    pub fn prove<F, const CHECK_FOR_OVERFLOW: bool>(
86        transcript: &mut PcsProverTranscript,
87        pp: &ZipPlusParams<Zt, Lc>,
88        polys: &[DenseMultilinearExtension<Zt::Eval>],
89        point: &[Zt::Pt],
90        commit_hint: &ZipPlusHint<Zt::Cw>,
91        field_cfg: &F::Config,
92    ) -> Result<F, ZipError>
93    where
94        F: PrimeField
95            + for<'a> FromWithConfig<&'a Zt::CombR>
96            + for<'a> FromWithConfig<&'a Zt::Pt>
97            + for<'a> MulByScalar<&'a F>
98            + FromRef<F>,
99        F::Inner: Transcribable,
100        F::Modulus: Transcribable,
101    {
102        let point = point
103            .iter()
104            .map(|v| v.into_with_cfg(field_cfg))
105            .collect::<Vec<F>>();
106        Self::prove_f::<F, CHECK_FOR_OVERFLOW>(
107            transcript,
108            pp,
109            polys,
110            &point,
111            commit_hint,
112            field_cfg,
113        )
114    }
115
116    /// See [`Self::prove`] for details.
117    /// This version takes the evaluation point already mapped to the field
118    #[allow(clippy::arithmetic_side_effects)]
119    pub fn prove_f<F, const CHECK_FOR_OVERFLOW: bool>(
120        transcript: &mut PcsProverTranscript,
121        pp: &ZipPlusParams<Zt, Lc>,
122        polys: &[DenseMultilinearExtension<Zt::Eval>],
123        point: &[F],
124        commit_hint: &ZipPlusHint<Zt::Cw>,
125        field_cfg: &F::Config,
126    ) -> Result<F, ZipError>
127    where
128        F: PrimeField
129            + for<'a> FromWithConfig<&'a Zt::CombR>
130            + for<'a> MulByScalar<&'a F>
131            + FromRef<F>,
132        F::Inner: Transcribable,
133        F::Modulus: Transcribable,
134    {
135        let batch_size = polys.len();
136        validate_input::<Zt, Lc, _>(
137            "prove",
138            pp.num_vars,
139            pp.linear_code.row_len(),
140            batch_size,
141            polys,
142            &[point],
143        )?;
144
145        let num_rows = pp.num_rows;
146        let row_len = pp.linear_code.row_len();
147
148        // TODO Lift q0, q1 back to int and take following dot products on ints instead
149        // of MBSInnerProduct in field (see comboned row) We prove evaluations
150        // over the field, so integers need to be mapped to field elements first
151        let (q_0, q_1) = point_to_tensor(num_rows, point, field_cfg)?;
152
153        let degree_bound = Zt::Comb::DEGREE_BOUND;
154        let polys_as_comb_r: Vec<Vec<Zt::CombR>> = polys
155            .iter()
156            .map(|poly| {
157                let alphas = if degree_bound.is_zero() {
158                    vec![Zt::Chal::ONE]
159                } else {
160                    transcript.fs_transcript.get_challenges(degree_bound + 1)
161                };
162
163                cfg_iter!(poly.evaluations)
164                    .map(|eval| {
165                        Zt::EvalDotChal::inner_product::<CHECK_FOR_OVERFLOW>(
166                            eval,
167                            &alphas,
168                            Zt::CombR::ZERO,
169                        )
170                        .map_err(ZipError::from)
171                    })
172                    .collect()
173            })
174            .try_collect()?;
175
176        let zero_f = F::zero_with_cfg(field_cfg);
177
178        // Compute per-polynomial row dot products, then sum across polynomials.
179        let b = {
180            let per_poly_b: Vec<Vec<F>> = cfg_iter!(polys_as_comb_r)
181                .map(|poly_comb_r| {
182                    cfg_chunks!(poly_comb_r, row_len)
183                        .map(|row| MBSInnerProduct::inner_product_field(row, &q_1, zero_f.clone()))
184                        .collect::<Result<Vec<F>, _>>()
185                })
186                .collect::<Result<_, _>>()?;
187
188            let mut b = vec![zero_f.clone(); num_rows];
189            for poly_b in &per_poly_b {
190                b.iter_mut().zip(poly_b).for_each(|(a, d)| *a += d);
191            }
192            b
193        };
194
195        transcript.write_field_elements(&b)?;
196        // Compute eval = <q_0, b> (inner product in field), <q_2, b> in paper
197        // It is safe to use inner_product_unchecked because we're in a field.
198        let eval = MBSInnerProduct::inner_product::<UNCHECKED>(&q_0, &b, zero_f.clone())?;
199
200        // Matrix-vector product over the flat poly_comb_r layout:
201        // Each poly is a row-major (num_rows x row_len) matrix, and coeffs is the
202        // vector.
203        // combined_row[col] = sum_i sum_j (coeffs[j] * poly_i[j * row_len + col])
204
205        let coeffs = if pp.num_rows == 1 {
206            vec![Zt::Chal::ONE]
207        } else {
208            transcript
209                .fs_transcript
210                .get_challenges::<Zt::Chal>(num_rows)
211        };
212
213        let combined_row: Vec<Zt::CombR> = {
214            let mut combined = vec![Zt::CombR::ZERO; row_len];
215            cfg_iter_mut!(combined).enumerate().try_for_each(
216                |(col, acc)| -> Result<(), ZipError> {
217                    for poly_comb_r in &polys_as_comb_r {
218                        // Strided access: skip to column `col`, then step by `row_len`
219                        // to pick the col-th entry of each logical row.
220                        for (eval, coeff) in poly_comb_r
221                            .iter()
222                            .skip(col)
223                            .step_by(row_len)
224                            .zip(coeffs.iter())
225                        {
226                            let scaled: Zt::CombR = eval
227                                .mul_by_scalar::<CHECK_FOR_OVERFLOW>(coeff)
228                                .expect("Cannot multiply evaluation by coefficient");
229                            if CHECK_FOR_OVERFLOW {
230                                *acc = zinc_utils::add!(
231                                    *acc,
232                                    &scaled,
233                                    "Addition overflow while combining rows across polys"
234                                );
235                            } else {
236                                *acc += scaled;
237                            }
238                        }
239                    }
240                    Ok(())
241                },
242            )?;
243            combined
244        };
245
246        transcript.write_const_many(&combined_row)?;
247        for _ in 0..Zt::NUM_COLUMN_OPENINGS {
248            let column_idx = transcript.squeeze_challenge_idx(pp.linear_code.codeword_len());
249            Self::open_merkle_trees_for_column(transcript, commit_hint, column_idx)?;
250        }
251
252        Ok(eval)
253    }
254
255    /// See [`Self::prove`] for details.
256    #[inline(always)]
257    pub fn prove_single<F, const CHECK_FOR_OVERFLOW: bool>(
258        transcript: &mut PcsProverTranscript,
259        pp: &ZipPlusParams<Zt, Lc>,
260        poly: &DenseMultilinearExtension<Zt::Eval>,
261        point: &[Zt::Pt],
262        commit_hint: &ZipPlusHint<Zt::Cw>,
263        field_cfg: &F::Config,
264    ) -> Result<F, ZipError>
265    where
266        F: PrimeField
267            + for<'a> FromWithConfig<&'a Zt::CombR>
268            + for<'a> FromWithConfig<&'a Zt::Chal>
269            + for<'a> FromWithConfig<&'a Zt::Pt>
270            + for<'a> MulByScalar<&'a F>
271            + FromRef<F>,
272        F::Inner: Transcribable,
273        F::Modulus: FromRef<Zt::Fmod> + Transcribable,
274    {
275        Self::prove::<F, CHECK_FOR_OVERFLOW>(
276            transcript,
277            pp,
278            std::slice::from_ref(poly),
279            point,
280            commit_hint,
281            field_cfg,
282        )
283    }
284
285    pub(super) fn open_merkle_trees_for_column(
286        transcript: &mut PcsProverTranscript,
287        commit_hint: &ZipPlusHint<Zt::Cw>,
288        column_idx: usize,
289    ) -> Result<(), ZipError> {
290        for cw_matrix in &commit_hint.cw_matrices {
291            let column_values = cw_matrix.as_rows().map(|row| &row[column_idx]);
292            transcript.write_const_many_iter(column_values, cw_matrix.num_rows)?;
293        }
294
295        let merkle_proof = commit_hint
296            .merkle_tree
297            .prove(column_idx)
298            .map_err(|_| ZipError::InvalidPcsOpen("Failed to open merkle tree".into()))?;
299        transcript
300            .write_merkle_proof(&merkle_proof)
301            .map_err(|_| ZipError::InvalidPcsOpen("Failed to write a merkle tree proof".into()))?;
302
303        Ok(())
304    }
305}
306
307#[cfg(test)]
308#[allow(
309    clippy::arithmetic_side_effects,
310    clippy::cast_possible_truncation,
311    clippy::cast_possible_wrap
312)]
313mod tests {
314    use crate::{
315        code::iprs::IprsCode,
316        merkle::MerkleTree,
317        pcs::{
318            structs::{ZipPlus, ZipPlusHint},
319            test_utils::*,
320        },
321        pcs_transcript::PcsProverTranscript,
322    };
323    use crypto_bigint::U64;
324    use crypto_primitives::{
325        IntoWithConfig, crypto_bigint_boxed_monty::BoxedMontyField, crypto_bigint_int::Int,
326    };
327    use num_traits::{ConstOne, Zero};
328    use zinc_poly::mle::DenseMultilinearExtension;
329    use zinc_utils::{CHECKED, from_ref::FromRef};
330
331    const INT_LIMBS: usize = U64::LIMBS;
332
333    const N: usize = INT_LIMBS;
334    const K: usize = INT_LIMBS * 4;
335    const M: usize = INT_LIMBS * 8;
336    const DEGREE_PLUS_ONE: usize = 3;
337
338    type F = BoxedMontyField;
339
340    type Zt = TestZipTypes<N, K, M>;
341    type C = IprsCode<Zt, TestIprsConfig, REP_FACTOR, CHECKED>;
342
343    type PolyZt = TestBinPolyZipTypes<K, M, DEGREE_PLUS_ONE>;
344    type PolyC = IprsCode<PolyZt, TestIprsConfig, REP_FACTOR, CHECKED>;
345
346    type TestZip = ZipPlus<Zt, C>;
347    type TestPolyZip = ZipPlus<PolyZt, PolyC>;
348
349    fn test_point(num_vars: usize) -> Vec<Int<INT_LIMBS>> {
350        (0..num_vars).map(|i| Int::from(i as i32 + 2)).collect()
351    }
352
353    #[test]
354    fn prove_succeeds_for_single_poly() {
355        let num_vars = 10;
356        let (pp, poly) = setup_test_params(num_vars);
357        let (hint, comm) = TestZip::commit_single(&pp, &poly).unwrap();
358        let point = test_point(num_vars);
359
360        let mut transcript = PcsProverTranscript::new_from_commitment(&comm);
361        let field_cfg = get_field_cfg::<Zt, F>(&mut transcript.fs_transcript);
362
363        let result = TestZip::prove_single::<F, CHECKED>(
364            &mut transcript,
365            &pp,
366            &poly,
367            &point,
368            &hint,
369            &field_cfg,
370        );
371        assert!(result.is_ok());
372    }
373
374    #[test]
375    fn prove_succeeds_for_poly_type() {
376        let num_vars = 10;
377        let (pp, poly) = setup_poly_test_params(num_vars);
378        let (hint, comm) = TestPolyZip::commit_single(&pp, &poly).unwrap();
379        let point: Vec<i128> = (0..num_vars).map(|i| i as i128 + 2).collect();
380
381        let mut transcript = PcsProverTranscript::new_from_commitment(&comm);
382        let field_cfg = get_field_cfg::<Zt, F>(&mut transcript.fs_transcript);
383
384        let result = TestPolyZip::prove_single::<F, CHECKED>(
385            &mut transcript,
386            &pp,
387            &poly,
388            &point,
389            &hint,
390            &field_cfg,
391        );
392        assert!(result.is_ok());
393    }
394
395    #[test]
396    fn prove_succeeds_with_corrupted_codeword() {
397        let num_vars = 10;
398        let (pp, poly) = setup_test_params(num_vars);
399        let (mut hint, comm) = TestZip::commit_single(&pp, &poly).unwrap();
400
401        {
402            let mut rows = hint.cw_matrices[0].to_rows_slices_mut();
403            assert!(!rows.is_empty());
404            rows[0][0] += Int::ONE;
405        }
406
407        let corrupted_tree = {
408            let all_rows: Vec<&[_]> = hint.cw_matrices.iter().flat_map(|m| m.as_rows()).collect();
409            MerkleTree::new(&all_rows)
410        };
411        let corrupted_hint = ZipPlusHint::new(hint.cw_matrices, corrupted_tree);
412
413        let point = test_point(num_vars);
414
415        let mut transcript = PcsProverTranscript::new_from_commitment(&comm);
416        let field_cfg = get_field_cfg::<Zt, F>(&mut transcript.fs_transcript);
417
418        let result = TestZip::prove_single::<F, CHECKED>(
419            &mut transcript,
420            &pp,
421            &poly,
422            &point,
423            &corrupted_hint,
424            &field_cfg,
425        );
426        assert!(result.is_ok());
427    }
428
429    #[test]
430    fn prove_rejects_oversized_polynomial() {
431        let num_vars = 10;
432        let (pp, _) = setup_test_params(num_vars);
433        let oversized_poly: DenseMultilinearExtension<_> =
434            (0..1 << (num_vars + 1)).map(Int::from).collect();
435
436        let (hint, comm) =
437            TestZip::commit_single(&pp, &setup_test_params::<N, K, M>(num_vars).1).unwrap();
438
439        let point = test_point(num_vars);
440
441        let mut transcript = PcsProverTranscript::new_from_commitment(&comm);
442        let field_cfg = get_field_cfg::<Zt, F>(&mut transcript.fs_transcript);
443
444        let result = TestZip::prove_single::<F, CHECKED>(
445            &mut transcript,
446            &pp,
447            &oversized_poly,
448            &point,
449            &hint,
450            &field_cfg,
451        );
452        assert!(result.is_err());
453    }
454
455    /// For TestZipTypes (degree_bound = 0), alphas = [1] so prove eval
456    /// equals poly(point) lifted to F.
457    #[test]
458    fn prove_returns_correct_evaluation() {
459        let num_vars = 10;
460        let (pp, poly) = setup_test_params(num_vars);
461        let (hint, comm) = TestZip::commit_single(&pp, &poly).unwrap();
462        let point = test_point(num_vars);
463
464        let mut transcript = PcsProverTranscript::new_from_commitment(&comm);
465        let field_cfg = get_field_cfg::<Zt, F>(&mut transcript.fs_transcript);
466
467        let eval_f = TestZip::prove_single::<F, CHECKED>(
468            &mut transcript,
469            &pp,
470            &poly,
471            &point,
472            &hint,
473            &field_cfg,
474        )
475        .unwrap();
476
477        let poly_wide: DenseMultilinearExtension<Int<M>> =
478            poly.evaluations.iter().map(Int::from_ref).collect();
479        let expected_int = poly_wide.evaluate(&point, Zero::zero()).unwrap();
480        let expected_f: F = (&expected_int).into_with_cfg(&field_cfg);
481
482        assert_eq!(eval_f, expected_f);
483    }
484
485    fn make_batch_polys(
486        num_vars: usize,
487        batch_size: usize,
488    ) -> Vec<DenseMultilinearExtension<Int<INT_LIMBS>>> {
489        let poly_size = 1 << num_vars;
490        (0..batch_size)
491            .map(|b| {
492                let base = (b * poly_size) as i32;
493                (base + 1..=base + poly_size as i32)
494                    .map(Int::from)
495                    .collect()
496            })
497            .collect()
498    }
499
500    #[test]
501    fn prove_succeeds_for_batch() {
502        let num_vars = 10;
503        let (pp, _) = setup_test_params(num_vars);
504        let polys = make_batch_polys(num_vars, 2);
505
506        let (hint, comm) = TestZip::commit(&pp, &polys).unwrap();
507        let point = test_point(num_vars);
508
509        let mut transcript = PcsProverTranscript::new_from_commitment(&comm);
510        let field_cfg = get_field_cfg::<Zt, F>(&mut transcript.fs_transcript);
511
512        let result =
513            TestZip::prove::<F, CHECKED>(&mut transcript, &pp, &polys, &point, &hint, &field_cfg);
514        assert!(result.is_ok())
515    }
516
517    #[test]
518    fn prove_succeeds_for_batch_5() {
519        let num_vars = 10;
520        let (pp, _) = setup_test_params(num_vars);
521        let polys = make_batch_polys(num_vars, 5);
522
523        let (hint, comm) = TestZip::commit(&pp, &polys).unwrap();
524        let point = test_point(num_vars);
525
526        let mut transcript = PcsProverTranscript::new_from_commitment(&comm);
527        let field_cfg = get_field_cfg::<Zt, F>(&mut transcript.fs_transcript);
528
529        let result =
530            TestZip::prove::<F, CHECKED>(&mut transcript, &pp, &polys, &point, &hint, &field_cfg);
531        assert!(result.is_ok())
532    }
533
534    #[test]
535    fn prove_with_corrupted_codeword_for_batch() {
536        let num_vars = 10;
537        let (pp, _) = setup_test_params(num_vars);
538        let polys = make_batch_polys(num_vars, 2);
539
540        let (mut hint, comm) = TestZip::commit(&pp, &polys).unwrap();
541
542        hint.cw_matrices[0].to_rows_slices_mut()[0][0] += Int::ONE;
543
544        let corrupted_tree = {
545            let all_rows: Vec<&[_]> = hint.cw_matrices.iter().flat_map(|m| m.as_rows()).collect();
546            MerkleTree::new(&all_rows)
547        };
548        let corrupted_hint = ZipPlusHint::new(hint.cw_matrices, corrupted_tree);
549
550        let point = test_point(num_vars);
551
552        let mut transcript = PcsProverTranscript::new_from_commitment(&comm);
553        let field_cfg = get_field_cfg::<Zt, F>(&mut transcript.fs_transcript);
554
555        let result = TestZip::prove::<F, CHECKED>(
556            &mut transcript,
557            &pp,
558            &polys,
559            &point,
560            &corrupted_hint,
561            &field_cfg,
562        );
563        assert!(result.is_ok());
564    }
565
566    #[test]
567    fn prove_rejects_oversized_polynomial_in_batch() {
568        let num_vars = 10;
569        let (pp, _) = setup_test_params(num_vars);
570        let oversized: DenseMultilinearExtension<_> = (0..1 << 5).map(Int::from).collect();
571        let normal: DenseMultilinearExtension<_> = (1..=16).map(Int::from).collect();
572        let polys = vec![normal, oversized];
573
574        let (hint, comm) = TestZip::commit(&pp, &make_batch_polys(num_vars, 2)).unwrap();
575
576        let point = test_point(num_vars);
577
578        let mut transcript = PcsProverTranscript::new_from_commitment(&comm);
579        let field_cfg = get_field_cfg::<Zt, F>(&mut transcript.fs_transcript);
580
581        let result =
582            TestZip::prove::<F, CHECKED>(&mut transcript, &pp, &polys, &point, &hint, &field_cfg);
583        assert!(result.is_err());
584    }
585}