1#[cfg(feature = "parallel")]
31use rayon::prelude::*;
32
33use crate::{
34 shift_predicate::eval_shift_predicate,
35 sumcheck::{MLSumcheck, SumCheckError, SumcheckProof, verifier::Subclaim as SumcheckSubclaim},
36};
37use crypto_primitives::{FromPrimitiveWithConfig, PrimeField};
38use num_traits::Zero;
39use std::marker::PhantomData;
40use thiserror::Error;
41use zinc_poly::{
42 mle::DenseMultilinearExtension,
43 utils::{ArithErrors, build_eq_x_r_inner, build_next_c_r_mle},
44};
45use zinc_transcript::{
46 delegate_transcribable,
47 traits::{ConstTranscribable, Transcript},
48};
49use zinc_uair::ShiftSpec;
50use zinc_utils::{cfg_into_iter, inner_transparent_field::InnerTransparentField};
51
52#[derive(Clone, Debug, PartialEq, Eq)]
62pub struct Proof<F: PrimeField> {
63 pub sumcheck_proof: SumcheckProof<F>,
65}
66
67delegate_transcribable!(Proof<F> { sumcheck_proof: SumcheckProof<F> }
68 where F: PrimeField, F::Inner: ConstTranscribable, F::Modulus: ConstTranscribable);
69
70pub struct ProverState<F: PrimeField> {
72 pub eval_point: Vec<F>,
74}
75
76#[derive(Clone, Debug)]
83pub struct Subclaim<F: PrimeField> {
84 pub sumcheck_subclaim: SumcheckSubclaim<F>,
88 pub gammas: Vec<F>,
90 pub alphas: Vec<F>,
92 pub eq_at_r0: F,
94 pub shifts_at_r0: Vec<F>,
97}
98
99pub struct MultipointEval<F>(PhantomData<F>);
104
105impl<F> MultipointEval<F>
106where
107 F: InnerTransparentField + FromPrimitiveWithConfig + Send + Sync + 'static,
108 F::Inner: ConstTranscribable + Zero + Default + Send + Sync,
109 F::Modulus: ConstTranscribable,
110{
111 #[allow(clippy::arithmetic_side_effects)]
119 pub fn prove_as_subprotocol(
120 transcript: &mut impl Transcript,
121 trace_mles: &[DenseMultilinearExtension<F::Inner>],
122 eval_point: &[F],
123 up_evals: &[F],
124 down_evals: &[F],
125 shifts: &[ShiftSpec],
126 field_cfg: &F::Config,
127 ) -> Result<(Proof<F>, ProverState<F>), MultipointEvalError<F>> {
128 let num_cols = trace_mles.len();
129 let num_down_cols = shifts.len();
130 let num_vars = eval_point.len();
131 let zero = F::zero_with_cfg(field_cfg);
132 let zero_inner = zero.inner();
133
134 let alphas: Vec<F> = transcript.get_field_challenges(num_down_cols, field_cfg);
137 let gammas: Vec<F> = transcript.get_field_challenges(num_cols, field_cfg);
138
139 let eq_r = build_eq_x_r_inner(eval_point, field_cfg)?;
143 let (next_mles, down_cols): (Vec<_>, Vec<_>) = shifts
144 .iter()
145 .map(|spec| {
146 let next = build_next_c_r_mle(eval_point, spec.shift_amount(), field_cfg)?;
147 let col = trace_mles[spec.source_col()].clone();
148 Ok((next, col))
149 })
150 .collect::<Result<Vec<_>, ArithErrors>>()?
151 .into_iter()
152 .unzip();
153
154 let precombined = {
156 let evaluations: Vec<_> = cfg_into_iter!(0..1 << num_vars)
157 .map(|b| {
158 gammas
159 .iter()
160 .enumerate()
161 .fold(zero.clone(), |acc, (i, gamma)| {
162 let eval_f = F::new_unchecked_with_cfg(
163 trace_mles[i].evaluations[b].clone(),
164 field_cfg,
165 );
166 acc + gamma.clone() * eval_f
167 })
168 .into_inner()
169 })
170 .collect();
171 DenseMultilinearExtension::from_evaluations_vec(
172 num_vars,
173 evaluations,
174 zero_inner.clone(),
175 )
176 };
177
178 let mut mles = Vec::with_capacity(2 + 2 * num_down_cols);
180 mles.push(eq_r);
181 mles.extend(next_mles);
182 mles.push(precombined);
183 mles.extend(down_cols);
184
185 let (sumcheck_proof, sumcheck_prover_state) = MLSumcheck::prove_as_subprotocol(
190 transcript,
191 mles,
192 num_vars,
193 2,
194 |mle_values: &[F]| {
195 let eq_val = &mle_values[0];
196 let precombined = &mle_values[num_down_cols + 1];
197 alphas
198 .iter()
199 .enumerate()
200 .fold(eq_val.clone() * precombined, |acc, (i, alpha)| {
201 let next = &mle_values[1 + i];
202 let down_col = &mle_values[num_down_cols + 2 + i];
203 acc + alpha.clone() * next * down_col
204 })
205 },
206 field_cfg,
207 );
208
209 debug_assert_eq!(
211 sumcheck_proof.claimed_sum,
212 compute_expected_sum(up_evals, down_evals, &gammas, &alphas, zero)
213 );
214
215 Ok((
216 Proof { sumcheck_proof },
217 ProverState {
218 eval_point: sumcheck_prover_state.randomness,
219 },
220 ))
221 }
222
223 #[allow(clippy::arithmetic_side_effects, clippy::too_many_arguments)]
232 pub fn verify_as_subprotocol(
233 transcript: &mut impl Transcript,
234 proof: Proof<F>,
235 eval_point: &[F],
236 up_evals: &[F],
237 down_evals: &[F],
238 shifts: &[ShiftSpec],
239 num_vars: usize,
240 field_cfg: &F::Config,
241 ) -> Result<Subclaim<F>, MultipointEvalError<F>> {
242 let num_cols = up_evals.len();
243 let num_down_cols = shifts.len();
244 let zero = F::zero_with_cfg(field_cfg);
245 let one = F::one_with_cfg(field_cfg);
246
247 let alphas: Vec<F> = transcript.get_field_challenges(num_down_cols, field_cfg);
249 let gammas: Vec<F> = transcript.get_field_challenges(num_cols, field_cfg);
250
251 let expected_sum: F =
253 compute_expected_sum(up_evals, down_evals, &gammas, &alphas, zero.clone());
254
255 if proof.sumcheck_proof.claimed_sum != expected_sum {
256 return Err(MultipointEvalError::WrongSumcheckSum {
257 got: proof.sumcheck_proof.claimed_sum.clone(),
258 expected: expected_sum,
259 });
260 }
261
262 let sumcheck_subclaim = MLSumcheck::verify_as_subprotocol(
264 transcript,
265 num_vars,
266 2,
267 &proof.sumcheck_proof,
268 field_cfg,
269 )?;
270
271 let r_0 = &sumcheck_subclaim.point;
272
273 let eq_at_r0 = zinc_poly::utils::eq_eval(r_0, eval_point, one)?;
275 let shifts_at_r0: Vec<F> = shifts
276 .iter()
277 .map(|spec| eval_shift_predicate(eval_point, r_0, spec.shift_amount(), field_cfg))
278 .collect();
279
280 Ok(Subclaim {
281 sumcheck_subclaim,
282 gammas,
283 alphas,
284 eq_at_r0,
285 shifts_at_r0,
286 })
287 }
288
289 #[allow(clippy::arithmetic_side_effects)]
297 pub fn verify_subclaim(
298 subclaim: &Subclaim<F>,
299 open_evals: &[F],
300 shifts: &[ShiftSpec],
301 field_cfg: &F::Config,
302 ) -> Result<(), MultipointEvalError<F>> {
303 let num_cols = subclaim.gammas.len();
304
305 if open_evals.len() != num_cols {
306 return Err(MultipointEvalError::WrongOpenEvalsNumber {
307 got: open_evals.len(),
308 expected: num_cols,
309 });
310 }
311
312 let zero = F::zero_with_cfg(field_cfg);
313
314 let batched_up: F = subclaim
315 .gammas
316 .iter()
317 .zip(open_evals.iter())
318 .fold(zero.clone(), |acc, (gamma, eval)| {
319 acc + gamma.clone() * eval
320 });
321
322 let batched_down: F = subclaim
326 .alphas
327 .iter()
328 .enumerate()
329 .zip(subclaim.shifts_at_r0.iter())
330 .fold(zero, |acc, ((k, alpha), shift_at_r0)| {
331 let src_col = shifts[k].source_col();
332 acc + alpha.clone() * shift_at_r0 * &open_evals[src_col]
333 });
334
335 let expected_evaluation = subclaim.eq_at_r0.clone() * &batched_up + batched_down;
336
337 if expected_evaluation != subclaim.sumcheck_subclaim.expected_evaluation {
338 return Err(MultipointEvalError::ClaimMismatch {
339 got: subclaim.sumcheck_subclaim.expected_evaluation.clone(),
340 expected: expected_evaluation,
341 });
342 }
343
344 Ok(())
345 }
346}
347
348fn compute_expected_sum<F: PrimeField>(
351 up_evals: &[F],
352 down_evals: &[F],
353 gammas: &[F],
354 alphas: &[F],
355 zero: F,
356) -> F {
357 let up_sum = gammas
358 .iter()
359 .zip(up_evals.iter())
360 .fold(zero, |acc, (gamma, up)| acc + gamma.clone() * up);
361
362 alphas
363 .iter()
364 .zip(down_evals.iter())
365 .fold(up_sum, |acc, (alpha, down)| acc + alpha.clone() * down)
366}
367
368#[derive(Debug, Error)]
373pub enum MultipointEvalError<F: PrimeField> {
374 #[error("wrong number of open evaluations: got {got}, expected {expected}")]
375 WrongOpenEvalsNumber { got: usize, expected: usize },
376 #[error("wrong sumcheck claimed sum: got {got}, expected {expected}")]
377 WrongSumcheckSum { got: F, expected: F },
378 #[error("multi-point eval claim mismatch: got {got}, expected {expected}")]
379 ClaimMismatch { got: F, expected: F },
380 #[error("sumcheck error: {0}")]
381 SumcheckError(#[from] SumCheckError<F>),
382 #[error("arithmetic error: {0}")]
383 ArithError(#[from] ArithErrors),
384}
385
386#[cfg(test)]
387#[allow(
388 clippy::arithmetic_side_effects,
389 clippy::cast_possible_truncation,
390 clippy::cast_possible_wrap,
391 clippy::cast_sign_loss
392)]
393mod tests {
394 use super::*;
395 use crypto_bigint::{U128, const_monty_params};
396 use crypto_primitives::crypto_bigint_const_monty::ConstMontyField;
397 use num_traits::{ConstOne, ConstZero};
398 use zinc_poly::mle::{DenseMultilinearExtension, MultilinearExtensionWithConfig};
399 use zinc_transcript::Blake3Transcript;
400
401 const_monty_params!(Params, U128, "00000000b933426489189cb5b47d567f");
402 type F = ConstMontyField<Params, { U128::LIMBS }>;
403
404 #[derive(Clone)]
406 struct SharedSubprotocolInput {
407 eval_point: Vec<F>,
408 up_evals: Vec<F>,
409 down_evals: Vec<F>,
410 shifts: Vec<ShiftSpec>,
411 num_vars: usize,
412 }
413
414 #[derive(Clone)]
416 struct ProverMessage {
417 proof: Proof<F>,
418 open_evals: Vec<F>,
419 }
420
421 fn make_transcript() -> Blake3Transcript {
422 let mut t = Blake3Transcript::default();
423 t.absorb_slice(b"Lorem ipsum");
424 t
425 }
426
427 fn build_trace(
428 num_vars: usize,
429 num_cols: usize,
430 shifts: &[ShiftSpec],
431 ) -> (
432 Vec<DenseMultilinearExtension<<F as crypto_primitives::Field>::Inner>>,
433 SharedSubprotocolInput,
434 ) {
435 let n = 1usize << num_vars;
436 let zero_inner = F::ZERO.into_inner();
437
438 let trace_mles: Vec<DenseMultilinearExtension<_>> = (0..num_cols)
439 .map(|col| {
440 let evals: Vec<_> = (0..n)
441 .map(|i| F::from((col * n + i + 1) as u32).into_inner())
442 .collect();
443 DenseMultilinearExtension::from_evaluations_vec(num_vars, evals, zero_inner)
444 })
445 .collect();
446
447 let eval_point: Vec<F> = (0..num_vars).map(|i| F::from((i + 7) as u32)).collect();
448
449 let up_evals: Vec<F> = trace_mles
450 .iter()
451 .map(|mle| mle.clone().evaluate_with_config(&eval_point, &()).unwrap())
452 .collect();
453
454 let down_evals: Vec<F> = shifts
455 .iter()
456 .map(|spec| {
457 let mle = &trace_mles[spec.source_col()];
458 let c = spec.shift_amount();
459 let mut shifted = mle.evaluations[c..].to_vec();
460 shifted.extend(vec![zero_inner; c]);
461 let shifted_mle =
462 DenseMultilinearExtension::from_evaluations_vec(num_vars, shifted, zero_inner);
463 shifted_mle.evaluate_with_config(&eval_point, &()).unwrap()
464 })
465 .collect();
466
467 let public = SharedSubprotocolInput {
468 eval_point,
469 up_evals,
470 down_evals,
471 shifts: shifts.to_vec(),
472 num_vars,
473 };
474 (trace_mles, public)
475 }
476
477 fn run_prover(
479 trace_mles: &[DenseMultilinearExtension<<F as crypto_primitives::Field>::Inner>],
480 public: &SharedSubprotocolInput,
481 ) -> ProverMessage {
482 let mut transcript = make_transcript();
483 let (proof, prover_state) = MultipointEval::<F>::prove_as_subprotocol(
484 &mut transcript,
485 trace_mles,
486 &public.eval_point,
487 &public.up_evals,
488 &public.down_evals,
489 &public.shifts,
490 &(),
491 )
492 .expect("prover should succeed");
493
494 let r_0 = &prover_state.eval_point;
495 let open_evals: Vec<F> = trace_mles
496 .iter()
497 .map(|mle| mle.clone().evaluate_with_config(r_0, &()).unwrap())
498 .collect();
499
500 ProverMessage { proof, open_evals }
501 }
502
503 fn run_verifier(
505 public: &SharedSubprotocolInput,
506 msg: &ProverMessage,
507 ) -> Result<Subclaim<F>, MultipointEvalError<F>> {
508 let subclaim = MultipointEval::<F>::verify_as_subprotocol(
509 &mut make_transcript(),
510 msg.proof.clone(),
511 &public.eval_point,
512 &public.up_evals,
513 &public.down_evals,
514 &public.shifts,
515 public.num_vars,
516 &(),
517 )?;
518
519 MultipointEval::<F>::verify_subclaim(&subclaim, &msg.open_evals, &public.shifts, &())?;
520
521 Ok(subclaim)
522 }
523
524 fn honest_interaction(
526 num_vars: usize,
527 num_cols: usize,
528 shifts: &[ShiftSpec],
529 ) -> (SharedSubprotocolInput, ProverMessage) {
530 let (trace, public) = build_trace(num_vars, num_cols, shifts);
531 let msg = run_prover(&trace, &public);
532 (public, msg)
533 }
534
535 fn all_shift_by_1(num_cols: usize) -> Vec<ShiftSpec> {
537 (0..num_cols).map(|i| ShiftSpec::new(i, 1)).collect()
538 }
539
540 #[test]
543 fn honest_prove_verify_single_column() {
544 let shifts = all_shift_by_1(1);
545 let (public, msg) = honest_interaction(4, 1, &shifts);
546 run_verifier(&public, &msg).unwrap();
547 }
548
549 #[test]
550 fn honest_prove_verify_many_columns() {
551 let shifts = all_shift_by_1(10);
552 let (public, msg) = honest_interaction(3, 10, &shifts);
553 run_verifier(&public, &msg).unwrap();
554 }
555
556 #[test]
557 fn honest_prove_verify_no_shifts() {
558 let (public, msg) = honest_interaction(3, 3, &[]);
559 run_verifier(&public, &msg).unwrap();
560 }
561
562 #[test]
563 fn honest_prove_verify_mixed_shifts() {
564 let shifts = vec![ShiftSpec::new(0, 1), ShiftSpec::new(1, 3)];
565 let (public, msg) = honest_interaction(4, 3, &shifts);
566 run_verifier(&public, &msg).unwrap();
567 }
568
569 #[test]
570 fn honest_prove_verify_shift_by_3() {
571 let shifts = vec![
572 ShiftSpec::new(0, 3),
573 ShiftSpec::new(1, 3),
574 ShiftSpec::new(2, 3),
575 ];
576 let (public, msg) = honest_interaction(4, 3, &shifts);
577 run_verifier(&public, &msg).unwrap();
578 }
579
580 #[test]
581 fn honest_prove_verify_same_col_different_shifts() {
582 let shifts = vec![ShiftSpec::new(0, 2), ShiftSpec::new(0, 5)];
584 let (public, msg) = honest_interaction(4, 3, &shifts);
585 run_verifier(&public, &msg).unwrap();
586 }
587
588 #[test]
591 fn bad_down_eval_rejected_mixed_shifts() {
592 let shifts = vec![ShiftSpec::new(0, 1), ShiftSpec::new(1, 3)];
593 let (mut public, msg) = honest_interaction(4, 3, &shifts);
594 public.down_evals[0] += F::ONE;
595 let err = run_verifier(&public, &msg).unwrap_err();
596 assert!(
597 matches!(err, MultipointEvalError::WrongSumcheckSum { .. }),
598 "expected WrongSumcheckSum, got {err:?}",
599 );
600 }
601
602 #[test]
605 fn wrong_open_evals_count() {
606 let shifts = all_shift_by_1(3);
607 let (public, msg) = honest_interaction(3, 3, &shifts);
608
609 let mut msg_short = msg.clone();
610 msg_short.open_evals.pop();
611
612 let mut msg_long = msg;
613 msg_long.open_evals.push(F::from(42_u32));
614
615 for bad_msg in [&msg_short, &msg_long] {
616 let err = run_verifier(&public, bad_msg).unwrap_err();
617 assert!(
618 matches!(err, MultipointEvalError::WrongOpenEvalsNumber {
619 got,
620 expected: 3,
621 } if got == bad_msg.open_evals.len()),
622 "expected WrongOpenEvalsNumber, got {err:?}",
623 );
624 }
625 }
626
627 #[test]
630 fn wrong_claimed_sum_via_corrupted_up_evals() {
631 let shifts = all_shift_by_1(3);
632 let (mut public, msg) = honest_interaction(3, 3, &shifts);
633 public.up_evals[0] += F::ONE;
634 let err = run_verifier(&public, &msg).unwrap_err();
635 assert!(
636 matches!(err, MultipointEvalError::WrongSumcheckSum { .. }),
637 "expected WrongSumcheckSum, got {err:?}",
638 );
639 }
640
641 #[test]
642 fn wrong_claimed_sum_via_corrupted_down_evals() {
643 let shifts = all_shift_by_1(3);
644 let (mut public, msg) = honest_interaction(3, 3, &shifts);
645 public.down_evals[1] += F::ONE;
646 let err = run_verifier(&public, &msg).unwrap_err();
647 assert!(
648 matches!(err, MultipointEvalError::WrongSumcheckSum { .. }),
649 "expected WrongSumcheckSum, got {err:?}",
650 );
651 }
652
653 #[test]
656 fn wrong_open_eval_value() {
657 let shifts = all_shift_by_1(3);
658 let (public, mut msg) = honest_interaction(3, 3, &shifts);
659 msg.open_evals[0] += F::ONE;
660 let err = run_verifier(&public, &msg).unwrap_err();
661 assert!(
662 matches!(err, MultipointEvalError::ClaimMismatch { .. }),
663 "expected ClaimMismatch, got {err:?}",
664 );
665 }
666
667 #[test]
668 fn all_open_evals_zeroed() {
669 let shifts = all_shift_by_1(3);
670 let (public, mut msg) = honest_interaction(3, 3, &shifts);
671 for e in &mut msg.open_evals {
672 *e = F::ZERO;
673 }
674 let err = run_verifier(&public, &msg).unwrap_err();
675 assert!(
676 matches!(err, MultipointEvalError::ClaimMismatch { .. }),
677 "expected ClaimMismatch, got {err:?}",
678 );
679 }
680
681 fn mixed_shifts() -> Vec<ShiftSpec> {
684 vec![ShiftSpec::new(0, 1), ShiftSpec::new(1, 3)]
685 }
686
687 #[test]
688 fn mixed_shifts_corrupted_up_eval() {
689 let (mut public, msg) = honest_interaction(4, 3, &mixed_shifts());
690 public.up_evals[2] += F::ONE; let err = run_verifier(&public, &msg).unwrap_err();
692 assert!(
693 matches!(err, MultipointEvalError::WrongSumcheckSum { .. }),
694 "expected WrongSumcheckSum, got {err:?}",
695 );
696 }
697
698 #[test]
699 fn mixed_shifts_wrong_open_eval() {
700 let (public, mut msg) = honest_interaction(4, 3, &mixed_shifts());
701 msg.open_evals[1] += F::ONE; let err = run_verifier(&public, &msg).unwrap_err();
703 assert!(
704 matches!(err, MultipointEvalError::ClaimMismatch { .. }),
705 "expected ClaimMismatch, got {err:?}",
706 );
707 }
708
709 #[test]
710 fn mixed_shifts_tampered_sumcheck() {
711 let (public, mut msg) = honest_interaction(4, 3, &mixed_shifts());
712 msg.proof.sumcheck_proof.messages[0].0.tail_evaluations[0] += F::ONE;
713 let err = run_verifier(&public, &msg).unwrap_err();
714 assert!(
715 matches!(
716 err,
717 MultipointEvalError::SumcheckError(_) | MultipointEvalError::ClaimMismatch { .. }
718 ),
719 "expected sumcheck or consistency error, got {err:?}",
720 );
721 }
722
723 #[test]
726 fn tampered_sumcheck_round_message() {
727 let shifts = all_shift_by_1(3);
728 let (public, mut msg) = honest_interaction(3, 3, &shifts);
729 msg.proof.sumcheck_proof.messages[0].0.tail_evaluations[0] += F::ONE;
730 let err = run_verifier(&public, &msg).unwrap_err();
731 assert!(
732 matches!(
733 err,
734 MultipointEvalError::SumcheckError(_) | MultipointEvalError::ClaimMismatch { .. }
735 ),
736 "expected sumcheck or consistency error, got {err:?}",
737 );
738 }
739}