1mod folder;
4mod structs;
5
6pub use structs::*;
7
8use crate::{
9 CombFn,
10 combined_poly_resolver::{
11 folder::ConstraintFolder,
12 structs::{Proof as CprProof, ProverState as CprProverState},
13 },
14 ideal_check,
15 sumcheck::{
16 SumCheckError, multi_degree::MultiDegreeSumcheckGroup,
17 prover::ProverState as SumcheckProverState,
18 },
19};
20use crypto_primitives::{FromPrimitiveWithConfig, PrimeField};
21use itertools::Itertools;
22use num_traits::Zero;
23#[cfg(feature = "parallel")]
24use rayon::prelude::*;
25use std::{collections::HashMap, marker::PhantomData, slice};
26use thiserror::Error;
27use zinc_poly::{
28 EvaluationError,
29 mle::{DenseMultilinearExtension, MultilinearExtensionWithConfig},
30 univariate::dynamic::over_field::{DynamicPolyFInnerProduct, DynamicPolynomialF},
31 utils::{ArithErrors, build_eq_x_r_inner, eq_eval},
32};
33use zinc_transcript::traits::{ConstTranscribable, Transcript};
34use zinc_uair::{TraceRow, Uair, ideal::ImpossibleIdeal};
35use zinc_utils::{
36 UNCHECKED, add, cfg_iter, from_ref::FromRef, inner_product::InnerProduct,
37 inner_transparent_field::InnerTransparentField, powers,
38};
39
40pub struct CombinedPolyResolver<F: InnerTransparentField>(PhantomData<F>);
41
42impl<F: InnerTransparentField + FromPrimitiveWithConfig + Send + Sync> CombinedPolyResolver<F> {
43 #[allow(clippy::arithmetic_side_effects, clippy::too_many_arguments)]
71 pub fn prepare_sumcheck_group<U>(
72 transcript: &mut impl Transcript,
73 trace_matrix: Vec<DenseMultilinearExtension<F::Inner>>,
74 evaluation_point: &[F],
75 projected_scalars: &HashMap<U::Scalar, F>,
76 num_constraints: usize,
77 num_vars: usize,
78 max_degree: usize,
79 field_cfg: &F::Config,
80 ) -> Result<(MultiDegreeSumcheckGroup<F>, CprProverAncillary), CombinedPolyResolverError<F>>
81 where
82 F::Inner: ConstTranscribable + Send + Sync + Zero + Default,
83 F::Modulus: ConstTranscribable,
84 F: 'static,
85 U::Scalar: 'static,
86 U: Uair,
87 {
88 debug_assert_ne!(
89 num_vars, 1,
90 "The protocol is not needed when the number of variables is 1 :)"
91 );
92
93 let zero = F::zero_with_cfg(field_cfg);
94 let one = F::one_with_cfg(field_cfg);
95
96 let uair_sig = U::signature();
102 let zero_inner = zero.clone().into_inner();
103 let n = 1usize << num_vars;
104 let down: Vec<DenseMultilinearExtension<F::Inner>> = cfg_iter!(uair_sig.shifts())
105 .map(|spec| {
106 let mut evals = trace_matrix[spec.source_col()][spec.shift_amount()..].to_vec();
107 evals.resize(n, zero_inner.clone());
108 DenseMultilinearExtension {
109 evaluations: evals,
110 num_vars,
111 }
112 })
113 .collect();
114
115 let eq_r = build_eq_x_r_inner(evaluation_point, field_cfg)?;
116 let last_row_selector = DenseMultilinearExtension {
120 num_vars,
121 evaluations: {
122 let mut evals = vec![zero.inner().clone(); 1 << num_vars];
123 evals[(1 << num_vars) - 1] = one.inner().clone();
124 evals
125 },
126 };
127
128 let folding_challenge: F = transcript.get_field_challenge(field_cfg);
130
131 let folding_challenge_powers: Vec<F> =
132 powers(folding_challenge, one.clone(), num_constraints);
133
134 let num_cols = trace_matrix.len();
135 let num_down_cols = down.len();
136 let mles: Vec<DenseMultilinearExtension<F::Inner>> = {
137 let mut mles = Vec::with_capacity(2 + num_cols + num_down_cols);
138
139 mles.push(last_row_selector);
140 mles.push(eq_r);
141
142 mles.extend(trace_matrix);
143 mles.extend(down);
144
145 mles
146 };
147
148 let projected_scalars = projected_scalars.clone();
149 let comb_fn: CombFn<F> = Box::new(move |mle_values: &[F]| {
150 let uair_sig = U::signature();
151 let up_layout = uair_sig.total_cols().as_column_layout();
152 let down_layout = uair_sig.down_cols().as_column_layout();
153
154 let selector = &mle_values[0];
155 let eq_r = &mle_values[1];
156
157 let mut folder = ConstraintFolder::new(&folding_challenge_powers, &zero);
158
159 let project = |scalar: &U::Scalar| {
160 projected_scalars
161 .get(scalar)
162 .cloned()
163 .expect("all scalars should have been projected at this point")
164 };
165
166 U::constrain_general(
167 &mut folder,
168 TraceRow::from_slice_with_layout(&mle_values[2..num_cols + 2], up_layout),
169 TraceRow::from_slice_with_layout(&mle_values[num_cols + 2..], down_layout),
170 project,
171 |x, y| Some(project(y) * x),
172 ImpossibleIdeal::from_ref,
173 );
174
175 folder.folded_constraints * (one.clone() - selector) * eq_r
176 });
177
178 Ok((
179 MultiDegreeSumcheckGroup::new(max_degree + 2, mles, comb_fn),
180 CprProverAncillary {
181 num_cols,
182 num_down_cols,
183 num_vars,
184 },
185 ))
186 }
187
188 #[allow(clippy::arithmetic_side_effects)]
198 pub fn finalize_prover(
199 transcript: &mut impl Transcript,
200 sumcheck_prover_state: SumcheckProverState<F>,
201 ancillary: CprProverAncillary,
202 field_cfg: &F::Config,
203 ) -> Result<(CprProof<F>, CprProverState<F>), CombinedPolyResolverError<F>>
204 where
205 F::Inner: ConstTranscribable + Zero,
206 F::Modulus: ConstTranscribable,
207 {
208 debug_assert!(
214 sumcheck_prover_state
215 .mles
216 .iter()
217 .all(|mle| mle.num_vars == 1)
218 );
219
220 let last_sumcheck_challenge = sumcheck_prover_state
221 .randomness
222 .last()
223 .expect("sumcheck could not have had 0 rounds");
224
225 let mut mles = sumcheck_prover_state.mles;
226 let evals: Vec<F> = mles
227 .drain(2..)
228 .map(|mle| {
229 mle.evaluate_with_config(slice::from_ref(last_sumcheck_challenge), field_cfg)
230 })
231 .try_collect()?;
232
233 debug_assert_eq!(evals.len(), ancillary.num_cols + ancillary.num_down_cols);
234 let mut transcription_buf: Vec<u8> = vec![0; F::Inner::NUM_BYTES];
235 transcript.absorb_random_field_slice(&evals, &mut transcription_buf);
236 let (up_evals, down_evals) = (
237 evals[0..ancillary.num_cols].to_vec(),
238 evals[ancillary.num_cols..].to_vec(),
239 );
240 Ok((
241 CprProof {
242 up_evals,
243 down_evals,
244 },
245 CprProverState {
246 evaluation_point: sumcheck_prover_state.randomness,
247 },
248 ))
249 }
250
251 #[allow(clippy::too_many_arguments)]
268 pub fn prepare_verifier<U>(
269 transcript: &mut impl Transcript,
270 proof: &CprProof<F>,
271 claimed_sum: F,
272 ic_check_subclaim: &ideal_check::VerifierSubclaim<F>,
273 num_constraints: usize,
274 num_vars: usize,
275 projecting_element: &F,
276 field_cfg: &F::Config,
277 ) -> Result<CprVerifierAncillary<F>, CombinedPolyResolverError<F>>
278 where
279 F::Inner: ConstTranscribable,
280 F::Modulus: ConstTranscribable,
281 U: Uair,
282 {
283 let uair_sig = U::signature();
284 proof
285 .validate_evaluation_sizes(uair_sig.total_cols().cols(), uair_sig.down_cols().cols())?;
286
287 let zero = F::zero_with_cfg(field_cfg);
288 let one = F::one_with_cfg(field_cfg);
289
290 let projection_powers: Vec<F> = {
292 let max_coeffs_len = ic_check_subclaim
293 .values
294 .iter()
295 .map(|poly| poly.degree().map_or(0, |d| add!(d, 1)))
296 .max()
297 .unwrap_or(0)
298 .max(1);
299 powers(projecting_element.clone(), one.clone(), max_coeffs_len)
300 };
301
302 let folding_challenge: F = transcript.get_field_challenge(field_cfg);
303
304 let folding_challenge_powers: Vec<F> =
305 powers(folding_challenge, one.clone(), num_constraints);
306
307 let expected_sum = ic_check_subclaim
310 .values
311 .iter()
312 .zip(&folding_challenge_powers)
313 .map(|(claimed_value, random_coeff)| {
314 let deg = claimed_value.degree().map_or(0, |d| add!(d, 1));
315 DynamicPolyFInnerProduct::inner_product::<UNCHECKED>(
316 &claimed_value.coeffs[..deg],
317 &projection_powers[..deg],
318 zero.clone(),
319 )
320 .expect("inner product cannot fail here")
321 * random_coeff
322 })
323 .fold(zero.clone(), |acc, term| acc + term);
324
325 if claimed_sum != expected_sum {
326 return Err(CombinedPolyResolverError::WrongSumcheckSum {
327 got: claimed_sum,
328 expected: expected_sum,
329 });
330 }
331
332 Ok(CprVerifierAncillary {
333 folding_challenge_powers,
334 ic_evaluation_point: ic_check_subclaim.evaluation_point.clone(),
335 num_vars,
336 })
337 }
338
339 #[allow(clippy::too_many_arguments)]
356 pub fn finalize_verifier<U>(
357 transcript: &mut impl Transcript,
358 proof: CprProof<F>,
359 shared_point: Vec<F>,
360 expected_evaluation: F,
361 ancillary: CprVerifierAncillary<F>,
362 projected_scalars: &HashMap<U::Scalar, F>,
363 field_cfg: &F::Config,
364 ) -> Result<VerifierSubclaim<F>, CombinedPolyResolverError<F>>
365 where
366 F::Inner: ConstTranscribable,
367 F::Modulus: ConstTranscribable,
368 U: Uair,
369 {
370 let uair_sig = U::signature();
371 let down_layout = uair_sig.down_cols().as_column_layout();
372 let zero = F::zero_with_cfg(field_cfg);
373 let one = F::one_with_cfg(field_cfg);
374
375 let eq_r_value = eq_eval(&shared_point, &ancillary.ic_evaluation_point, one.clone())?;
376 let selector_value = eq_eval(
377 &shared_point,
378 &vec![one.clone(); ancillary.num_vars],
379 one.clone(),
380 )?;
381
382 let mut folder = ConstraintFolder::new(&ancillary.folding_challenge_powers, &zero);
383
384 let project = |scalar: &U::Scalar| {
385 projected_scalars
386 .get(scalar)
387 .cloned()
388 .expect("all scalars should have been projected at this point")
389 };
390
391 U::constrain_general(
392 &mut folder,
393 TraceRow::from_slice_with_layout(
394 &proof.up_evals,
395 uair_sig.total_cols().as_column_layout(),
396 ),
397 TraceRow::from_slice_with_layout(&proof.down_evals, down_layout),
398 project,
399 |x, y| Some(project(y) * x),
400 ImpossibleIdeal::from_ref,
401 );
402
403 let expected_claim_value = eq_r_value * (one - selector_value) * folder.folded_constraints;
404
405 if expected_claim_value != expected_evaluation {
406 return Err(CombinedPolyResolverError::ClaimValueDoesNotMatch {
407 got: expected_evaluation,
408 expected: expected_claim_value,
409 });
410 }
411
412 let mut transcription_buf: Vec<u8> = vec![0; F::Inner::NUM_BYTES];
413 transcript.absorb_random_field_slice(&proof.up_evals, &mut transcription_buf);
414 transcript.absorb_random_field_slice(&proof.down_evals, &mut transcription_buf);
415
416 Ok(VerifierSubclaim {
417 up_evals: proof.up_evals,
418 down_evals: proof.down_evals,
419 evaluation_point: shared_point,
420 })
421 }
422}
423
424#[derive(Debug, Error)]
425pub enum CombinedPolyResolverError<F: PrimeField> {
426 #[error("failed to build eq_r: {0}")]
427 EqrError(ArithErrors),
428 #[error("error evaluating MLE: {0}")]
429 MleEvaluationError(EvaluationError),
430 #[error("error projecting polynomial {0} by point {1}: {2}")]
431 ProjectionError(DynamicPolynomialF<F>, F, EvaluationError),
432 #[error("wrong trace columns evaluations number: got {got}, expected {expected}")]
433 WrongUpEvalsNumber { got: usize, expected: usize },
434 #[error("wrong shifted trace columns evaluations number: got {got}, expected {expected}")]
435 WrongDownEvalsNumber { got: usize, expected: usize },
436 #[error("sumcheck verification failed: {0}")]
437 SumcheckError(SumCheckError<F>),
438 #[error("wrong sumcheck claimed sum: received {got}, expected {expected}")]
439 WrongSumcheckSum { got: F, expected: F },
440 #[error("resulting claim value does not match: received {got}, expected {expected}")]
441 ClaimValueDoesNotMatch { got: F, expected: F },
442}
443
444impl<F: PrimeField> From<EvaluationError> for CombinedPolyResolverError<F> {
445 fn from(eval_error: EvaluationError) -> Self {
446 Self::MleEvaluationError(eval_error)
447 }
448}
449
450impl<F: PrimeField> From<ArithErrors> for CombinedPolyResolverError<F> {
451 fn from(arith_error: ArithErrors) -> Self {
452 Self::EqrError(arith_error)
453 }
454}
455
456impl<F: PrimeField> From<SumCheckError<F>> for CombinedPolyResolverError<F> {
457 fn from(sumcheck_error: SumCheckError<F>) -> Self {
458 Self::SumcheckError(sumcheck_error)
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465 use crate::{
466 ideal_check::IdealCheckProtocol,
467 projections::{ProjectedTrace, evaluate_trace_to_column_mles, project_scalars_to_field},
468 sumcheck::multi_degree::MultiDegreeSumcheck,
469 test_utils::{LIMBS, run_ideal_check_prover_combined, test_config},
470 };
471 use crypto_primitives::{crypto_bigint_int::Int, crypto_bigint_monty::MontyField};
472 use rand::rng;
473 use zinc_poly::univariate::dense::DensePolynomial;
474 use zinc_test_uair::{
475 GenerateRandomTrace, TestUairNoMultiplication, TestUairSimpleMultiplication,
476 };
477 use zinc_transcript::Blake3Transcript;
478 use zinc_uair::{
479 constraint_counter::count_constraints,
480 degree_counter::count_max_degree,
481 ideal::{DegreeOneIdeal, Ideal, IdealCheck},
482 ideal_collector::IdealOrZero,
483 };
484
485 fn test_successful_verification_generic<
490 U,
491 IdealOverF,
492 IdealOverFFromRef,
493 const DEGREE_PLUS_ONE: usize,
494 >(
495 num_vars: usize,
496 ideal_over_f_from_ref: IdealOverFFromRef,
497 ) where
498 U: Uair<Scalar = DensePolynomial<Int<5>, DEGREE_PLUS_ONE>>
499 + GenerateRandomTrace<DEGREE_PLUS_ONE, PolyCoeff = Int<5>, Int = Int<5>>
500 + IdealCheckProtocol,
501 IdealOverF: Ideal + IdealCheck<DynamicPolynomialF<MontyField<LIMBS>>>,
502 IdealOverFFromRef: Fn(&IdealOrZero<U::Ideal>) -> IdealOverF,
503 {
504 let mut rng = rng();
505
506 let mut prover_transcript = Blake3Transcript::new();
507 let mut verifier_transcript = prover_transcript.clone();
508
509 let trace = U::generate_random_trace(num_vars, &mut rng);
510
511 let (ic_proof, ic_prover_state, projected_scalars, projected_trace) =
512 run_ideal_check_prover_combined::<U, DEGREE_PLUS_ONE>(
513 num_vars,
514 &trace,
515 &mut prover_transcript,
516 );
517
518 let num_constraints = count_constraints::<U>();
519
520 let ic_check_subclaim = U::verify_as_subprotocol(
521 &mut verifier_transcript,
522 ic_proof,
523 num_constraints,
524 num_vars,
525 ideal_over_f_from_ref,
526 &test_config(),
527 )
528 .expect("Verification failed");
529
530 let max_degree = count_max_degree::<U>();
531
532 let projecting_element: MontyField<4> =
533 prover_transcript.get_field_challenge(&test_config());
534
535 let projected_scalars =
536 project_scalars_to_field(projected_scalars, &projecting_element).unwrap();
537
538 let (cpr_group, cpr_ancillary) = CombinedPolyResolver::prepare_sumcheck_group::<U>(
540 &mut prover_transcript,
541 evaluate_trace_to_column_mles(
542 &ProjectedTrace::RowMajor(projected_trace),
543 &projecting_element,
544 ),
545 &ic_prover_state.evaluation_point,
546 &projected_scalars,
547 num_constraints,
548 num_vars,
549 max_degree,
550 &test_config(),
551 )
552 .expect("CPR prepare failed");
553
554 let (md_proof, states) = MultiDegreeSumcheck::prove_as_subprotocol(
555 &mut prover_transcript,
556 vec![cpr_group],
557 num_vars,
558 &test_config(),
559 );
560
561 let (proof, _) = CombinedPolyResolver::finalize_prover(
562 &mut prover_transcript,
563 states.into_iter().next().unwrap(),
564 cpr_ancillary,
565 &test_config(),
566 )
567 .expect("CPR finalize failed");
568
569 let projecting_element: MontyField<LIMBS> =
570 verifier_transcript.get_field_challenge(&test_config());
571
572 let cpr_verifier_ancillary = CombinedPolyResolver::prepare_verifier::<U>(
574 &mut verifier_transcript,
575 &proof,
576 md_proof.claimed_sums()[0].clone(),
577 &ic_check_subclaim,
578 num_constraints,
579 num_vars,
580 &projecting_element,
581 &test_config(),
582 )
583 .expect("CPR prepare_verifier failed");
584
585 let md_subclaims = MultiDegreeSumcheck::verify_as_subprotocol(
586 &mut verifier_transcript,
587 num_vars,
588 &md_proof,
589 &test_config(),
590 )
591 .expect("MultiDegreeSumcheck verify failed");
592
593 assert!(
594 CombinedPolyResolver::finalize_verifier::<U>(
595 &mut verifier_transcript,
596 proof,
597 md_subclaims.point().to_vec(),
598 md_subclaims.expected_evaluations()[0].clone(),
599 cpr_verifier_ancillary,
600 &projected_scalars,
601 &test_config(),
602 )
603 .is_ok()
604 );
605 }
606
607 #[test]
608 fn test_successful_verification() {
609 let field_cfg = test_config();
610
611 let num_vars = 2;
612
613 test_successful_verification_generic::<TestUairNoMultiplication<Int<5>>, _, _, 32>(
614 num_vars,
615 |ideal_over_ring| ideal_over_ring.map(|i| DegreeOneIdeal::from_with_cfg(i, &field_cfg)),
616 );
617 test_successful_verification_generic::<TestUairSimpleMultiplication<Int<5>>, _, _, 32>(
618 num_vars,
619 |_ideal_over_ring| IdealOrZero::<DegreeOneIdeal<_>>::zero(),
620 );
621 }
622}