zinc_piop/sumcheck/
prover.rs1use std::slice;
4
5use crypto_primitives::PrimeField;
6#[cfg(feature = "parallel")]
7use rayon::iter::*;
8use zinc_poly::mle::{DenseMultilinearExtension, MultilinearExtensionWithConfig};
9use zinc_transcript::{delegate_transcribable, traits::ConstTranscribable};
10use zinc_utils::{cfg_into_iter, cfg_iter_mut, inner_transparent_field::InnerTransparentField};
11
12#[repr(transparent)]
14#[derive(Clone, Debug, PartialEq, Eq)]
15pub struct NatEvaluatedPolyWithoutConstant<F> {
16 pub tail_evaluations: Vec<F>,
18}
19
20impl<F> NatEvaluatedPolyWithoutConstant<F> {
21 pub fn new(tail_evaluations: Vec<F>) -> Self {
22 Self { tail_evaluations }
23 }
24}
25
26impl<F> std::ops::Deref for NatEvaluatedPolyWithoutConstant<F> {
27 type Target = [F];
28
29 fn deref(&self) -> &Self::Target {
30 &self.tail_evaluations
31 }
32}
33
34impl<F> std::ops::DerefMut for NatEvaluatedPolyWithoutConstant<F> {
35 fn deref_mut(&mut self) -> &mut Self::Target {
36 &mut self.tail_evaluations
37 }
38}
39
40delegate_transcribable!(NatEvaluatedPolyWithoutConstant<F> { tail_evaluations: Vec<F> }
41 where F: PrimeField, F::Inner: ConstTranscribable, F::Modulus: ConstTranscribable);
42
43#[repr(transparent)]
44#[derive(Clone, Debug, PartialEq, Eq)]
45pub struct ProverMsg<F>(pub NatEvaluatedPolyWithoutConstant<F>);
46
47delegate_transcribable!(ProverMsg<F>(NatEvaluatedPolyWithoutConstant<F>)
48 where F: PrimeField, F::Inner: ConstTranscribable, F::Modulus: ConstTranscribable);
49
50pub struct ProverState<F: PrimeField> {
52 pub randomness: Vec<F>,
54 pub mles: Vec<DenseMultilinearExtension<F::Inner>>,
57 pub num_vars: usize,
59 pub max_degree: usize,
61 pub round: usize,
63 pub asserted_sum: Option<F>,
65}
66
67impl<F: PrimeField> ProverState<F> {
68 pub fn new(
71 mles: Vec<DenseMultilinearExtension<F::Inner>>,
72 nvars: usize,
73 degree: usize,
74 ) -> Self {
75 Self {
76 randomness: Vec::with_capacity(nvars),
77 mles,
78 num_vars: nvars,
79 max_degree: degree,
80 round: 0,
81 asserted_sum: None,
82 }
83 }
84}
85
86impl<F> ProverState<F>
87where
88 F: InnerTransparentField,
89{
90 #[allow(clippy::arithmetic_side_effects)]
95 pub fn prove_round(
96 &mut self,
97 v_msg: &Option<F>,
98 comb_fn: impl Fn(&[F]) -> F + Send + Sync,
99 config: &F::Config,
100 ) -> ProverMsg<F> {
101 if let Some(msg) = v_msg {
102 if self.round == 0 {
103 panic!("first round should be prover first.");
104 }
105 self.randomness.push(msg.clone());
106
107 let i = self.round;
109 let r = self.randomness[i - 1].clone();
110
111 cfg_iter_mut!(self.mles).for_each(|multiplicand| {
112 multiplicand.fix_variables_with_config(slice::from_ref(&r), config);
113 });
114 } else if self.round > 0 {
115 panic!("verifier message is empty");
116 }
117
118 self.round += 1;
119
120 if self.round > self.num_vars {
121 panic!("Prover is not active");
122 }
123
124 let i = self.round;
125 let nv = self.num_vars;
126 let degree = self.max_degree;
127
128 let polys = &self.mles;
129
130 struct Scratch<R> {
131 evals: Vec<R>,
132 steps: Vec<R>,
133 vals0: Vec<R>,
134 vals1: Vec<R>,
135 vals: Vec<R>,
136 levals: Vec<R>,
137 }
138 let zero = F::zero_with_cfg(config);
139 let zero_vec_deg = vec![zero.clone(); degree + 1];
140 let zero_vec_poly = vec![zero.clone(); polys.len()];
141 let scratch = || Scratch {
142 evals: zero_vec_deg.clone(),
143 steps: zero_vec_poly.clone(),
144 vals0: zero_vec_poly.clone(),
145 vals1: zero_vec_poly.clone(),
146 vals: zero_vec_poly.clone(),
147 levals: zero_vec_deg.clone(),
148 };
149
150 #[cfg(not(feature = "parallel"))]
151 let zeros = scratch();
152 #[cfg(feature = "parallel")]
153 let zeros = scratch;
154
155 let summer = cfg_into_iter!(0..1 << (nv - i)).fold(zeros, |mut s, b| {
156 let index = b << 1;
157
158 s.vals0
171 .iter_mut()
172 .zip(polys.iter())
173 .for_each(|(v0, poly)| *v0.inner_mut() = poly[index].clone());
174 s.levals[0] = comb_fn(&s.vals0);
175
176 if degree > 0 {
177 s.vals1
178 .iter_mut()
179 .zip(polys.iter())
180 .for_each(|(v1, poly)| *v1.inner_mut() = poly[index + 1].clone());
181 s.levals[1] = comb_fn(&s.vals1);
182
183 for (i, (v1, v0)) in s.vals1.iter().zip(s.vals0.iter()).enumerate() {
184 s.steps[i] = v1.clone() - v0.clone();
185 s.vals[i] = v1.clone();
186 }
187
188 for eval_point in s.levals.iter_mut().take(degree + 1).skip(2) {
189 for poly_i in 0..polys.len() {
190 s.vals[poly_i] += &s.steps[poly_i];
191 }
192 *eval_point = comb_fn(&s.vals);
193 }
194 }
195
196 s.evals
202 .iter_mut()
203 .zip(s.levals.iter())
204 .for_each(|(e, l)| *e += l);
205
206 s
207 });
208
209 #[cfg(feature = "parallel")]
211 let evaluations = summer.map(|s| s.evals).reduce(
212 || vec![zero.clone(); degree + 1],
213 |mut evaluations, evals| {
214 evaluations
215 .iter_mut()
216 .zip(evals)
217 .for_each(|(e, l)| *e += &l);
218 evaluations
219 },
220 );
221
222 #[cfg(not(feature = "parallel"))]
223 let evaluations = summer.evals;
224
225 if self.round == 1 {
227 let p0 = evaluations
228 .first()
229 .expect("evaluations should always contain the constant term");
230 let sum = if degree > 0 {
231 p0.clone()
232 + evaluations
233 .get(1)
234 .expect("degree > 0 implies evaluation at 1 is present")
235 } else {
236 p0.clone()
237 };
238 self.asserted_sum = Some(sum);
239 }
240
241 let mut tail = evaluations;
243 tail.remove(0); ProverMsg(NatEvaluatedPolyWithoutConstant::new(tail))
246 }
247}