1use crypto_primitives::{Field, PrimeField, Semiring};
2use num_traits::Zero;
3use thiserror::Error;
4use zinc_utils::{cfg_iter_mut, inner_transparent_field::InnerTransparentField, sub};
5
6#[cfg(feature = "parallel")]
7use rayon::prelude::*;
8
9use crate::mle::{DenseMultilinearExtension, dense::CollectDenseMleWithZero};
10
11#[derive(Debug, Clone, Error)]
13pub enum ArithErrors {
14 #[error("Invalid parameters: {0}")]
15 InvalidParameters(String),
16}
17
18pub fn build_eq_x_r<F>(
25 r: &[F],
26 cfg: &F::Config,
27) -> Result<DenseMultilinearExtension<F>, ArithErrors>
28where
29 F: PrimeField,
30{
31 let evals = build_eq_x_r_vec(r, cfg)?;
32 let mle =
33 DenseMultilinearExtension::from_evaluations_vec(r.len(), evals, F::zero_with_cfg(cfg));
34
35 Ok(mle)
36}
37
38pub fn build_eq_x_r_vec<F>(r: &[F], cfg: &F::Config) -> Result<Vec<F>, ArithErrors>
46where
47 F: PrimeField,
48{
49 let mut eval = Vec::new();
61 build_eq_x_r_helper(r, &mut eval, cfg)?;
62
63 Ok(eval)
64}
65
66fn build_eq_x_r_helper<F>(r: &[F], buf: &mut Vec<F>, cfg: &F::Config) -> Result<(), ArithErrors>
70where
71 F: PrimeField,
72{
73 if r.is_empty() {
74 return Err(ArithErrors::InvalidParameters("r length is 0".into()));
75 } else if r.len() == 1 {
76 buf.push(F::one_with_cfg(cfg) - &r[0]);
78 buf.push(r[0].clone());
79 } else {
80 build_eq_x_r_helper(&r[1..], buf, cfg)?;
81
82 let mut res = vec![F::zero_with_cfg(cfg); buf.len() << 1];
88 cfg_iter_mut!(res).enumerate().for_each(|(i, val)| {
89 let bi = buf[i >> 1].clone();
90 let tmp = r[0].clone() * &bi;
91 if (i & 1) == 0 {
92 *val = bi - tmp;
93 } else {
94 *val = tmp;
95 }
96 });
97 *buf = res;
98 }
99
100 Ok(())
101}
102
103pub fn build_eq_x_r_inner<F>(
110 r: &[F],
111 cfg: &F::Config,
112) -> Result<DenseMultilinearExtension<F::Inner>, ArithErrors>
113where
114 F: PrimeField,
115 F::Inner: Zero,
116{
117 let evals = build_eq_x_r_inner_vec(r, cfg)?;
118 let mle = DenseMultilinearExtension {
119 num_vars: r.len(),
120 evaluations: evals,
121 };
122
123 Ok(mle)
124}
125
126fn build_eq_x_r_inner_vec<F>(r: &[F], cfg: &F::Config) -> Result<Vec<F::Inner>, ArithErrors>
134where
135 F: PrimeField,
136 F::Inner: Zero,
137{
138 let mut eval = Vec::new();
150 build_eq_x_r_inner_helper(r, &mut eval, cfg)?;
151
152 Ok(eval)
153}
154
155fn build_eq_x_r_inner_helper<F>(
159 r: &[F],
160 buf: &mut Vec<F::Inner>,
161 cfg: &F::Config,
162) -> Result<(), ArithErrors>
163where
164 F: PrimeField,
165 F::Inner: Zero,
166{
167 let one = F::one_with_cfg(cfg);
168 if r.is_empty() {
169 return Err(ArithErrors::InvalidParameters("r length is 0".into()));
170 } else if r.len() == 1 {
171 buf.push((one - &r[0]).into_inner());
173 buf.push(r[0].inner().clone());
174 } else {
175 build_eq_x_r_inner_helper(&r[1..], buf, cfg)?;
176
177 let mut res = vec![F::Inner::zero(); buf.len() << 1];
183 cfg_iter_mut!(res).enumerate().for_each(|(i, val)| {
184 let bi = F::new_unchecked_with_cfg(buf[i >> 1].clone(), cfg);
185 let tmp = r[0].clone() * &bi;
186 if (i & 1) == 0 {
187 *val = (bi - tmp).into_inner();
188 } else {
189 *val = tmp.into_inner();
190 }
191 });
192 *buf = res;
193 }
194
195 Ok(())
196}
197
198pub fn build_next_c_r_mle<F>(
208 r: &[F],
209 c: usize,
210 field_cfg: &F::Config,
211) -> Result<DenseMultilinearExtension<F::Inner>, ArithErrors>
212where
213 F: PrimeField,
214 F::Inner: Zero,
215{
216 let num_vars = r.len();
217 let n = 1 << num_vars;
218 assert!(c < n, "shift c={c} must be < domain size {n}");
219 let zero_inner = F::zero_with_cfg(field_cfg).into_inner();
220
221 let eq_r = build_eq_x_r_inner(r, field_cfg)?;
222 if c == 0 {
223 return Ok(eq_r);
224 }
225
226 let mut evaluations = Vec::with_capacity(n);
229 evaluations.resize(c, zero_inner);
230 evaluations.extend_from_slice(&eq_r.evaluations[..sub!(n, c)]);
231
232 Ok(DenseMultilinearExtension {
233 num_vars,
234 evaluations,
235 })
236}
237
238#[allow(clippy::arithmetic_side_effects)]
240pub fn eq_eval<R: Semiring>(x: &[R], y: &[R], one: R) -> Result<R, ArithErrors> {
241 if x.len() != y.len() {
242 return Err(ArithErrors::InvalidParameters(
243 "x and y have different length".to_string(),
244 ));
245 }
246
247 let mut res = one.clone();
248 for (xi, yi) in x.iter().zip(y.iter()) {
249 let xi_yi = xi.clone() * yi;
250 res *= xi_yi.clone() + xi_yi - xi - yi + one.clone();
251 }
252
253 Ok(res)
254}
255
256#[allow(clippy::arithmetic_side_effects)]
267pub fn mle_eval_with_eq_table<F: InnerTransparentField>(
268 evaluations: &[F::Inner],
269 eq_table: &[F],
270 cfg: &F::Config,
271) -> F {
272 let mut acc = F::zero_with_cfg(cfg);
273 assert_eq!(
274 evaluations.len(),
275 eq_table.len(),
276 "evaluations and eq_table must have the same length"
277 );
278 for (eval, eq_val) in evaluations.iter().zip(eq_table.iter()) {
279 let mut term = eq_val.clone();
280 term.mul_assign_by_inner(eval);
281 acc += &term;
282 }
283 acc
284}
285
286#[allow(clippy::arithmetic_side_effects)]
289pub fn next_mle_inner<F: Field>(
290 num_vars: u32,
291 zero: F,
292 one: F,
293) -> Result<DenseMultilinearExtension<F::Inner>, ArithErrors> {
294 if !num_vars.is_multiple_of(2) {
295 return Err(ArithErrors::InvalidParameters(
296 "num_vars must be even".to_string(),
297 ));
298 }
299
300 let mut mle = (0..1 << num_vars)
301 .map(|_| zero.inner().clone())
302 .collect_dense_mle_with_zero(zero.inner());
303
304 let half_vars = num_vars / 2;
305
306 for i in 0usize..(1 << half_vars) - 1 {
307 let next = i + 1;
308
309 let i_concat_next = (next << half_vars) | i;
310
311 mle[i_concat_next] = one.inner().clone();
312 }
313
314 Ok(mle)
315}
316
317#[allow(clippy::arithmetic_side_effects)]
338pub fn next_mle_eval<R: Semiring>(u: &[R], v: &[R], zero: R, one: R) -> R {
339 let n = u.len();
340 assert_eq!(n, v.len(), "u and v must have the same length");
341 if n == 0 {
342 return zero;
343 }
344
345 let mut suffix_eq = vec![one.clone(); n + 1];
347 for i in (0..n).rev() {
348 suffix_eq[i] = suffix_eq[i + 1].clone()
349 * (u[i].clone() * &v[i] + (one.clone() - &u[i]) * (one.clone() - &v[i]));
350 }
351
352 let mut prefix_carry = one.clone();
354 let mut result = zero;
355 for j in 0..n {
356 result += prefix_carry.clone() * (one.clone() - &u[j]) * &v[j] * &suffix_eq[j + 1];
357 prefix_carry *= u[j].clone() * (one.clone() - &v[j]);
358 }
359 result
360}
361
362#[cfg(test)]
363#[allow(clippy::arithmetic_side_effects, clippy::cast_possible_truncation)]
364mod tests {
365 use crypto_bigint::{U128, const_monty_params};
366 use crypto_primitives::{IntoWithConfig, crypto_bigint_const_monty::ConstMontyField};
367 use num_traits::One;
368 use proptest::{prelude::*, proptest};
369
370 use crate::mle::MultilinearExtensionWithConfig;
371
372 use super::*;
373
374 const_monty_params!(Params, U128, "00000000b933426489189cb5b47d567f");
375
376 type F = ConstMontyField<Params, { U128::LIMBS }>;
377
378 const NUM_VARS: u32 = 8;
379
380 #[test]
381 fn next_mle_is_one_on_successors() {
382 let next_mle = next_mle_inner(NUM_VARS, F::zero(), F::one()).unwrap();
383
384 for i in 0..(1 << ((NUM_VARS / 2) - 1)) {
385 let mut point: Vec<F> = (0..(NUM_VARS / 2))
386 .map(|j| {
387 if i & (1 << j) == 0 {
388 F::zero()
389 } else {
390 F::one()
391 }
392 })
393 .collect();
394
395 point.extend((0..(NUM_VARS / 2)).map(|j| {
396 if (i + 1) & (1 << j) == 0 {
397 F::zero()
398 } else {
399 F::one()
400 }
401 }));
402
403 assert_eq!(
404 next_mle.clone().evaluate_with_config(&point, &()),
405 Ok(F::one())
406 );
407 }
408 }
409
410 #[test]
411 fn next_mle_is_one_only_on_successors() {
412 let next_mle = next_mle_inner(NUM_VARS, F::zero(), F::one()).unwrap();
413
414 assert_eq!(
418 next_mle.evaluations.iter().filter(|x| !x.is_zero()).count(),
419 (1 << (NUM_VARS / 2)) - 1
420 );
421 }
422
423 fn any_f(cfg: <F as PrimeField>::Config) -> impl Strategy<Value = F> + 'static {
424 any::<u128>().prop_map(move |v| v.into_with_cfg(&cfg))
425 }
426
427 fn point_n(n: usize) -> impl Strategy<Value = Vec<F>> {
428 prop::collection::vec(any_f(()), n)
429 }
430
431 #[test]
432 fn next_mle_eval_coincides_with_next_mle_evaluated_at_successors() {
433 let next_mle = next_mle_inner(NUM_VARS, F::zero(), F::one()).unwrap();
434
435 for i in 0..(1 << ((NUM_VARS / 2) - 1)) {
436 let mut point: Vec<F> = (0..(NUM_VARS / 2))
437 .map(|j| {
438 if i & (1 << j) == 0 {
439 F::zero()
440 } else {
441 F::one()
442 }
443 })
444 .collect();
445
446 point.extend((0..(NUM_VARS / 2)).map(|j| {
447 if (i + 1) & (1 << j) == 0 {
448 F::zero()
449 } else {
450 F::one()
451 }
452 }));
453
454 let (u, v) = point.split_at(NUM_VARS as usize / 2);
455 assert_eq!(
456 next_mle.clone().evaluate_with_config(&point, &()),
457 Ok(next_mle_eval(u, v, F::zero(), F::one()))
458 );
459 }
460 }
461
462 proptest! {
463 #[test]
464 fn prop_next_mle_eval_coincides_with_next_mle_evaluate_at_point(r in point_n(NUM_VARS as usize)) {
465 let next_mle = next_mle_inner(NUM_VARS, F::zero(), F::one()).unwrap();
466
467 let (u, v) = r.split_at(NUM_VARS as usize / 2);
468 prop_assert_eq!(
469 next_mle.evaluate_with_config(&r, &()),
470 Ok(next_mle_eval(u, v, F::zero(), F::one()))
471 );
472 }
473 }
474
475 #[test]
476 fn next_c_r_mle_c1_matches_shift_by_1() {
477 let num_vars: usize = 4;
479 let r: Vec<F> = (0..num_vars).map(|i| F::from((i + 3) as u32)).collect();
480
481 let next_1 = build_next_c_r_mle(&r, 1, &()).unwrap();
482
483 let eq_r = build_eq_x_r_inner(&r, &()).unwrap();
485 let n = 1 << num_vars;
486 let mut expected = vec![F::zero().into_inner(); 1];
487 expected.extend_from_slice(&eq_r.evaluations[..n - 1]);
488
489 assert_eq!(next_1.evaluations, expected);
490 }
491
492 #[test]
493 fn next_c_r_mle_c0_is_eq() {
494 let num_vars: usize = 4;
496 let r: Vec<F> = (0..num_vars).map(|i| F::from((i + 7) as u32)).collect();
497
498 let next_0 = build_next_c_r_mle(&r, 0, &()).unwrap();
499 let eq_r = build_eq_x_r_inner(&r, &()).unwrap();
500
501 assert_eq!(next_0.evaluations, eq_r.evaluations);
502 }
503
504 #[test]
505 fn next_c_r_mle_has_correct_structure() {
506 let num_vars: usize = 4;
510 let n = 1 << num_vars;
511 let r: Vec<F> = (0..num_vars).map(|i| F::from((i + 5) as u32)).collect();
512
513 for c in [2, 3, 5, 7] {
514 let next_c = build_next_c_r_mle(&r, c, &()).unwrap();
515 let eq_r = build_eq_x_r_inner(&r, &()).unwrap();
516
517 for b in 0..c {
519 assert!(
520 next_c.evaluations[b].is_zero(),
521 "c={c}, b={b}: expected zero"
522 );
523 }
524 for b in c..n {
526 assert_eq!(
527 next_c.evaluations[b],
528 eq_r.evaluations[b - c],
529 "c={c}, b={b}: mismatch"
530 );
531 }
532 }
533 }
534
535 proptest! {
536 #[test]
537 fn prop_next_c_r_mle_evaluates_correctly(r in point_n(4), c in 1..15usize) {
538 let next_c = build_next_c_r_mle(&r, c, &()).unwrap();
541 let eq_r = build_eq_x_r_inner(&r, &()).unwrap();
542
543 let n = 1 << r.len();
545 for b in 0..c.min(n) {
546 prop_assert!(next_c.evaluations[b].is_zero());
547 }
548 for b in c..n {
549 prop_assert_eq!(&next_c.evaluations[b], &eq_r.evaluations[b - c]);
550 }
551 }
552 }
553}