1pub mod multi_degree;
2pub mod prover;
3pub mod verifier;
5
6#[cfg(test)]
7mod tests;
8
9use self::verifier::Subclaim;
10use crate::sumcheck::{
11 prover::{NatEvaluatedPolyWithoutConstant, ProverMsg},
12 verifier::VerifierState,
13};
14use crypto_primitives::{FromPrimitiveWithConfig, PrimeField};
15use num_traits::Zero;
16use prover::ProverState;
17use std::marker::PhantomData;
18use thiserror::Error;
19use zinc_poly::{EvaluationError, mle::DenseMultilinearExtension, utils::ArithErrors};
20use zinc_transcript::traits::{ConstTranscribable, GenTranscribable, Transcribable, Transcript};
21use zinc_utils::{inner_transparent_field::InnerTransparentField, mul};
22
23pub struct MLSumcheck<F>(PhantomData<F>);
25
26#[derive(Clone, Debug, PartialEq, Eq)]
28pub struct SumcheckProof<F> {
29 pub messages: Vec<ProverMsg<F>>,
31 pub claimed_sum: F,
33}
34
35impl<F: PrimeField> GenTranscribable for SumcheckProof<F>
36where
37 F::Inner: ConstTranscribable,
38 F::Modulus: ConstTranscribable,
39{
40 fn read_transcription_bytes_exact(bytes: &[u8]) -> Self {
41 let mod_size = F::Modulus::NUM_BYTES;
42 let cfg = zinc_transcript::read_field_cfg::<F>(&bytes[..mod_size]);
43 let bytes = &bytes[mod_size..];
44
45 let (n_msgs, mut bytes) = u32::read_transcription_bytes_subset(bytes);
46 let n_msgs = usize::try_from(n_msgs).expect("message count must fit into usize");
47
48 let mut messages = Vec::with_capacity(n_msgs);
49 for _ in 0..n_msgs {
50 let (len, rest) = u32::read_transcription_bytes_subset(bytes);
51 let len = usize::try_from(len).expect("polynomial length must fit into usize");
52 bytes = rest;
53 let end = mul!(len, F::Inner::NUM_BYTES);
54 let tail_evaluations = zinc_transcript::read_field_vec_with_cfg(&bytes[..end], &cfg);
55 messages.push(ProverMsg(NatEvaluatedPolyWithoutConstant {
56 tail_evaluations,
57 }));
58 bytes = &bytes[end..];
59 }
60
61 let claimed_sum = F::Inner::read_transcription_bytes_exact(bytes);
62 let claimed_sum = F::new_unchecked_with_cfg(claimed_sum, &cfg);
63 Self {
64 messages,
65 claimed_sum,
66 }
67 }
68
69 fn write_transcription_bytes_exact(&self, mut buf: &mut [u8]) {
70 buf = zinc_transcript::append_field_cfg::<F>(buf, &self.claimed_sum.modulus());
71 buf = {
72 let len =
73 u32::try_from(self.messages.len()).expect("messages length must fit into u32");
74 len.write_transcription_bytes_exact(&mut buf[0..u32::NUM_BYTES]);
75 &mut buf[u32::NUM_BYTES..]
76 };
77 for msg in &self.messages {
78 let evals = &msg.0.tail_evaluations;
79 buf = {
80 let len = u32::try_from(evals.len()).expect("messages length must fit into u32");
81 len.write_transcription_bytes_exact(&mut buf[0..u32::NUM_BYTES]);
82 &mut buf[u32::NUM_BYTES..]
83 };
84 buf = zinc_transcript::append_field_vec_inner(buf, evals);
85 }
86 self.claimed_sum
87 .inner()
88 .write_transcription_bytes_exact(buf);
89 }
90}
91
92impl<F: PrimeField> Transcribable for SumcheckProof<F>
93where
94 F::Inner: ConstTranscribable,
95 F::Modulus: ConstTranscribable,
96{
97 #[allow(clippy::arithmetic_side_effects)]
98 fn get_num_bytes(&self) -> usize {
99 let n_msgs = self.messages.len();
100 let total_evals: usize = self
101 .messages
102 .iter()
103 .map(|m| m.0.tail_evaluations.len())
104 .sum();
105 F::Modulus::NUM_BYTES
106 + u32::NUM_BYTES + n_msgs * u32::NUM_BYTES + total_evals * F::Inner::NUM_BYTES + F::Inner::NUM_BYTES }
111}
112
113impl<F: FromPrimitiveWithConfig> MLSumcheck<F> {
114 pub fn prove_as_subprotocol(
164 transcript: &mut impl Transcript,
165 mles: Vec<DenseMultilinearExtension<F::Inner>>,
166 nvars: usize,
167 degree: usize,
168 comb_fn: impl Fn(&[F]) -> F + Send + Sync,
169 config: &F::Config,
170 ) -> (SumcheckProof<F>, ProverState<F>)
171 where
172 F: InnerTransparentField,
173 F::Inner: ConstTranscribable + Zero,
174 F::Modulus: ConstTranscribable,
175 {
176 if nvars == 0 {
177 panic!("Attempt to prove a constant")
178 }
179
180 let mut buf = vec![0; F::Inner::NUM_BYTES];
181 let nvars_field = F::from_with_cfg(nvars as u64, config);
182 let degree_field = F::from_with_cfg(degree as u64, config);
183
184 transcript.absorb_random_field(&nvars_field, &mut buf);
185 transcript.absorb_random_field(°ree_field, &mut buf);
186
187 let mut prover_state = ProverState::new(mles, nvars, degree);
188 let mut verifier_msg = None;
189 let mut prover_msgs = Vec::with_capacity(nvars);
190
191 for _ in 0..nvars {
192 let prover_msg = prover_state.prove_round(&verifier_msg, &comb_fn, config);
193 transcript.absorb_random_field_slice(&prover_msg.0.tail_evaluations, &mut buf);
194 prover_msgs.push(prover_msg);
195 let next_verifier_msg = transcript.get_field_challenge(config);
196 transcript.absorb_random_field(&next_verifier_msg, &mut buf);
197
198 verifier_msg = Some(next_verifier_msg);
199 }
200 let asserted_sum = prover_state
201 .asserted_sum
202 .clone()
203 .expect("asserted sum should be recorded after the first prover round");
204 if let Some(vmsg) = verifier_msg {
205 prover_state.randomness.push(vmsg);
206 }
207
208 (
209 SumcheckProof {
210 messages: prover_msgs,
211 claimed_sum: asserted_sum,
212 },
213 prover_state,
214 )
215 }
216
217 pub fn verify_as_subprotocol(
269 transcript: &mut impl Transcript,
270 num_vars: usize,
271 degree: usize,
272 proof: &SumcheckProof<F>,
273 config: &F::Config,
274 ) -> Result<Subclaim<F>, SumCheckError<F>>
275 where
276 F::Inner: ConstTranscribable,
277 F::Modulus: ConstTranscribable,
278 {
279 if num_vars == 0 {
280 panic!("Attempt to verify a sumcheck claim for 0 variables")
281 }
282
283 let mut buf = vec![0; F::Inner::NUM_BYTES];
284
285 let (nvars_field, degree_field): (F, F) = {
286 (
287 F::from_with_cfg(num_vars as u64, config),
288 F::from_with_cfg(degree as u64, config),
289 )
290 };
291 transcript.absorb_random_field(&nvars_field, &mut buf);
292 transcript.absorb_random_field(°ree_field, &mut buf);
293
294 if proof.messages.len() != num_vars {
295 return Err(SumCheckError::InvalidProofLength {
296 expected: num_vars,
297 got: proof.messages.len(),
298 });
299 }
300
301 let mut verifier_state = VerifierState::new(num_vars, degree, config);
302
303 for i in 0..num_vars {
304 let prover_msg = &proof.messages[i];
305 transcript.absorb_random_field_slice(&prover_msg.0.tail_evaluations, &mut buf);
306 let verifier_msg = verifier_state.verify_round(prover_msg, transcript);
307 transcript.absorb_random_field(&verifier_msg, &mut buf);
308 }
309
310 verifier_state.check_and_generate_subclaim(proof.claimed_sum.clone())
311 }
312}
313
314#[derive(Error, Debug)]
315pub enum SumCheckError<F> {
316 #[error("univariate polynomial evaluation error")]
317 EvaluationError(ArithErrors),
318 #[error("incorrect sumcheck sum at round {0}. Expected `{1}`. Received `{2}`")]
319 SumCheckFailed(usize, Box<F>, Box<F>),
320 #[error("max degree exceeded")]
321 MaxDegreeExceeded,
322 #[error("invalid proof length: expected {expected}, got {got}")]
323 InvalidProofLength { expected: usize, got: usize },
324 #[error("verifier failed to evaluate a round polynomial: {0}")]
325 UnivariateEvaluationError(EvaluationError),
326}
327
328impl<F> From<ArithErrors> for SumCheckError<F> {
329 fn from(arith_error: ArithErrors) -> Self {
330 Self::EvaluationError(arith_error)
331 }
332}
333
334impl<F> From<EvaluationError> for SumCheckError<F> {
335 fn from(error: EvaluationError) -> Self {
336 Self::UnivariateEvaluationError(error)
337 }
338}