Skip to main content

zinc_piop/sumcheck/
prover.rs

1//! Prover
2
3use std::slice;
4
5use crypto_primitives::PrimeField;
6#[cfg(feature = "parallel")]
7use rayon::iter::*;
8use zinc_poly::mle::{DenseMultilinearExtension, MultilinearExtensionWithConfig};
9use zinc_transcript::{delegate_transcribable, traits::ConstTranscribable};
10use zinc_utils::{cfg_into_iter, cfg_iter_mut, inner_transparent_field::InnerTransparentField};
11
12/// Evaluation of a polynomial on natural points without the constant term.
13#[repr(transparent)]
14#[derive(Clone, Debug, PartialEq, Eq)]
15pub struct NatEvaluatedPolyWithoutConstant<F> {
16    /// Evaluations at 1, 2, ... (P(0) is omitted).
17    pub tail_evaluations: Vec<F>,
18}
19
20impl<F> NatEvaluatedPolyWithoutConstant<F> {
21    pub fn new(tail_evaluations: Vec<F>) -> Self {
22        Self { tail_evaluations }
23    }
24}
25
26impl<F> std::ops::Deref for NatEvaluatedPolyWithoutConstant<F> {
27    type Target = [F];
28
29    fn deref(&self) -> &Self::Target {
30        &self.tail_evaluations
31    }
32}
33
34impl<F> std::ops::DerefMut for NatEvaluatedPolyWithoutConstant<F> {
35    fn deref_mut(&mut self) -> &mut Self::Target {
36        &mut self.tail_evaluations
37    }
38}
39
40delegate_transcribable!(NatEvaluatedPolyWithoutConstant<F> { tail_evaluations: Vec<F> }
41    where F: PrimeField, F::Inner: ConstTranscribable, F::Modulus: ConstTranscribable);
42
43#[repr(transparent)]
44#[derive(Clone, Debug, PartialEq, Eq)]
45pub struct ProverMsg<F>(pub NatEvaluatedPolyWithoutConstant<F>);
46
47delegate_transcribable!(ProverMsg<F>(NatEvaluatedPolyWithoutConstant<F>)
48    where F: PrimeField, F::Inner: ConstTranscribable, F::Modulus: ConstTranscribable);
49
50/// Sumcheck Prover State.
51pub struct ProverState<F: PrimeField> {
52    /// Sampled randomness given by the verifier.
53    pub randomness: Vec<F>,
54    /// Stores the list of multilinear extensions
55    /// the sumcheck polynomial is comprised of.
56    pub mles: Vec<DenseMultilinearExtension<F::Inner>>,
57    /// Number of variables.
58    pub num_vars: usize,
59    /// Max degree.
60    pub max_degree: usize,
61    /// The current round number.
62    pub round: usize,
63    /// Claimed sum for the first round polynomial.
64    pub asserted_sum: Option<F>,
65}
66
67impl<F: PrimeField> ProverState<F> {
68    /// Initialize the prover to argue for the sum of products of
69    /// MLE's in {0,1}^`num_vars`.
70    pub fn new(
71        mles: Vec<DenseMultilinearExtension<F::Inner>>,
72        nvars: usize,
73        degree: usize,
74    ) -> Self {
75        Self {
76            randomness: Vec::with_capacity(nvars),
77            mles,
78            num_vars: nvars,
79            max_degree: degree,
80            round: 0,
81            asserted_sum: None,
82        }
83    }
84}
85
86impl<F> ProverState<F>
87where
88    F: InnerTransparentField,
89{
90    /// Receive message from verifier, generate prover message, and proceed to
91    /// next round.
92    ///
93    /// Adapted Jolt's sumcheck implementation.
94    #[allow(clippy::arithmetic_side_effects)]
95    pub fn prove_round(
96        &mut self,
97        v_msg: &Option<F>,
98        comb_fn: impl Fn(&[F]) -> F + Send + Sync,
99        config: &F::Config,
100    ) -> ProverMsg<F> {
101        if let Some(msg) = v_msg {
102            if self.round == 0 {
103                panic!("first round should be prover first.");
104            }
105            self.randomness.push(msg.clone());
106
107            // fix the next variable at the verifier randomness for this round
108            let i = self.round;
109            let r = self.randomness[i - 1].clone();
110
111            cfg_iter_mut!(self.mles).for_each(|multiplicand| {
112                multiplicand.fix_variables_with_config(slice::from_ref(&r), config);
113            });
114        } else if self.round > 0 {
115            panic!("verifier message is empty");
116        }
117
118        self.round += 1;
119
120        if self.round > self.num_vars {
121            panic!("Prover is not active");
122        }
123
124        let i = self.round;
125        let nv = self.num_vars;
126        let degree = self.max_degree;
127
128        let polys = &self.mles;
129
130        struct Scratch<R> {
131            evals: Vec<R>,
132            steps: Vec<R>,
133            vals0: Vec<R>,
134            vals1: Vec<R>,
135            vals: Vec<R>,
136            levals: Vec<R>,
137        }
138        let zero = F::zero_with_cfg(config);
139        let zero_vec_deg = vec![zero.clone(); degree + 1];
140        let zero_vec_poly = vec![zero.clone(); polys.len()];
141        let scratch = || Scratch {
142            evals: zero_vec_deg.clone(),
143            steps: zero_vec_poly.clone(),
144            vals0: zero_vec_poly.clone(),
145            vals1: zero_vec_poly.clone(),
146            vals: zero_vec_poly.clone(),
147            levals: zero_vec_deg.clone(),
148        };
149
150        #[cfg(not(feature = "parallel"))]
151        let zeros = scratch();
152        #[cfg(feature = "parallel")]
153        let zeros = scratch;
154
155        let summer = cfg_into_iter!(0..1 << (nv - i)).fold(zeros, |mut s, b| {
156            let index = b << 1;
157
158            // TODO(Alex): Once you have benches set,
159            //             could please try getting rid of vals0 and vals1 fields in the
160            // structs, replacing them with
161            //
162            //             ```rust
163            //             let vals0: Vec<_> = polys.iter().map(|poly|
164            // poly[index].clone()).collect();             let vals1: Vec<_> =
165            // polys.iter().map(|poly| poly[index + 1].clone()).collect();
166            //             ```
167            //             My bet is that it won't affect running time, but better safe than
168            // sorry.
169
170            s.vals0
171                .iter_mut()
172                .zip(polys.iter())
173                .for_each(|(v0, poly)| *v0.inner_mut() = poly[index].clone());
174            s.levals[0] = comb_fn(&s.vals0);
175
176            if degree > 0 {
177                s.vals1
178                    .iter_mut()
179                    .zip(polys.iter())
180                    .for_each(|(v1, poly)| *v1.inner_mut() = poly[index + 1].clone());
181                s.levals[1] = comb_fn(&s.vals1);
182
183                for (i, (v1, v0)) in s.vals1.iter().zip(s.vals0.iter()).enumerate() {
184                    s.steps[i] = v1.clone() - v0.clone();
185                    s.vals[i] = v1.clone();
186                }
187
188                for eval_point in s.levals.iter_mut().take(degree + 1).skip(2) {
189                    for poly_i in 0..polys.len() {
190                        s.vals[poly_i] += &s.steps[poly_i];
191                    }
192                    *eval_point = comb_fn(&s.vals);
193                }
194            }
195
196            // TODO(Alex): It seems that the only thing
197            //             we pass around meaningfully is evals,
198            //             so this loop could be reworked to map/reduce - maybe even without
199            //             #[cfg(feature = "parallel")]. Would help to get benchmarks up and
200            //             running first though.
201            s.evals
202                .iter_mut()
203                .zip(s.levals.iter())
204                .for_each(|(e, l)| *e += l);
205
206            s
207        });
208
209        // Rayon's fold outputs an iter which still needs to be summed over
210        #[cfg(feature = "parallel")]
211        let evaluations = summer.map(|s| s.evals).reduce(
212            || vec![zero.clone(); degree + 1],
213            |mut evaluations, evals| {
214                evaluations
215                    .iter_mut()
216                    .zip(evals)
217                    .for_each(|(e, l)| *e += &l);
218                evaluations
219            },
220        );
221
222        #[cfg(not(feature = "parallel"))]
223        let evaluations = summer.evals;
224
225        // Record the claimed sum once during the first round.
226        if self.round == 1 {
227            let p0 = evaluations
228                .first()
229                .expect("evaluations should always contain the constant term");
230            let sum = if degree > 0 {
231                p0.clone()
232                    + evaluations
233                        .get(1)
234                        .expect("degree > 0 implies evaluation at 1 is present")
235            } else {
236                p0.clone()
237            };
238            self.asserted_sum = Some(sum);
239        }
240
241        // Strip the constant term before sending, without re-allocating all elements.
242        let mut tail = evaluations;
243        tail.remove(0); // leaves P(0) behind; tail holds P(1..)
244
245        ProverMsg(NatEvaluatedPolyWithoutConstant::new(tail))
246    }
247}