Skip to main content

zinc_piop/
sumcheck.rs

1pub mod multi_degree;
2pub mod prover;
3// pub mod utils;
4pub mod verifier;
5
6#[cfg(test)]
7mod tests;
8
9use self::verifier::Subclaim;
10use crate::sumcheck::{
11    prover::{NatEvaluatedPolyWithoutConstant, ProverMsg},
12    verifier::VerifierState,
13};
14use crypto_primitives::{FromPrimitiveWithConfig, PrimeField};
15use num_traits::Zero;
16use prover::ProverState;
17use std::marker::PhantomData;
18use thiserror::Error;
19use zinc_poly::{EvaluationError, mle::DenseMultilinearExtension, utils::ArithErrors};
20use zinc_transcript::traits::{ConstTranscribable, GenTranscribable, Transcribable, Transcript};
21use zinc_utils::{inner_transparent_field::InnerTransparentField, mul};
22
23/// Sumcheck for products of multilinear polynomial.
24pub struct MLSumcheck<F>(PhantomData<F>);
25
26/// Proof generated by the sumcheck prover.
27#[derive(Clone, Debug, PartialEq, Eq)]
28pub struct SumcheckProof<F> {
29    /// List of prover messages, one for each round.
30    pub messages: Vec<ProverMsg<F>>,
31    /// The claimed sum for the first round polynomial.
32    pub claimed_sum: F,
33}
34
35impl<F: PrimeField> GenTranscribable for SumcheckProof<F>
36where
37    F::Inner: ConstTranscribable,
38    F::Modulus: ConstTranscribable,
39{
40    fn read_transcription_bytes_exact(bytes: &[u8]) -> Self {
41        let mod_size = F::Modulus::NUM_BYTES;
42        let cfg = zinc_transcript::read_field_cfg::<F>(&bytes[..mod_size]);
43        let bytes = &bytes[mod_size..];
44
45        let (n_msgs, mut bytes) = u32::read_transcription_bytes_subset(bytes);
46        let n_msgs = usize::try_from(n_msgs).expect("message count must fit into usize");
47
48        let mut messages = Vec::with_capacity(n_msgs);
49        for _ in 0..n_msgs {
50            let (len, rest) = u32::read_transcription_bytes_subset(bytes);
51            let len = usize::try_from(len).expect("polynomial length must fit into usize");
52            bytes = rest;
53            let end = mul!(len, F::Inner::NUM_BYTES);
54            let tail_evaluations = zinc_transcript::read_field_vec_with_cfg(&bytes[..end], &cfg);
55            messages.push(ProverMsg(NatEvaluatedPolyWithoutConstant {
56                tail_evaluations,
57            }));
58            bytes = &bytes[end..];
59        }
60
61        let claimed_sum = F::Inner::read_transcription_bytes_exact(bytes);
62        let claimed_sum = F::new_unchecked_with_cfg(claimed_sum, &cfg);
63        Self {
64            messages,
65            claimed_sum,
66        }
67    }
68
69    fn write_transcription_bytes_exact(&self, mut buf: &mut [u8]) {
70        buf = zinc_transcript::append_field_cfg::<F>(buf, &self.claimed_sum.modulus());
71        buf = {
72            let len =
73                u32::try_from(self.messages.len()).expect("messages length must fit into u32");
74            len.write_transcription_bytes_exact(&mut buf[0..u32::NUM_BYTES]);
75            &mut buf[u32::NUM_BYTES..]
76        };
77        for msg in &self.messages {
78            let evals = &msg.0.tail_evaluations;
79            buf = {
80                let len = u32::try_from(evals.len()).expect("messages length must fit into u32");
81                len.write_transcription_bytes_exact(&mut buf[0..u32::NUM_BYTES]);
82                &mut buf[u32::NUM_BYTES..]
83            };
84            buf = zinc_transcript::append_field_vec_inner(buf, evals);
85        }
86        self.claimed_sum
87            .inner()
88            .write_transcription_bytes_exact(buf);
89    }
90}
91
92impl<F: PrimeField> Transcribable for SumcheckProof<F>
93where
94    F::Inner: ConstTranscribable,
95    F::Modulus: ConstTranscribable,
96{
97    #[allow(clippy::arithmetic_side_effects)]
98    fn get_num_bytes(&self) -> usize {
99        let n_msgs = self.messages.len();
100        let total_evals: usize = self
101            .messages
102            .iter()
103            .map(|m| m.0.tail_evaluations.len())
104            .sum();
105        F::Modulus::NUM_BYTES
106            + u32::NUM_BYTES // n_msgs
107            + n_msgs * u32::NUM_BYTES // n_evals for each message
108            + total_evals * F::Inner::NUM_BYTES // evals
109            + F::Inner::NUM_BYTES // claimed_sum
110    }
111}
112
113impl<F: FromPrimitiveWithConfig> MLSumcheck<F> {
114    /// Sumcheck prover main entry point.
115    ///
116    /// This function executes the Prover side of the Sumcheck protocol.
117    /// It verifies a claim of the form:
118    ///
119    /// $$
120    /// \sum_{x \in \{0, 1\}^{\text{nvars}}} \text{comb\\_fn}(\text{mles}(x)) =
121    /// \text{claimed\\_sum}. $$
122    ///
123    /// It is designed to be used as a subprotocol within a larger system
124    /// since it takes the FS transcript (`transcript` argument) as input
125    /// and returns the **internal ProverState** alongside the final proof.
126    ///
127    /// The claimed sum is derived by the prover.
128    ///
129    /// ---
130    ///
131    /// # Arguments
132    ///
133    /// * `transcript`: A mutable reference to a Fiat-Shamir `Transcript`.
134    /// * `mles`: A `Vec` of dense multilinear extension over the base field
135    ///   `F`. The sumcheck polynomial is made over the combined result of these
136    ///   multilinear extensions.
137    /// * `nvars`: The number of variables over which the `mles` are defined.
138    ///   This must be consistent across all `mles`.
139    /// * `degree`: The maximum combined degree of the `mles` under the
140    ///   `comb_fn`.
141    /// * `comb_fn`: A closure that defines the combination function
142    ///   $G(\text{mles}(x))$. It takes a slice of field elements (the
143    ///   evaluations of the `mles` at a point $x$) and returns a single field
144    ///   element.
145    /// * `config`: The configuration for the underlying field used in the
146    ///   protocol.
147    ///
148    /// ---
149    ///
150    /// # Returns
151    ///
152    /// A tuple containing:
153    ///
154    /// 1. `SumcheckProof<F>`: The final sumcheck proof.
155    /// 2. `ProverState<F>`: The state of the Prover after the protocol
156    ///    completes.
157    ///
158    /// ---
159    ///
160    /// # Panics
161    ///
162    /// * Panics if the number of variables is `0`.
163    pub fn prove_as_subprotocol(
164        transcript: &mut impl Transcript,
165        mles: Vec<DenseMultilinearExtension<F::Inner>>,
166        nvars: usize,
167        degree: usize,
168        comb_fn: impl Fn(&[F]) -> F + Send + Sync,
169        config: &F::Config,
170    ) -> (SumcheckProof<F>, ProverState<F>)
171    where
172        F: InnerTransparentField,
173        F::Inner: ConstTranscribable + Zero,
174        F::Modulus: ConstTranscribable,
175    {
176        if nvars == 0 {
177            panic!("Attempt to prove a constant")
178        }
179
180        let mut buf = vec![0; F::Inner::NUM_BYTES];
181        let nvars_field = F::from_with_cfg(nvars as u64, config);
182        let degree_field = F::from_with_cfg(degree as u64, config);
183
184        transcript.absorb_random_field(&nvars_field, &mut buf);
185        transcript.absorb_random_field(&degree_field, &mut buf);
186
187        let mut prover_state = ProverState::new(mles, nvars, degree);
188        let mut verifier_msg = None;
189        let mut prover_msgs = Vec::with_capacity(nvars);
190
191        for _ in 0..nvars {
192            let prover_msg = prover_state.prove_round(&verifier_msg, &comb_fn, config);
193            transcript.absorb_random_field_slice(&prover_msg.0.tail_evaluations, &mut buf);
194            prover_msgs.push(prover_msg);
195            let next_verifier_msg = transcript.get_field_challenge(config);
196            transcript.absorb_random_field(&next_verifier_msg, &mut buf);
197
198            verifier_msg = Some(next_verifier_msg);
199        }
200        let asserted_sum = prover_state
201            .asserted_sum
202            .clone()
203            .expect("asserted sum should be recorded after the first prover round");
204        if let Some(vmsg) = verifier_msg {
205            prover_state.randomness.push(vmsg);
206        }
207
208        (
209            SumcheckProof {
210                messages: prover_msgs,
211                claimed_sum: asserted_sum,
212            },
213            prover_state,
214        )
215    }
216
217    /// Sumcheck verifier main entry point.
218    ///
219    /// This function executes the Verifier side of the Sumcheck protocol.
220    /// It takes a `proof` and a `claimed_sum` and verifies the
221    /// intermediate steps of the sumcheck.
222    ///
223    /// The sumcheck verifies the claim:
224    ///
225    /// $$
226    /// \sum_{x \in \{0, 1\}^{\text{num\\_vars}}} G(x) = \text{claimed\\_sum}.
227    /// $$
228    ///
229    /// It is designed to be used as a subprotocol within a larger system.
230    /// If successful, it returns a **Subclaim**, a final equation that the
231    /// outer protocol must satisfy for the overall proof to be valid.
232    ///
233    /// ---
234    ///
235    /// # Arguments
236    ///
237    /// * `transcript`: A mutable reference to a Fiat-Shamir `Transcript`.
238    /// * `num_vars`: The number of variables over which the sum was originally
239    ///   computed.
240    /// * `degree`: The maximum combined degree of the underlying polynomial
241    ///   $G(x)$. This must match the degree used by the Prover.
242    /// * `proof`: A reference to the `SumcheckProof<F>` generated by the
243    ///   Prover.
244    /// * `config`: The configuration for the underlying field used in the
245    ///   protocol.
246    ///
247    /// ---
248    ///
249    /// # Returns
250    ///
251    /// A `Result` which is:
252    ///
253    /// * `Ok(Subclaim<F>)`: If the Sumcheck protocol passes successfully, it
254    ///   returns a `Subclaim`. This claim consists of:
255    ///     1. The final random challenge point $r \in
256    ///        \text{F}^{\text{num\\_vars}}$.
257    ///     2. The expected evaluation $v$ of the combined polynomial $G(r)$ at
258    ///        that point.
259    ///
260    /// * `Err(SumCheckError<F>)`: If any of the round checks fail during the
261    ///   protocol.
262    ///
263    /// ---
264    ///
265    /// # Panics
266    ///
267    /// * Panics if the number of variables is `0`.
268    pub fn verify_as_subprotocol(
269        transcript: &mut impl Transcript,
270        num_vars: usize,
271        degree: usize,
272        proof: &SumcheckProof<F>,
273        config: &F::Config,
274    ) -> Result<Subclaim<F>, SumCheckError<F>>
275    where
276        F::Inner: ConstTranscribable,
277        F::Modulus: ConstTranscribable,
278    {
279        if num_vars == 0 {
280            panic!("Attempt to verify a sumcheck claim for 0 variables")
281        }
282
283        let mut buf = vec![0; F::Inner::NUM_BYTES];
284
285        let (nvars_field, degree_field): (F, F) = {
286            (
287                F::from_with_cfg(num_vars as u64, config),
288                F::from_with_cfg(degree as u64, config),
289            )
290        };
291        transcript.absorb_random_field(&nvars_field, &mut buf);
292        transcript.absorb_random_field(&degree_field, &mut buf);
293
294        if proof.messages.len() != num_vars {
295            return Err(SumCheckError::InvalidProofLength {
296                expected: num_vars,
297                got: proof.messages.len(),
298            });
299        }
300
301        let mut verifier_state = VerifierState::new(num_vars, degree, config);
302
303        for i in 0..num_vars {
304            let prover_msg = &proof.messages[i];
305            transcript.absorb_random_field_slice(&prover_msg.0.tail_evaluations, &mut buf);
306            let verifier_msg = verifier_state.verify_round(prover_msg, transcript);
307            transcript.absorb_random_field(&verifier_msg, &mut buf);
308        }
309
310        verifier_state.check_and_generate_subclaim(proof.claimed_sum.clone())
311    }
312}
313
314#[derive(Error, Debug)]
315pub enum SumCheckError<F> {
316    #[error("univariate polynomial evaluation error")]
317    EvaluationError(ArithErrors),
318    #[error("incorrect sumcheck sum at round {0}. Expected `{1}`. Received `{2}`")]
319    SumCheckFailed(usize, Box<F>, Box<F>),
320    #[error("max degree exceeded")]
321    MaxDegreeExceeded,
322    #[error("invalid proof length: expected {expected}, got {got}")]
323    InvalidProofLength { expected: usize, got: usize },
324    #[error("verifier failed to evaluate a round polynomial: {0}")]
325    UnivariateEvaluationError(EvaluationError),
326}
327
328impl<F> From<ArithErrors> for SumCheckError<F> {
329    fn from(arith_error: ArithErrors) -> Self {
330        Self::EvaluationError(arith_error)
331    }
332}
333
334impl<F> From<EvaluationError> for SumCheckError<F> {
335    fn from(error: EvaluationError) -> Self {
336        Self::UnivariateEvaluationError(error)
337    }
338}