Skip to main content

zinc_piop/sumcheck/
multi_degree.rs

1//! Multi-degree sumcheck: runs multiple degree groups in lockstep with
2//! shared verifier randomness, producing a common evaluation point.
3//!
4//! # Protocol
5//!
6//! Given G degree groups each with (degree_g, mles_g, comb_fn_g):
7//!
8//! 1. Absorb metadata: num_vars, num_groups, per-group degrees
9//! 2. For each round `i = 1..num_vars`:
10//!    - Each group computes its round polynomial `P_g` (parallelizable)
11//!    - Absorb all round messages in deterministic order
12//!    - Sample ONE shared challenge `r_i`
13//!    - All groups fix variable `i` at `r_i`
14//! 3. Each group produces a subclaim at the shared point r = (r_1, ..., r_n)
15
16use crypto_primitives::{FromPrimitiveWithConfig, PrimeField};
17use num_traits::Zero;
18#[cfg(feature = "parallel")]
19use rayon::prelude::*;
20use std::marker::PhantomData;
21use zinc_poly::mle::DenseMultilinearExtension;
22use zinc_transcript::traits::{ConstTranscribable, GenTranscribable, Transcribable, Transcript};
23use zinc_utils::{
24    add, cfg_iter, cfg_iter_mut, inner_transparent_field::InnerTransparentField, mul,
25};
26
27use crate::CombFn;
28
29use super::{
30    SumCheckError,
31    prover::{
32        NatEvaluatedPolyWithoutConstant, ProverMsg as SumcheckProverMsg,
33        ProverState as SumcheckProverState,
34    },
35    verifier::VerifierState,
36};
37
38// ---------------------------------------------------------------------------
39// Types
40// ---------------------------------------------------------------------------
41
42/// A single degree group for the multi-degree sumcheck: (degree, mles,
43/// comb_fn).
44pub struct MultiDegreeSumcheckGroup<F: PrimeField> {
45    degree: usize,
46    poly: Vec<DenseMultilinearExtension<F::Inner>>,
47    comb_fn: CombFn<F>,
48}
49
50impl<F: PrimeField> MultiDegreeSumcheckGroup<F> {
51    pub fn new(
52        degree: usize,
53        poly: Vec<DenseMultilinearExtension<F::Inner>>,
54        comb_fn: CombFn<F>,
55    ) -> Self {
56        Self {
57            degree,
58            poly,
59            comb_fn,
60        }
61    }
62}
63
64/// Proof for a multi-degree sumcheck.
65///
66/// `group_messages[g][round]` = prover message for group g in that round.
67/// All groups share verifier challenges, common evaluation point.
68#[derive(Clone, Debug, PartialEq, Eq)]
69pub struct MultiDegreeSumcheckProof<F> {
70    /// List of prover messages, one for each round per group.
71    group_messages: Vec<Vec<SumcheckProverMsg<F>>>,
72    // The claimed sum for the first round polynomial per group.
73    claimed_sums: Vec<F>,
74    // Max degrees per group.
75    degrees: Vec<usize>,
76}
77
78impl<F> MultiDegreeSumcheckProof<F> {
79    /// Needed by the verifier to check against expected
80    /// sums before running the sumcheck.
81    pub fn claimed_sums(&self) -> &[F] {
82        &self.claimed_sums
83    }
84}
85
86impl<F: PrimeField> GenTranscribable for MultiDegreeSumcheckProof<F>
87where
88    F::Inner: ConstTranscribable,
89    F::Modulus: ConstTranscribable,
90{
91    fn read_transcription_bytes_exact(bytes: &[u8]) -> Self {
92        let mod_size = F::Modulus::NUM_BYTES;
93        let cfg = zinc_transcript::read_field_cfg::<F>(&bytes[..mod_size]);
94        let bytes = &bytes[mod_size..];
95
96        let (num_groups, bytes) = u32::read_transcription_bytes_subset(bytes);
97        let num_groups = usize::try_from(num_groups).expect("group count must fit into usize");
98
99        let (num_vars, mut bytes) = u32::read_transcription_bytes_subset(bytes);
100        let num_vars = usize::try_from(num_vars).expect("num_vars must fit into usize");
101
102        let mut degrees = Vec::with_capacity(num_groups);
103        for _ in 0..num_groups {
104            let (deg, rest) = u32::read_transcription_bytes_subset(bytes);
105            degrees.push(usize::try_from(deg).expect("degree must fit into usize"));
106            bytes = rest;
107        }
108
109        let mut group_messages = Vec::with_capacity(num_groups);
110        for &deg in &degrees {
111            let msg_bytes = mul!(deg, F::Inner::NUM_BYTES);
112            let mut msgs = Vec::with_capacity(num_vars);
113            for _ in 0..num_vars {
114                let tail_evaluations =
115                    zinc_transcript::read_field_vec_with_cfg(&bytes[..msg_bytes], &cfg);
116                msgs.push(SumcheckProverMsg(NatEvaluatedPolyWithoutConstant {
117                    tail_evaluations,
118                }));
119                bytes = &bytes[msg_bytes..];
120            }
121            group_messages.push(msgs);
122        }
123
124        let mut claimed_sums = Vec::with_capacity(num_groups);
125        for _ in 0..num_groups {
126            let cs = F::Inner::read_transcription_bytes_exact(&bytes[..F::Inner::NUM_BYTES]);
127            let cs = F::new_unchecked_with_cfg(cs, &cfg);
128            claimed_sums.push(cs);
129            bytes = &bytes[F::Inner::NUM_BYTES..];
130        }
131
132        Self {
133            group_messages,
134            claimed_sums,
135            degrees,
136        }
137    }
138
139    fn write_transcription_bytes_exact(&self, mut buf: &mut [u8]) {
140        buf = zinc_transcript::append_field_cfg::<F>(buf, &self.claimed_sums[0].modulus());
141
142        let num_groups =
143            u32::try_from(self.group_messages.len()).expect("num groups must fit into u32");
144        num_groups.write_transcription_bytes_exact(&mut buf[..u32::NUM_BYTES]);
145        buf = &mut buf[u32::NUM_BYTES..];
146
147        // All groups share the same number of rounds (num_vars).
148        let num_vars =
149            u32::try_from(self.group_messages[0].len()).expect("num_vars must fit into u32");
150        num_vars.write_transcription_bytes_exact(&mut buf[..u32::NUM_BYTES]);
151        buf = &mut buf[u32::NUM_BYTES..];
152
153        for &deg in &self.degrees {
154            let deg = u32::try_from(deg).expect("degree must fit into u32");
155            deg.write_transcription_bytes_exact(&mut buf[..u32::NUM_BYTES]);
156            buf = &mut buf[u32::NUM_BYTES..];
157        }
158
159        for group in &self.group_messages {
160            for msg in group {
161                buf = zinc_transcript::append_field_vec_inner(buf, &msg.0.tail_evaluations);
162            }
163        }
164
165        for cs in &self.claimed_sums {
166            cs.inner()
167                .write_transcription_bytes_exact(&mut buf[..F::Inner::NUM_BYTES]);
168            buf = &mut buf[F::Inner::NUM_BYTES..];
169        }
170    }
171}
172
173impl<F: PrimeField> Transcribable for MultiDegreeSumcheckProof<F>
174where
175    F::Inner: ConstTranscribable,
176    F::Modulus: ConstTranscribable,
177{
178    fn get_num_bytes(&self) -> usize {
179        let num_groups = self.group_messages.len();
180        let num_vars = self.group_messages[0].len();
181        // total_evals = Σ_g (degree_g × num_vars)
182        let total_evals: usize = self.degrees.iter().map(|&d| mul!(d, num_vars)).sum();
183
184        // [field_cfg][num_groups][num_vars][deg₀..degₙ][evals...][claimed_sums]
185        let header = add!(F::Modulus::NUM_BYTES, add!(u32::NUM_BYTES, u32::NUM_BYTES));
186        let degrees = mul!(num_groups, u32::NUM_BYTES);
187        let eval_data = mul!(total_evals, F::Inner::NUM_BYTES);
188        let claimed = mul!(num_groups, F::Inner::NUM_BYTES);
189
190        add!(header, add!(degrees, add!(eval_data, claimed)))
191    }
192}
193
194/// Sub-claims: shared evaluation point + per-group expected evaluation.
195#[derive(Debug)]
196pub struct MultiDegreeSubClaims<F> {
197    point: Vec<F>,
198    expected_evaluations: Vec<F>,
199}
200
201impl<F> MultiDegreeSubClaims<F> {
202    pub fn point(&self) -> &[F] {
203        &self.point
204    }
205
206    pub fn expected_evaluations(&self) -> &[F] {
207        &self.expected_evaluations
208    }
209}
210
211// ---------------------------------------------------------------------------
212// MultiDegreeSumcheck
213// ---------------------------------------------------------------------------
214
215pub struct MultiDegreeSumcheck<F>(PhantomData<F>);
216
217impl<F: FromPrimitiveWithConfig> MultiDegreeSumcheck<F> {
218    /// Multi-degree sumcheck prover.
219    ///
220    /// Runs the prover side of the sumcheck protocol for G degree groups
221    /// sharing one verifier challenge per round. Proves the claim:
222    ///
223    /// $$
224    /// \sum_{x \in \{0, 1\}^{\text{num\\_vars}}} G_g(x) =
225    /// \text{claimed\\_sum}_g \quad \forall g
226    /// $$
227    ///
228    /// where $G_g(x) = \text{comb\\_fn}_g(\text{mles}_g(x))$ is the combination
229    /// function for group $g$ applied to its MLEs.
230    ///
231    /// It is designed to be used as a subprotocol within a larger system.
232    /// since it takes the FS transcript (`transcript` argument) as input
233    /// and returns the **internal ProverState** alongside the sumcheck proof.
234    ///
235    /// Claimed sums are derived by the prover during the first round.
236    ///
237    /// # Arguments
238    ///
239    /// * `transcript`: Fiat-Shamir transcript.
240    /// * `groups`: One [`MultiDegreeSumcheckGroup`] per degree bucket, each
241    ///   carrying its MLEs and combination function.
242    /// * `num_vars`: Number of variables (must be consistent across all
243    ///   groups).
244    /// * `config`: Field configuration.
245    ///
246    /// # Returns
247    ///
248    /// A tuple containing:
249    ///
250    /// 1. [`MultiDegreeSumcheckProof<F>`]: The proof (group messages, claimed
251    ///    sums, degrees).
252    /// 2. `Vec<ProverState<F>>`: Per-group prover states — needed by the caller
253    ///    to evaluate MLEs at the shared point after the sumcheck.
254    ///
255    /// # Panics
256    ///
257    /// * Panics if `num_vars == 0` or `groups` is empty.
258    #[allow(clippy::type_complexity)]
259    pub fn prove_as_subprotocol(
260        transcript: &mut impl Transcript,
261        groups: Vec<MultiDegreeSumcheckGroup<F>>,
262        num_vars: usize,
263        config: &F::Config,
264    ) -> (MultiDegreeSumcheckProof<F>, Vec<SumcheckProverState<F>>)
265    where
266        F: InnerTransparentField + Send + Sync,
267        F::Inner: ConstTranscribable + Zero,
268        F::Modulus: ConstTranscribable,
269    {
270        assert!(
271            num_vars > 0,
272            "Attempts to prove a constant: num_vars must be > 0"
273        );
274        assert!(!groups.is_empty(), "need at least one degree group");
275
276        let num_groups = groups.len();
277        let mut buf = vec![0; F::Inner::NUM_BYTES];
278        let nvars_field = F::from_with_cfg(num_vars as u64, config);
279        let ngroups_field = F::from_with_cfg(num_groups as u64, config);
280        transcript.absorb_random_field(&nvars_field, &mut buf);
281        transcript.absorb_random_field(&ngroups_field, &mut buf);
282
283        let mut verifier_msg = None;
284        let mut group_messages: Vec<Vec<SumcheckProverMsg<F>>> = (0..num_groups)
285            .map(|_| Vec::with_capacity(num_vars))
286            .collect();
287        let mut claimed_sums = Vec::with_capacity(num_groups);
288
289        let (mut prover_states, comb_fns): (Vec<_>, Vec<_>) = groups
290            .into_iter()
291            .map(|group| {
292                let degree_field = F::from_with_cfg(group.degree as u64, config);
293                transcript.absorb_random_field(&degree_field, &mut buf);
294
295                (
296                    SumcheckProverState::new(group.poly, num_vars, group.degree),
297                    group.comb_fn,
298                )
299            })
300            .unzip();
301
302        for _ in 0..num_vars {
303            // Parallel: each group computes its round polynomial independently
304            let round_msgs: Vec<SumcheckProverMsg<F>> = cfg_iter_mut!(prover_states)
305                .zip(cfg_iter!(comb_fns))
306                .map(|(state, comb_fn)| state.prove_round(&verifier_msg, comb_fn, config))
307                .collect();
308
309            // Sequential: absorb in deterministic order, sample one shared challenge
310            for msg in &round_msgs {
311                transcript.absorb_random_field_slice(&msg.0.tail_evaluations, &mut buf);
312            }
313
314            for (j, msg) in round_msgs.into_iter().enumerate() {
315                group_messages[j].push(msg);
316            }
317
318            let next_verifier_msg = transcript.get_field_challenge(config);
319            transcript.absorb_random_field(&next_verifier_msg, &mut buf);
320
321            verifier_msg = Some(next_verifier_msg);
322        }
323
324        prover_states.iter_mut().for_each(|state| {
325            let sum = state
326                .asserted_sum
327                .clone()
328                .expect("asserted sum should be recorded after the first prover round");
329            claimed_sums.push(sum);
330
331            if let Some(ref vmsg) = verifier_msg {
332                state.randomness.push(vmsg.clone());
333            }
334        });
335
336        let degrees = prover_states.iter().map(|s| s.max_degree).collect();
337
338        (
339            MultiDegreeSumcheckProof {
340                group_messages,
341                claimed_sums,
342                degrees,
343            },
344            prover_states,
345        )
346    }
347
348    /// Multi-degree sumcheck verifier.
349    ///
350    /// Runs the verifier side of the sumcheck protocol for G degree groups
351    /// sharing one verifier challenge per round. Verifies the claim:
352    ///
353    /// $$
354    /// \sum_{x \in \{0, 1\}^{\text{num\\_vars}}} G_g(x) =
355    /// \text{claimed\\_sum}_g \quad \forall g
356    /// $$
357    ///
358    /// where $G_g(x) = \text{comb\\_fn}_g(\text{mles}_g(x))$.
359    ///
360    /// It is designed to be used as a subprotocol within a larger system.
361    /// If successful, it returns **Subclaim** for each group, a final equation
362    /// that the outer protocol must satisfy for the overall sumcheck proof
363    /// to be valid.
364    ///
365    /// Mirrors the prover transcript exactly: absorbs metadata, then per-round
366    /// absorbs all group messages, samples one shared challenge, and calls
367    /// [`VerifierState::check_and_generate_subclaim`] per group. Per-group
368    /// degrees are read from the proof — no external degree parameter needed.
369    ///
370    /// # Arguments
371    ///
372    /// * `transcript`: Fiat-Shamir transcript (must match prover state at the
373    ///   start of the sumcheck).
374    /// * `num_vars`: Number of variables (sumcheck rounds).
375    /// * `proof`: The [`MultiDegreeSumcheckProof`] produced by the prover.
376    /// * `config`: Field configuration.
377    ///
378    /// # Returns
379    ///
380    /// * `Ok(MultiDegreeSubClaims<F>)`: Shared evaluation point `r*` and
381    ///   per-group expected evaluations. The caller must verify each group's
382    ///   MLE combination at `r*` equals its expected evaluation.
383    /// * `Err(SumCheckError<F>)`: If any round check fails.
384    ///
385    /// # Panics
386    ///
387    /// * Panics if `num_vars == 0` or the proof has no groups.
388    pub fn verify_as_subprotocol(
389        transcript: &mut impl Transcript,
390        num_vars: usize,
391        proof: &MultiDegreeSumcheckProof<F>,
392        config: &F::Config,
393    ) -> Result<MultiDegreeSubClaims<F>, SumCheckError<F>>
394    where
395        F: InnerTransparentField,
396        F::Inner: ConstTranscribable,
397        F::Modulus: ConstTranscribable,
398    {
399        assert!(
400            num_vars > 0,
401            "Attempts to prove a constant: num_vars must be > 0"
402        );
403        let num_groups = proof.degrees.len();
404        assert!(num_groups != 0, "need at least one degree group");
405
406        let mut buf = vec![0; F::Inner::NUM_BYTES];
407        let nvars_field = F::from_with_cfg(num_vars as u64, config);
408        let ngroups_field = F::from_with_cfg(num_groups as u64, config);
409        transcript.absorb_random_field(&nvars_field, &mut buf);
410        transcript.absorb_random_field(&ngroups_field, &mut buf);
411
412        let mut verifier_states: Vec<VerifierState<F>> = (0..num_groups)
413            .map(|j| {
414                let degree = proof.degrees[j];
415                let degree_field = F::from_with_cfg(degree as u64, config);
416                transcript.absorb_random_field(&degree_field, &mut buf);
417
418                VerifierState::new(num_vars, degree, config)
419            })
420            .collect();
421
422        for msgs in &proof.group_messages {
423            if msgs.len() != num_vars {
424                return Err(SumCheckError::InvalidProofLength {
425                    expected: num_vars,
426                    got: msgs.len(),
427                });
428            }
429        }
430
431        assert_eq!(
432            verifier_states.len(),
433            proof.group_messages.len(),
434            "number of verifier states ({}) must match number of proof groups ({})",
435            verifier_states.len(),
436            proof.group_messages.len(),
437        );
438
439        for i in 0..num_vars {
440            proof.group_messages.iter().for_each(|msg| {
441                transcript.absorb_random_field_slice(&msg[i].0.tail_evaluations, &mut buf)
442            });
443
444            let shared_challenge: F = transcript.get_field_challenge(config);
445            transcript.absorb_random_field(&shared_challenge, &mut buf);
446
447            verifier_states
448                .iter_mut()
449                .zip(proof.group_messages.iter())
450                .for_each(|(state, msg)| {
451                    state.verify_round_with_challenge(&msg[i], shared_challenge.clone())
452                });
453        }
454
455        let mut shared_point: Option<Vec<F>> = None;
456        let mut expected_evaluations = Vec::with_capacity(num_groups);
457        // TODO: parallelize when multiple lookup groups exist
458        for (j, state) in verifier_states.into_iter().enumerate() {
459            let subclaim = state.check_and_generate_subclaim(proof.claimed_sums[j].clone())?;
460            if let Some(ref p) = shared_point {
461                debug_assert_eq!(p, &subclaim.point);
462            } else {
463                shared_point = Some(subclaim.point)
464            }
465
466            expected_evaluations.push(subclaim.expected_evaluation);
467        }
468
469        Ok(MultiDegreeSubClaims {
470            point: shared_point.expect("at least one group"),
471            expected_evaluations,
472        })
473    }
474}
475
476// ---------------------------------------------------------------------------
477// Tests
478// ---------------------------------------------------------------------------
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483    use crypto_bigint::{U128, const_monty_params};
484    use crypto_primitives::crypto_bigint_const_monty::ConstMontyField;
485    use zinc_poly::{mle::MultilinearExtensionWithConfig, utils::build_eq_x_r_inner};
486    use zinc_transcript::Blake3Transcript;
487
488    const_monty_params!(TestParams, U128, "00000000b933426489189cb5b47d567f");
489    type F = ConstMontyField<TestParams, { U128::LIMBS }>;
490
491    /// Two degree groups sharing the same evaluation point.
492    ///
493    /// - Group 0 (degree 2): `eq(y, r) · (a(y) + b(y))`
494    /// - Group 1 (degree 3): `eq(y, r) · a(y) · b(y)`
495    #[test]
496    fn multi_degree_two_groups() {
497        let num_vars = 3;
498        let cfg = &();
499
500        let a_vals: Vec<F> = (0u32..8).map(|i| F::from(i + 1)).collect();
501        let b_vals: Vec<F> = (0u32..8).map(|i| F::from(i + 10)).collect();
502        let inner_zero = *F::from(0u32).inner();
503
504        let a_mle = DenseMultilinearExtension::from_evaluations_vec(
505            num_vars,
506            a_vals.iter().map(|x| *x.inner()).collect(),
507            inner_zero,
508        );
509        let b_mle = DenseMultilinearExtension::from_evaluations_vec(
510            num_vars,
511            b_vals.iter().map(|x| *x.inner()).collect(),
512            inner_zero,
513        );
514
515        let r: Vec<F> = vec![F::from(5u32), F::from(7u32), F::from(11u32)];
516        let eq_r = build_eq_x_r_inner(&r, cfg).unwrap();
517
518        // Group 0 (degree 2): eq · (a + b)
519        let g0 = MultiDegreeSumcheckGroup::new(
520            2,
521            vec![eq_r.clone(), a_mle.clone(), b_mle.clone()],
522            Box::new(|v: &[F]| v[0] * (v[1] + v[2])),
523        );
524
525        // Group 1 (degree 3): eq · a · b
526        let g1 = MultiDegreeSumcheckGroup::new(
527            3,
528            vec![eq_r.clone(), a_mle.clone(), b_mle.clone()],
529            Box::new(|v: &[F]| v[0] * v[1] * v[2]),
530        );
531
532        // Prove
533        let mut pt = Blake3Transcript::new();
534        let (proof, _states) =
535            MultiDegreeSumcheck::<F>::prove_as_subprotocol(&mut pt, vec![g0, g1], num_vars, cfg);
536
537        // Verify
538        let mut vt = Blake3Transcript::new();
539        let subclaims =
540            MultiDegreeSumcheck::<F>::verify_as_subprotocol(&mut vt, num_vars, &proof, cfg)
541                .expect("verification should succeed");
542
543        assert_eq!(subclaims.expected_evaluations.len(), 2);
544
545        // Check final evaluations manually
546        let point = &subclaims.point;
547        let eq_eval = zinc_poly::utils::eq_eval(point, &r, F::from(1u32)).unwrap();
548        let a_eval = a_mle.evaluate_with_config(point, cfg).unwrap();
549        let b_eval = b_mle.evaluate_with_config(point, cfg).unwrap();
550
551        assert_eq!(
552            subclaims.expected_evaluations[0],
553            eq_eval * (a_eval + b_eval)
554        );
555        assert_eq!(subclaims.expected_evaluations[1], eq_eval * a_eval * b_eval);
556    }
557
558    /// Multi-degree sumcheck with a single group produces a valid subclaim.
559    #[test]
560    fn multi_degree_single_group() {
561        let num_vars = 2;
562        let cfg = &();
563
564        let vals: Vec<F> = (0u32..4).map(|i| F::from(i + 1)).collect();
565        let inner_zero = *F::from(0u32).inner();
566        let mle = DenseMultilinearExtension::from_evaluations_vec(
567            num_vars,
568            vals.iter().map(|x| *x.inner()).collect(),
569            inner_zero,
570        );
571
572        let r: Vec<F> = vec![F::from(3u32), F::from(7u32)];
573        let eq_r = build_eq_x_r_inner(&r, cfg).unwrap();
574
575        let g = MultiDegreeSumcheckGroup::new(
576            2,
577            vec![eq_r.clone(), mle.clone()],
578            Box::new(|v: &[F]| v[0] * v[1]),
579        );
580
581        let mut pt = Blake3Transcript::new();
582        let (proof, _) =
583            MultiDegreeSumcheck::<F>::prove_as_subprotocol(&mut pt, vec![g], num_vars, cfg);
584
585        let mut vt = Blake3Transcript::new();
586        let subclaims =
587            MultiDegreeSumcheck::<F>::verify_as_subprotocol(&mut vt, num_vars, &proof, cfg)
588                .expect("verification should succeed");
589
590        let point = &subclaims.point;
591        let eq_eval = zinc_poly::utils::eq_eval(point, &r, F::from(1u32)).unwrap();
592        let a_eval = mle.clone().evaluate_with_config(point, cfg).unwrap();
593
594        assert_eq!(subclaims.expected_evaluations[0], eq_eval * a_eval);
595    }
596}