1use crypto_primitives::{FromPrimitiveWithConfig, PrimeField};
17use num_traits::Zero;
18#[cfg(feature = "parallel")]
19use rayon::prelude::*;
20use std::marker::PhantomData;
21use zinc_poly::mle::DenseMultilinearExtension;
22use zinc_transcript::traits::{ConstTranscribable, GenTranscribable, Transcribable, Transcript};
23use zinc_utils::{
24 add, cfg_iter, cfg_iter_mut, inner_transparent_field::InnerTransparentField, mul,
25};
26
27use crate::CombFn;
28
29use super::{
30 SumCheckError,
31 prover::{
32 NatEvaluatedPolyWithoutConstant, ProverMsg as SumcheckProverMsg,
33 ProverState as SumcheckProverState,
34 },
35 verifier::VerifierState,
36};
37
38pub struct MultiDegreeSumcheckGroup<F: PrimeField> {
45 degree: usize,
46 poly: Vec<DenseMultilinearExtension<F::Inner>>,
47 comb_fn: CombFn<F>,
48}
49
50impl<F: PrimeField> MultiDegreeSumcheckGroup<F> {
51 pub fn new(
52 degree: usize,
53 poly: Vec<DenseMultilinearExtension<F::Inner>>,
54 comb_fn: CombFn<F>,
55 ) -> Self {
56 Self {
57 degree,
58 poly,
59 comb_fn,
60 }
61 }
62}
63
64#[derive(Clone, Debug, PartialEq, Eq)]
69pub struct MultiDegreeSumcheckProof<F> {
70 group_messages: Vec<Vec<SumcheckProverMsg<F>>>,
72 claimed_sums: Vec<F>,
74 degrees: Vec<usize>,
76}
77
78impl<F> MultiDegreeSumcheckProof<F> {
79 pub fn claimed_sums(&self) -> &[F] {
82 &self.claimed_sums
83 }
84}
85
86impl<F: PrimeField> GenTranscribable for MultiDegreeSumcheckProof<F>
87where
88 F::Inner: ConstTranscribable,
89 F::Modulus: ConstTranscribable,
90{
91 fn read_transcription_bytes_exact(bytes: &[u8]) -> Self {
92 let mod_size = F::Modulus::NUM_BYTES;
93 let cfg = zinc_transcript::read_field_cfg::<F>(&bytes[..mod_size]);
94 let bytes = &bytes[mod_size..];
95
96 let (num_groups, bytes) = u32::read_transcription_bytes_subset(bytes);
97 let num_groups = usize::try_from(num_groups).expect("group count must fit into usize");
98
99 let (num_vars, mut bytes) = u32::read_transcription_bytes_subset(bytes);
100 let num_vars = usize::try_from(num_vars).expect("num_vars must fit into usize");
101
102 let mut degrees = Vec::with_capacity(num_groups);
103 for _ in 0..num_groups {
104 let (deg, rest) = u32::read_transcription_bytes_subset(bytes);
105 degrees.push(usize::try_from(deg).expect("degree must fit into usize"));
106 bytes = rest;
107 }
108
109 let mut group_messages = Vec::with_capacity(num_groups);
110 for ° in °rees {
111 let msg_bytes = mul!(deg, F::Inner::NUM_BYTES);
112 let mut msgs = Vec::with_capacity(num_vars);
113 for _ in 0..num_vars {
114 let tail_evaluations =
115 zinc_transcript::read_field_vec_with_cfg(&bytes[..msg_bytes], &cfg);
116 msgs.push(SumcheckProverMsg(NatEvaluatedPolyWithoutConstant {
117 tail_evaluations,
118 }));
119 bytes = &bytes[msg_bytes..];
120 }
121 group_messages.push(msgs);
122 }
123
124 let mut claimed_sums = Vec::with_capacity(num_groups);
125 for _ in 0..num_groups {
126 let cs = F::Inner::read_transcription_bytes_exact(&bytes[..F::Inner::NUM_BYTES]);
127 let cs = F::new_unchecked_with_cfg(cs, &cfg);
128 claimed_sums.push(cs);
129 bytes = &bytes[F::Inner::NUM_BYTES..];
130 }
131
132 Self {
133 group_messages,
134 claimed_sums,
135 degrees,
136 }
137 }
138
139 fn write_transcription_bytes_exact(&self, mut buf: &mut [u8]) {
140 buf = zinc_transcript::append_field_cfg::<F>(buf, &self.claimed_sums[0].modulus());
141
142 let num_groups =
143 u32::try_from(self.group_messages.len()).expect("num groups must fit into u32");
144 num_groups.write_transcription_bytes_exact(&mut buf[..u32::NUM_BYTES]);
145 buf = &mut buf[u32::NUM_BYTES..];
146
147 let num_vars =
149 u32::try_from(self.group_messages[0].len()).expect("num_vars must fit into u32");
150 num_vars.write_transcription_bytes_exact(&mut buf[..u32::NUM_BYTES]);
151 buf = &mut buf[u32::NUM_BYTES..];
152
153 for ° in &self.degrees {
154 let deg = u32::try_from(deg).expect("degree must fit into u32");
155 deg.write_transcription_bytes_exact(&mut buf[..u32::NUM_BYTES]);
156 buf = &mut buf[u32::NUM_BYTES..];
157 }
158
159 for group in &self.group_messages {
160 for msg in group {
161 buf = zinc_transcript::append_field_vec_inner(buf, &msg.0.tail_evaluations);
162 }
163 }
164
165 for cs in &self.claimed_sums {
166 cs.inner()
167 .write_transcription_bytes_exact(&mut buf[..F::Inner::NUM_BYTES]);
168 buf = &mut buf[F::Inner::NUM_BYTES..];
169 }
170 }
171}
172
173impl<F: PrimeField> Transcribable for MultiDegreeSumcheckProof<F>
174where
175 F::Inner: ConstTranscribable,
176 F::Modulus: ConstTranscribable,
177{
178 fn get_num_bytes(&self) -> usize {
179 let num_groups = self.group_messages.len();
180 let num_vars = self.group_messages[0].len();
181 let total_evals: usize = self.degrees.iter().map(|&d| mul!(d, num_vars)).sum();
183
184 let header = add!(F::Modulus::NUM_BYTES, add!(u32::NUM_BYTES, u32::NUM_BYTES));
186 let degrees = mul!(num_groups, u32::NUM_BYTES);
187 let eval_data = mul!(total_evals, F::Inner::NUM_BYTES);
188 let claimed = mul!(num_groups, F::Inner::NUM_BYTES);
189
190 add!(header, add!(degrees, add!(eval_data, claimed)))
191 }
192}
193
194#[derive(Debug)]
196pub struct MultiDegreeSubClaims<F> {
197 point: Vec<F>,
198 expected_evaluations: Vec<F>,
199}
200
201impl<F> MultiDegreeSubClaims<F> {
202 pub fn point(&self) -> &[F] {
203 &self.point
204 }
205
206 pub fn expected_evaluations(&self) -> &[F] {
207 &self.expected_evaluations
208 }
209}
210
211pub struct MultiDegreeSumcheck<F>(PhantomData<F>);
216
217impl<F: FromPrimitiveWithConfig> MultiDegreeSumcheck<F> {
218 #[allow(clippy::type_complexity)]
259 pub fn prove_as_subprotocol(
260 transcript: &mut impl Transcript,
261 groups: Vec<MultiDegreeSumcheckGroup<F>>,
262 num_vars: usize,
263 config: &F::Config,
264 ) -> (MultiDegreeSumcheckProof<F>, Vec<SumcheckProverState<F>>)
265 where
266 F: InnerTransparentField + Send + Sync,
267 F::Inner: ConstTranscribable + Zero,
268 F::Modulus: ConstTranscribable,
269 {
270 assert!(
271 num_vars > 0,
272 "Attempts to prove a constant: num_vars must be > 0"
273 );
274 assert!(!groups.is_empty(), "need at least one degree group");
275
276 let num_groups = groups.len();
277 let mut buf = vec![0; F::Inner::NUM_BYTES];
278 let nvars_field = F::from_with_cfg(num_vars as u64, config);
279 let ngroups_field = F::from_with_cfg(num_groups as u64, config);
280 transcript.absorb_random_field(&nvars_field, &mut buf);
281 transcript.absorb_random_field(&ngroups_field, &mut buf);
282
283 let mut verifier_msg = None;
284 let mut group_messages: Vec<Vec<SumcheckProverMsg<F>>> = (0..num_groups)
285 .map(|_| Vec::with_capacity(num_vars))
286 .collect();
287 let mut claimed_sums = Vec::with_capacity(num_groups);
288
289 let (mut prover_states, comb_fns): (Vec<_>, Vec<_>) = groups
290 .into_iter()
291 .map(|group| {
292 let degree_field = F::from_with_cfg(group.degree as u64, config);
293 transcript.absorb_random_field(°ree_field, &mut buf);
294
295 (
296 SumcheckProverState::new(group.poly, num_vars, group.degree),
297 group.comb_fn,
298 )
299 })
300 .unzip();
301
302 for _ in 0..num_vars {
303 let round_msgs: Vec<SumcheckProverMsg<F>> = cfg_iter_mut!(prover_states)
305 .zip(cfg_iter!(comb_fns))
306 .map(|(state, comb_fn)| state.prove_round(&verifier_msg, comb_fn, config))
307 .collect();
308
309 for msg in &round_msgs {
311 transcript.absorb_random_field_slice(&msg.0.tail_evaluations, &mut buf);
312 }
313
314 for (j, msg) in round_msgs.into_iter().enumerate() {
315 group_messages[j].push(msg);
316 }
317
318 let next_verifier_msg = transcript.get_field_challenge(config);
319 transcript.absorb_random_field(&next_verifier_msg, &mut buf);
320
321 verifier_msg = Some(next_verifier_msg);
322 }
323
324 prover_states.iter_mut().for_each(|state| {
325 let sum = state
326 .asserted_sum
327 .clone()
328 .expect("asserted sum should be recorded after the first prover round");
329 claimed_sums.push(sum);
330
331 if let Some(ref vmsg) = verifier_msg {
332 state.randomness.push(vmsg.clone());
333 }
334 });
335
336 let degrees = prover_states.iter().map(|s| s.max_degree).collect();
337
338 (
339 MultiDegreeSumcheckProof {
340 group_messages,
341 claimed_sums,
342 degrees,
343 },
344 prover_states,
345 )
346 }
347
348 pub fn verify_as_subprotocol(
389 transcript: &mut impl Transcript,
390 num_vars: usize,
391 proof: &MultiDegreeSumcheckProof<F>,
392 config: &F::Config,
393 ) -> Result<MultiDegreeSubClaims<F>, SumCheckError<F>>
394 where
395 F: InnerTransparentField,
396 F::Inner: ConstTranscribable,
397 F::Modulus: ConstTranscribable,
398 {
399 assert!(
400 num_vars > 0,
401 "Attempts to prove a constant: num_vars must be > 0"
402 );
403 let num_groups = proof.degrees.len();
404 assert!(num_groups != 0, "need at least one degree group");
405
406 let mut buf = vec![0; F::Inner::NUM_BYTES];
407 let nvars_field = F::from_with_cfg(num_vars as u64, config);
408 let ngroups_field = F::from_with_cfg(num_groups as u64, config);
409 transcript.absorb_random_field(&nvars_field, &mut buf);
410 transcript.absorb_random_field(&ngroups_field, &mut buf);
411
412 let mut verifier_states: Vec<VerifierState<F>> = (0..num_groups)
413 .map(|j| {
414 let degree = proof.degrees[j];
415 let degree_field = F::from_with_cfg(degree as u64, config);
416 transcript.absorb_random_field(°ree_field, &mut buf);
417
418 VerifierState::new(num_vars, degree, config)
419 })
420 .collect();
421
422 for msgs in &proof.group_messages {
423 if msgs.len() != num_vars {
424 return Err(SumCheckError::InvalidProofLength {
425 expected: num_vars,
426 got: msgs.len(),
427 });
428 }
429 }
430
431 assert_eq!(
432 verifier_states.len(),
433 proof.group_messages.len(),
434 "number of verifier states ({}) must match number of proof groups ({})",
435 verifier_states.len(),
436 proof.group_messages.len(),
437 );
438
439 for i in 0..num_vars {
440 proof.group_messages.iter().for_each(|msg| {
441 transcript.absorb_random_field_slice(&msg[i].0.tail_evaluations, &mut buf)
442 });
443
444 let shared_challenge: F = transcript.get_field_challenge(config);
445 transcript.absorb_random_field(&shared_challenge, &mut buf);
446
447 verifier_states
448 .iter_mut()
449 .zip(proof.group_messages.iter())
450 .for_each(|(state, msg)| {
451 state.verify_round_with_challenge(&msg[i], shared_challenge.clone())
452 });
453 }
454
455 let mut shared_point: Option<Vec<F>> = None;
456 let mut expected_evaluations = Vec::with_capacity(num_groups);
457 for (j, state) in verifier_states.into_iter().enumerate() {
459 let subclaim = state.check_and_generate_subclaim(proof.claimed_sums[j].clone())?;
460 if let Some(ref p) = shared_point {
461 debug_assert_eq!(p, &subclaim.point);
462 } else {
463 shared_point = Some(subclaim.point)
464 }
465
466 expected_evaluations.push(subclaim.expected_evaluation);
467 }
468
469 Ok(MultiDegreeSubClaims {
470 point: shared_point.expect("at least one group"),
471 expected_evaluations,
472 })
473 }
474}
475
476#[cfg(test)]
481mod tests {
482 use super::*;
483 use crypto_bigint::{U128, const_monty_params};
484 use crypto_primitives::crypto_bigint_const_monty::ConstMontyField;
485 use zinc_poly::{mle::MultilinearExtensionWithConfig, utils::build_eq_x_r_inner};
486 use zinc_transcript::Blake3Transcript;
487
488 const_monty_params!(TestParams, U128, "00000000b933426489189cb5b47d567f");
489 type F = ConstMontyField<TestParams, { U128::LIMBS }>;
490
491 #[test]
496 fn multi_degree_two_groups() {
497 let num_vars = 3;
498 let cfg = &();
499
500 let a_vals: Vec<F> = (0u32..8).map(|i| F::from(i + 1)).collect();
501 let b_vals: Vec<F> = (0u32..8).map(|i| F::from(i + 10)).collect();
502 let inner_zero = *F::from(0u32).inner();
503
504 let a_mle = DenseMultilinearExtension::from_evaluations_vec(
505 num_vars,
506 a_vals.iter().map(|x| *x.inner()).collect(),
507 inner_zero,
508 );
509 let b_mle = DenseMultilinearExtension::from_evaluations_vec(
510 num_vars,
511 b_vals.iter().map(|x| *x.inner()).collect(),
512 inner_zero,
513 );
514
515 let r: Vec<F> = vec![F::from(5u32), F::from(7u32), F::from(11u32)];
516 let eq_r = build_eq_x_r_inner(&r, cfg).unwrap();
517
518 let g0 = MultiDegreeSumcheckGroup::new(
520 2,
521 vec![eq_r.clone(), a_mle.clone(), b_mle.clone()],
522 Box::new(|v: &[F]| v[0] * (v[1] + v[2])),
523 );
524
525 let g1 = MultiDegreeSumcheckGroup::new(
527 3,
528 vec![eq_r.clone(), a_mle.clone(), b_mle.clone()],
529 Box::new(|v: &[F]| v[0] * v[1] * v[2]),
530 );
531
532 let mut pt = Blake3Transcript::new();
534 let (proof, _states) =
535 MultiDegreeSumcheck::<F>::prove_as_subprotocol(&mut pt, vec![g0, g1], num_vars, cfg);
536
537 let mut vt = Blake3Transcript::new();
539 let subclaims =
540 MultiDegreeSumcheck::<F>::verify_as_subprotocol(&mut vt, num_vars, &proof, cfg)
541 .expect("verification should succeed");
542
543 assert_eq!(subclaims.expected_evaluations.len(), 2);
544
545 let point = &subclaims.point;
547 let eq_eval = zinc_poly::utils::eq_eval(point, &r, F::from(1u32)).unwrap();
548 let a_eval = a_mle.evaluate_with_config(point, cfg).unwrap();
549 let b_eval = b_mle.evaluate_with_config(point, cfg).unwrap();
550
551 assert_eq!(
552 subclaims.expected_evaluations[0],
553 eq_eval * (a_eval + b_eval)
554 );
555 assert_eq!(subclaims.expected_evaluations[1], eq_eval * a_eval * b_eval);
556 }
557
558 #[test]
560 fn multi_degree_single_group() {
561 let num_vars = 2;
562 let cfg = &();
563
564 let vals: Vec<F> = (0u32..4).map(|i| F::from(i + 1)).collect();
565 let inner_zero = *F::from(0u32).inner();
566 let mle = DenseMultilinearExtension::from_evaluations_vec(
567 num_vars,
568 vals.iter().map(|x| *x.inner()).collect(),
569 inner_zero,
570 );
571
572 let r: Vec<F> = vec![F::from(3u32), F::from(7u32)];
573 let eq_r = build_eq_x_r_inner(&r, cfg).unwrap();
574
575 let g = MultiDegreeSumcheckGroup::new(
576 2,
577 vec![eq_r.clone(), mle.clone()],
578 Box::new(|v: &[F]| v[0] * v[1]),
579 );
580
581 let mut pt = Blake3Transcript::new();
582 let (proof, _) =
583 MultiDegreeSumcheck::<F>::prove_as_subprotocol(&mut pt, vec![g], num_vars, cfg);
584
585 let mut vt = Blake3Transcript::new();
586 let subclaims =
587 MultiDegreeSumcheck::<F>::verify_as_subprotocol(&mut vt, num_vars, &proof, cfg)
588 .expect("verification should succeed");
589
590 let point = &subclaims.point;
591 let eq_eval = zinc_poly::utils::eq_eval(point, &r, F::from(1u32)).unwrap();
592 let a_eval = mle.clone().evaluate_with_config(point, cfg).unwrap();
593
594 assert_eq!(subclaims.expected_evaluations[0], eq_eval * a_eval);
595 }
596}