1use crypto_primitives::PrimeField;
7use zinc_poly::utils::next_mle_eval;
8
9#[allow(clippy::arithmetic_side_effects)]
19pub fn eval_shift_predicate<F: PrimeField>(x: &[F], y: &[F], c: usize, cfg: &F::Config) -> F {
20 let m = x.len();
21 assert_eq!(y.len(), m);
22 let zero = F::zero_with_cfg(cfg);
23 let one = F::one_with_cfg(cfg);
24
25 if c == 0 {
27 return eval_eq_poly(x, y, &one);
28 }
29
30 if c == 1 {
32 return next_mle_eval(x, y, zero, one);
33 }
34
35 assert!(c < (1usize << m), "shift c must satisfy c < 2^m");
36 let k = (2 * c).next_power_of_two().trailing_zeros() as usize;
38 if k >= m {
39 return eval_shift_small(x, y, c, m, &zero, &one);
40 }
41
42 let (x_lo, x_hi) = x.split_at(k);
44 let (y_lo, y_hi) = y.split_at(k);
45
46 let l0 = eval_l0(x_lo, y_lo, c, k, &zero, &one);
47 let l1 = eval_l1(x_lo, y_lo, c, k, &zero, &one);
48 let eq = eval_eq_poly(x_hi, y_hi, &one);
49 let next = next_mle_eval(x_hi, y_hi, zero, one);
50 l0 * eq + l1 * next
51}
52
53pub(crate) fn eval_eq_poly<F: PrimeField>(u: &[F], v: &[F], one: &F) -> F {
57 u.iter()
58 .zip(v.iter())
59 .map(|(u_i, v_i)| u_i.clone() * v_i + (one.clone() - u_i) * (one.clone() - v_i))
60 .fold(one.clone(), |acc, term| acc * term)
61}
62
63pub(crate) fn eval_delta<F: PrimeField>(u: &[F], a: usize, k: usize, one: &F) -> F {
70 let mut result = one.clone();
71 for (i, u) in u.iter().take(k).enumerate() {
72 let bit = (a >> i) & 1;
73 if bit == 1 {
74 result *= u;
75 } else {
76 result *= one.clone() - u
77 }
78 }
79 result
80}
81
82#[allow(clippy::arithmetic_side_effects)]
89pub(crate) fn eval_l0<F: PrimeField>(
90 x_lo: &[F],
91 y_lo: &[F],
92 c: usize,
93 k: usize,
94 zero: &F,
95 one: &F,
96) -> F {
97 let upper = (1 << k) - c;
98 (0..upper).fold(zero.clone(), |acc, a| {
99 acc + eval_delta(x_lo, a, k, one) * eval_delta(y_lo, a + c, k, one)
100 })
101}
102
103#[allow(clippy::arithmetic_side_effects)]
109pub(crate) fn eval_l1<F: PrimeField>(
110 x_lo: &[F],
111 y_lo: &[F],
112 c: usize,
113 k: usize,
114 zero: &F,
115 one: &F,
116) -> F {
117 let two_k = 1 << k;
118 ((two_k - c)..two_k).fold(zero.clone(), |acc, a| {
119 acc + eval_delta(x_lo, a, k, one) * eval_delta(y_lo, a + c - two_k, k, one)
120 })
121}
122
123#[allow(clippy::arithmetic_side_effects)]
127fn eval_shift_small<F: PrimeField>(x: &[F], y: &[F], c: usize, m: usize, zero: &F, one: &F) -> F {
128 let upper = (1 << m) - c;
129 (0..upper).fold(zero.clone(), |acc, a| {
130 acc + eval_delta(x, a, m, one) * eval_delta(y, a + c, m, one)
131 })
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use crate::test_utils::test_config;
138 use crypto_primitives::{Field, FromWithConfig, crypto_bigint_monty::MontyField};
139 use rand::Rng;
140 use zinc_poly::utils::{build_eq_x_r_inner, build_next_c_r_mle};
141
142 type F = MontyField<4>;
143
144 fn to_bin(val: usize, bit: usize, cfg: &<F as PrimeField>::Config) -> F {
146 if (val >> bit) & 1 == 1 {
147 F::one_with_cfg(cfg)
148 } else {
149 F::zero_with_cfg(cfg)
150 }
151 }
152
153 fn from_inner(inner: <F as Field>::Inner, cfg: &<F as PrimeField>::Config) -> F {
155 let mut f = F::zero_with_cfg(cfg);
156 *f.inner_mut() = inner;
157 f
158 }
159
160 fn rand_field(rng: &mut impl Rng, cfg: &<F as PrimeField>::Config) -> F {
161 F::from_with_cfg(rng.random::<u32>(), cfg)
162 }
163
164 #[test]
167 fn test_shift_predicate_boolean() {
168 let cfg = test_config();
169 let m = 4;
170 let n = 1usize << m;
171
172 for c in [1, 2, 5] {
173 for a in 0..n {
174 for b in 0..n {
175 let x: Vec<F> = (0..m).map(|i| to_bin(a, i, &cfg)).collect();
176 let y: Vec<F> = (0..m).map(|i| to_bin(b, i, &cfg)).collect();
177 let val = eval_shift_predicate(&x, &y, c, &cfg);
178
179 if b == a + c && a + c < n {
180 assert_eq!(
181 val,
182 F::one_with_cfg(&cfg),
183 "S_c({a},{b}) should be 1 for c={c}"
184 );
185 } else {
186 assert_eq!(
187 val,
188 F::zero_with_cfg(&cfg),
189 "S_c({a},{b}) should be 0 for c={c}"
190 );
191 }
192 }
193 }
194 }
195 }
196
197 #[test]
199 fn test_next_boolean() {
200 let cfg = test_config();
201 let m = 4;
202 let n = 1usize << m;
203 for a in 0..n {
204 for b in 0..n {
205 let u: Vec<F> = (0..m).map(|i| to_bin(a, i, &cfg)).collect();
206 let v: Vec<F> = (0..m).map(|i| to_bin(b, i, &cfg)).collect();
207 let val = next_mle_eval(&u, &v, F::zero_with_cfg(&cfg), F::one_with_cfg(&cfg));
208
209 if b == a + 1 && a + 1 < n {
210 assert_eq!(val, F::one_with_cfg(&cfg), "Next({a},{b}) should be 1");
211 } else {
212 assert_eq!(val, F::zero_with_cfg(&cfg), "Next({a},{b}) should be 0");
213 }
214 }
215 }
216 }
217
218 #[test]
222 fn test_shift_predicate_vs_prover_mle() {
223 let cfg = test_config();
224 let mut rng = rand::rng();
225 let m = 4;
226 let n = 1usize << m;
227 let c = 3;
228
229 let r: Vec<F> = (0..m).map(|_| rand_field(&mut rng, &cfg)).collect();
230 let next_c = build_next_c_r_mle(&r, c, &cfg).unwrap();
231
232 for b in 0..n {
233 let b_bin: Vec<F> = (0..m).map(|i| to_bin(b, i, &cfg)).collect();
234 let val = eval_shift_predicate(&r, &b_bin, c, &cfg);
235 let expected = from_inner(next_c.evaluations[b], &cfg);
236 assert_eq!(val, expected, "S_{c}(r, bin({b})) mismatch with prover MLE");
237 }
238 }
239
240 #[test]
244 fn test_shift_predicate_random_points() {
245 let cfg = test_config();
246 let mut rng = rand::rng();
247 let m = 4;
248 let c = 3;
249
250 for _ in 0..8 {
251 let r: Vec<F> = (0..m).map(|_| rand_field(&mut rng, &cfg)).collect();
252 let y: Vec<F> = (0..m).map(|_| rand_field(&mut rng, &cfg)).collect();
253
254 let next_c = build_next_c_r_mle(&r, c, &cfg).unwrap();
255 let eq_y = build_eq_x_r_inner(&y, &cfg).unwrap();
256 let zero = F::zero_with_cfg(&cfg);
257 let rhs = next_c
258 .evaluations
259 .iter()
260 .zip(eq_y.evaluations.iter())
261 .fold(zero, |acc, (ni, ei)| {
262 acc + from_inner(*ni, &cfg) * from_inner(*ei, &cfg)
263 });
264 let lhs = eval_shift_predicate(&r, &y, c, &cfg);
265
266 assert_eq!(lhs, rhs, "random-point MLE mismatch");
267 }
268 }
269
270 #[test]
274 fn test_fast_paths_and_multi_c() {
275 let cfg = test_config();
276 let mut rng = rand::rng();
277 let m = 4;
278 let n = 1usize << m;
279
280 for c in [0, 1, 2, 5, 7] {
281 let r: Vec<F> = (0..m).map(|_| rand_field(&mut rng, &cfg)).collect();
282 let next_c = build_next_c_r_mle(&r, c, &cfg).unwrap();
283
284 for b in 0..n {
286 let b_bin: Vec<F> = (0..m).map(|i| to_bin(b, i, &cfg)).collect();
287 let val = eval_shift_predicate(&r, &b_bin, c, &cfg);
288 let expected = from_inner(next_c.evaluations[b], &cfg);
289 assert_eq!(val, expected, "S_{c}(r, bin({b})) mismatch with prover MLE");
290 }
291
292 for _ in 0..4 {
294 let y: Vec<F> = (0..m).map(|_| rand_field(&mut rng, &cfg)).collect();
295 let eq_y = build_eq_x_r_inner(&y, &cfg).unwrap();
296 let zero = F::zero_with_cfg(&cfg);
297 let rhs = next_c
298 .evaluations
299 .iter()
300 .zip(eq_y.evaluations.iter())
301 .fold(zero.clone(), |acc, (ni, ei)| {
302 acc + from_inner(*ni, &cfg) * from_inner(*ei, &cfg)
303 });
304 let lhs = eval_shift_predicate(&r, &y, c, &cfg);
305 assert_eq!(lhs, rhs, "random-point MLE mismatch for c={c}");
306 }
307 }
308 }
309
310 #[test]
312 fn test_shift_predicate_boundary() {
313 let cfg = test_config();
314 let m = 3;
315 let n = 1usize << m; for c in [n / 2, n - 1] {
318 for a in 0..n {
320 for b in 0..n {
321 let x: Vec<F> = (0..m).map(|i| to_bin(a, i, &cfg)).collect();
322 let y: Vec<F> = (0..m).map(|i| to_bin(b, i, &cfg)).collect();
323 let val = eval_shift_predicate(&x, &y, c, &cfg);
324
325 if b == a + c && a + c < n {
326 assert_eq!(
327 val,
328 F::one_with_cfg(&cfg),
329 "S_{c}(bin({a}), bin({b})) should be 1"
330 );
331 } else {
332 assert_eq!(
333 val,
334 F::zero_with_cfg(&cfg),
335 "S_{c}(bin({a}), bin({b})) should be 0"
336 );
337 }
338 }
339 }
340
341 let mut rng = rand::rng();
343 let r: Vec<F> = (0..m).map(|_| rand_field(&mut rng, &cfg)).collect();
344 let next_c = build_next_c_r_mle(&r, c, &cfg).unwrap();
345 let zero_inner = *F::zero_with_cfg(&cfg).inner();
346
347 for b in 0..c {
349 assert_eq!(
350 next_c.evaluations[b], zero_inner,
351 "next_c[{b}] should be zero for c={c}"
352 );
353 }
354 let nonzero_count = next_c.evaluations[c..]
356 .iter()
357 .filter(|e| **e != zero_inner)
358 .count();
359 assert_eq!(
360 nonzero_count,
361 n - c,
362 "expected {} nonzero entries for c={c}",
363 n - c
364 );
365 }
366 }
367
368 #[test]
371 fn test_prover_mle_inner_product() {
372 let cfg = test_config();
373 let mut rng = rand::rng();
374 let m = 4;
375 let n = 1usize << m;
376
377 for c in [1, 2, 3, 7] {
378 let v: Vec<F> = (0..n).map(|_| rand_field(&mut rng, &cfg)).collect();
379 let r: Vec<F> = (0..m).map(|_| rand_field(&mut rng, &cfg)).collect();
380
381 let eq_r = build_eq_x_r_inner(&r, &cfg).unwrap();
383 let zero = F::zero_with_cfg(&cfg);
384 let expected = (c..n).fold(zero.clone(), |acc, b| {
385 acc + from_inner(eq_r.evaluations[b - c], &cfg) * &v[b]
386 });
387
388 let next_c = build_next_c_r_mle(&r, c, &cfg).unwrap();
390 let got = next_c
391 .evaluations
392 .iter()
393 .zip(v.iter())
394 .fold(zero, |acc, (ni, vi)| {
395 acc + vi.clone() * from_inner(*ni, &cfg)
396 });
397
398 assert_eq!(got, expected, "prover MLE inner product mismatch for c={c}");
399 }
400 }
401}