1mod try_collect_dense_mle;
2
3use core::ops::{Add, AddAssign, Index, IndexMut, Mul, MulAssign, Neg, Sub, SubAssign};
4use num_traits::Zero;
5#[cfg(feature = "parallel")]
6use rayon::prelude::*;
7use std::{
8 ops::{Deref, DerefMut},
9 slice::SliceIndex,
10};
11
12use crate::{
13 EvaluationError,
14 mle::{MultilinearExtension, MultilinearExtensionRand},
15};
16use crypto_primitives::{Matrix, PrimeField, Ring, Semiring};
17use rand::{distr::StandardUniform, prelude::*};
18use rand_core::RngCore;
19use zinc_utils::{
20 CHECKED, add, cfg_into_iter, inner_transparent_field::InnerTransparentField,
21 mul_by_scalar::MulByScalar, projectable_to_field::ProjectableToField, sub,
22};
23
24use super::MultilinearExtensionWithConfig;
25
26pub use try_collect_dense_mle::*;
27
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct DenseMultilinearExtension<T> {
30 pub evaluations: Vec<T>,
32 pub num_vars: usize,
34}
35
36impl<R> DenseMultilinearExtension<R> {
37 pub fn zero_vars(evaluation: R) -> Self {
38 Self {
39 evaluations: vec![evaluation],
40 num_vars: 0,
41 }
42 }
43}
44
45impl<R: Clone> DenseMultilinearExtension<R> {
46 pub fn from_evaluations_slice(num_vars: usize, evaluations: &[R], zero: R) -> Self {
47 Self::from_evaluations_vec(num_vars, evaluations.to_vec(), zero)
48 }
49
50 pub fn from_evaluations_vec(num_vars: usize, evaluations: Vec<R>, zero: R) -> Self {
51 assert!(
53 evaluations.len() <= 1 << num_vars,
54 "The size of evaluations should not exceed 2^num_vars. \n eval len: {:?}. num vars: {num_vars}",
55 evaluations.len()
56 );
57
58 if evaluations.len() != 1 << num_vars {
59 let mut evaluations = evaluations;
60 evaluations.resize(1 << num_vars, zero);
61 return Self {
62 num_vars,
63 evaluations,
64 };
65 }
66
67 Self {
68 num_vars,
69 evaluations,
70 }
71 }
72
73 #[allow(clippy::arithmetic_side_effects)]
76 pub fn from_matrix<M: Matrix<R>>(matrix: &M, zero: R) -> Self {
77 let n_vars: usize = (zinc_utils::log2(matrix.num_rows()) + zinc_utils::log2(matrix.num_cols())) as usize;
79
80 let padded_rows = matrix.num_rows().next_power_of_two();
82 let padded_cols = matrix.num_cols().next_power_of_two();
83
84 let mut v = vec![zero.clone(); padded_rows * padded_cols];
86
87 for (row_i, row) in matrix.cells().enumerate() {
88 for (col_i, val) in row {
89 v[(padded_cols * row_i) + col_i] = val.clone();
90 }
91 }
92
93 Self::from_evaluations_slice(n_vars, &v, zero)
95 }
96}
97
98impl<R: Default> DenseMultilinearExtension<R> {
99 pub fn from_evaluations_vec_pad(mut evaluations: Vec<R>) -> Self {
100 let len = evaluations.len();
101
102 evaluations.resize_with(len.next_power_of_two(), Default::default);
103
104 let num_vars = zinc_utils::log2(evaluations.len()) as usize;
105
106 Self {
107 evaluations,
108 num_vars,
109 }
110 }
111}
112
113impl<R: Clone> DenseMultilinearExtension<R> {
114 pub fn from_evaluations_vec_pad_with_zero(mut evaluations: Vec<R>, zero: &R) -> Self {
115 let len = evaluations.len();
116
117 evaluations.resize(len.next_power_of_two(), zero.clone());
118
119 let num_vars = zinc_utils::log2(evaluations.len()) as usize;
120
121 Self {
122 evaluations,
123 num_vars,
124 }
125 }
126}
127
128impl<R: Send + Default> FromIterator<R> for DenseMultilinearExtension<R> {
130 fn from_iter<T: IntoIterator<Item = R>>(iter: T) -> Self {
131 Self::from_evaluations_vec_pad(iter.into_iter().collect())
132 }
133}
134
135impl<R> Deref for DenseMultilinearExtension<R> {
136 type Target = [R];
137
138 fn deref(&self) -> &Self::Target {
139 &self.evaluations
140 }
141}
142
143impl<R> DerefMut for DenseMultilinearExtension<R> {
144 fn deref_mut(&mut self) -> &mut Self::Target {
145 &mut self.evaluations
146 }
147}
148
149impl<R> IntoIterator for DenseMultilinearExtension<R> {
150 type Item = R;
151
152 type IntoIter = std::vec::IntoIter<R>;
153
154 fn into_iter(self) -> Self::IntoIter {
155 self.evaluations.into_iter()
156 }
157}
158
159#[cfg(feature = "parallel")]
160impl<R: Send + Default> FromParallelIterator<R> for DenseMultilinearExtension<R> {
161 fn from_par_iter<I>(par_iter: I) -> Self
162 where
163 I: IntoParallelIterator<Item = R>,
164 {
165 Self::from_evaluations_vec_pad(par_iter.into_par_iter().collect())
166 }
167}
168
169#[cfg(feature = "parallel")]
170impl<R: Send + Sync> IntoParallelIterator for DenseMultilinearExtension<R> {
171 type Iter = rayon::vec::IntoIter<R>;
172
173 type Item = R;
174
175 fn into_par_iter(self) -> Self::Iter {
176 self.evaluations.into_par_iter()
177 }
178}
179
180#[cfg(feature = "parallel")]
181impl<'data, R: Send + Sync> IntoParallelRefIterator<'data> for &'data DenseMultilinearExtension<R> {
182 type Iter = rayon::slice::Iter<'data, R>;
183
184 type Item = &'data R;
185
186 fn par_iter(&'data self) -> Self::Iter {
187 self.evaluations.par_iter()
188 }
189}
190
191#[cfg(feature = "parallel")]
192impl<'data, R: Send + Sync> IntoParallelRefMutIterator<'data>
193 for &'data mut DenseMultilinearExtension<R>
194{
195 type Iter = rayon::slice::IterMut<'data, R>;
196
197 type Item = &'data mut R;
198
199 fn par_iter_mut(&'data mut self) -> Self::Iter {
200 self.evaluations.par_iter_mut()
201 }
202}
203
204impl<R: Semiring> DenseMultilinearExtension<R> {
205 pub fn evaluate<S>(&self, point: &[S], zero: R) -> Result<R, EvaluationError>
206 where
207 R: for<'a> MulByScalar<&'a S>,
208 {
209 if point.len() == self.num_vars {
210 Ok(self
211 .fixed_variables(point, zero)
212 .into_iter()
213 .next()
214 .expect("Evaluations should not be empty"))
215 } else {
216 Err(EvaluationError::WrongPointWidth {
217 expected: self.num_vars,
218 actual: point.len(),
219 })
220 }
221 }
222
223 fn unary<G>(&mut self, f: G)
224 where
225 G: FnMut(&mut R),
226 {
227 self.iter_mut().for_each(f);
228 }
229
230 fn binary<G>(&mut self, other: &Self, mut f: G)
231 where
232 G: FnMut(&mut R, &R),
233 {
234 self.iter_mut().zip(other.iter()).for_each(|(a, b)| f(a, b));
235 }
236}
237
238impl<F> MultilinearExtensionWithConfig<F> for DenseMultilinearExtension<F::Inner>
239where
240 F: InnerTransparentField,
241{
242 #[allow(clippy::arithmetic_side_effects)]
243 fn fix_variables_with_config(
244 &mut self,
245 partial_point: &[F],
246 config: &<F as PrimeField>::Config,
247 ) {
248 assert!(
249 partial_point.len() <= self.num_vars,
250 "too many partial points"
251 );
252
253 if partial_point.len().is_zero() {
254 return;
255 }
256
257 let nv = self.num_vars;
258 let dim = partial_point.len();
259
260 let mut r = partial_point[0].clone();
261 for i in 1..dim + 1 {
262 for b in 0..1 << (nv - i) {
263 *r.inner_mut() = partial_point[i - 1].inner().clone();
264 if self[2 * b + 1] != self[2 * b] {
265 let a = F::sub_inner(&self[2 * b + 1], &self[2 * b], config);
267
268 r.mul_assign_by_inner(&a);
270 self[b] = F::add_inner(&self[2 * b], r.inner(), config);
271 } else {
272 self[b] = self[2 * b].clone();
273 };
274 }
275 }
276
277 self.evaluations.truncate(1 << (nv - dim));
278 self.num_vars = sub!(nv, dim);
279 }
280
281 fn fixed_variables_with_config(
282 &self,
283 partial_point: &[F],
284 config: &<F as PrimeField>::Config,
285 ) -> Self {
286 let mut res = self.clone();
287 res.fix_variables_with_config(partial_point, config);
288 res
289 }
290
291 fn evaluate_with_config(
292 mut self,
293 point: &[F],
294 config: &<F as PrimeField>::Config,
295 ) -> Result<F, EvaluationError> {
296 if point.len() == self.num_vars {
297 self.fix_variables_with_config(point, config);
298 Ok(F::new_unchecked_with_cfg(
299 self.into_iter()
300 .next()
301 .expect("Evaluations should not be empty"),
302 config,
303 ))
304 } else {
305 Err(EvaluationError::WrongPointWidth {
306 expected: point.len(),
307 actual: self.num_vars,
308 })
309 }
310 }
311}
312
313impl<R> MultilinearExtension<R> for DenseMultilinearExtension<R>
314where
315 R: Semiring,
316{
317 #[allow(clippy::arithmetic_side_effects)]
318 fn fix_variables<S>(&mut self, partial_point: &[S], zero: R)
319 where
320 R: for<'a> MulByScalar<&'a S>,
321 {
322 assert!(
323 partial_point.len() <= self.num_vars,
324 "too many partial points"
325 );
326
327 let nv = self.num_vars;
328 let dim = partial_point.len();
329
330 for i in 1..dim + 1 {
331 let r = &partial_point[i - 1];
332 for b in 0..1 << (nv - i) {
333 let left = &self[2 * b];
334 let right = &self[2 * b + 1];
335 let a = sub!(*right, left);
337 if a != zero {
338 let ar = a
340 .mul_by_scalar::<CHECKED>(r)
341 .expect("Multiplication overflow");
342 self[b] = add!(*left, ar);
343 } else {
344 self[b] = left.clone();
345 };
346 }
347 }
348
349 self.evaluations.truncate(1 << (nv - dim));
350 self.num_vars = sub!(nv, dim);
351 }
352
353 fn fixed_variables<S>(&self, partial_point: &[S], zero: R) -> Self
354 where
355 R: for<'a> MulByScalar<&'a S>,
356 {
357 let mut res = self.clone();
358 res.fix_variables(partial_point, zero);
359 res
360 }
361}
362
363impl<R> MultilinearExtensionRand<R> for DenseMultilinearExtension<R>
364where
365 R: Send + Clone + Default,
366 StandardUniform: Distribution<R>,
367{
368 fn rand<Rng: RngCore + ?Sized>(num_vars: usize, rng: &mut Rng) -> Self {
369 (0..1 << num_vars).map(|_| rng.random::<R>()).collect()
370 }
371}
372
373impl<T, I: SliceIndex<[T]>> Index<I> for DenseMultilinearExtension<T> {
374 type Output = I::Output;
375
376 fn index(&self, index: I) -> &Self::Output {
377 &self.evaluations[index]
378 }
379}
380
381impl<T, I: SliceIndex<[T]>> IndexMut<I> for DenseMultilinearExtension<T> {
382 fn index_mut(&mut self, index: I) -> &mut Self::Output {
383 &mut self.evaluations[index]
384 }
385}
386
387impl<R: Ring> Neg for DenseMultilinearExtension<R> {
388 type Output = Self;
389
390 fn neg(mut self) -> Self::Output {
391 self.unary(|v| *v = v.checked_neg().expect("Negation overflow"));
392 self
393 }
394}
395
396impl<R: Semiring> Add for DenseMultilinearExtension<R> {
397 type Output = Self;
398
399 #[allow(clippy::arithmetic_side_effects)]
400 fn add(self, rhs: Self) -> Self::Output {
401 self + &rhs
402 }
403}
404
405impl<R: Semiring> Add<&Self> for DenseMultilinearExtension<R> {
406 type Output = Self;
407
408 #[allow(clippy::arithmetic_side_effects)]
409 fn add(mut self, rhs: &Self) -> Self::Output {
410 self.binary(rhs, |a, b| *a += b);
411 self
412 }
413}
414
415impl<R: Semiring> Sub<&Self> for DenseMultilinearExtension<R> {
416 type Output = Self;
417
418 #[allow(clippy::arithmetic_side_effects)]
419 fn sub(mut self, rhs: &Self) -> Self::Output {
420 self.binary(rhs, |a, b| *a -= b);
421 self
422 }
423}
424
425impl<R: Semiring> Mul<&Self> for DenseMultilinearExtension<R> {
426 type Output = Self;
427
428 #[allow(clippy::arithmetic_side_effects)]
429 fn mul(mut self, rhs: &Self) -> Self::Output {
430 self.binary(rhs, |a, b| *a *= b);
431 self
432 }
433}
434
435impl<R: Semiring> Mul<R> for DenseMultilinearExtension<R> {
436 type Output = Self;
437
438 #[allow(clippy::arithmetic_side_effects)]
439 fn mul(mut self, rhs: R) -> Self::Output {
440 self.unary(|v| *v *= &rhs);
441 self
442 }
443}
444
445impl<R: Semiring> AddAssign<&Self> for DenseMultilinearExtension<R> {
446 #[allow(clippy::arithmetic_side_effects)]
447 fn add_assign(&mut self, rhs: &Self) {
448 self.binary(rhs, |a, b| *a += b);
449 }
450}
451
452impl<R: Semiring> SubAssign<&Self> for DenseMultilinearExtension<R> {
453 #[allow(clippy::arithmetic_side_effects)]
454 fn sub_assign(&mut self, rhs: &Self) {
455 self.binary(rhs, |a, b| *a -= b);
456 }
457}
458
459impl<R: Semiring> MulAssign<&Self> for DenseMultilinearExtension<R> {
460 #[allow(clippy::arithmetic_side_effects)]
461 fn mul_assign(&mut self, rhs: &Self) {
462 self.binary(rhs, |a, b| *a *= b);
463 }
464}
465
466impl<R: Semiring> AddAssign<(R, &Self)> for DenseMultilinearExtension<R> {
467 #[allow(clippy::arithmetic_side_effects)]
468 fn add_assign(&mut self, rhs: (R, &Self)) {
469 let coeff = rhs.0;
470 self.binary(rhs.1, |a, b| *a += b.clone() * &coeff);
471 }
472}
473
474pub fn project_coeffs<F: PrimeField, R: ProjectableToField<F> + Send + Sync>(
475 mle: DenseMultilinearExtension<R>,
476 sampled_value: &F,
477) -> DenseMultilinearExtension<F::Inner> {
478 let projection = R::prepare_projection(sampled_value);
479
480 DenseMultilinearExtension {
481 evaluations: cfg_into_iter!(mle.evaluations)
482 .map(|x| projection(&x).into_inner())
483 .collect(),
484 num_vars: mle.num_vars,
485 }
486}
487
488pub trait CollectDenseMleWithZero: Iterator {
489 fn collect_dense_mle_with_zero(
490 self,
491 zero: &Self::Item,
492 ) -> DenseMultilinearExtension<Self::Item>;
493}
494
495impl<T> CollectDenseMleWithZero for T
496where
497 T: Iterator,
498 T::Item: Clone,
499{
500 fn collect_dense_mle_with_zero(
501 self,
502 zero: &Self::Item,
503 ) -> DenseMultilinearExtension<Self::Item> {
504 let evaluations = self.collect();
505
506 DenseMultilinearExtension::from_evaluations_vec_pad_with_zero(evaluations, zero)
507 }
508}
509
510#[cfg(test)]
511#[allow(
512 clippy::arithmetic_side_effects,
513 clippy::cast_possible_truncation,
514 clippy::cast_possible_wrap,
515 clippy::cast_sign_loss
516)]
517mod tests {
518 use crate::utils::{build_eq_x_r, build_eq_x_r_vec};
519
520 use super::*;
521
522 use crypto_primitives::{
523 DenseRowMatrix, IntoWithConfig, PrimeField, crypto_bigint_monty::MontyField,
524 crypto_bigint_uint::Uint,
525 };
526 use proptest::prelude::*;
527
528 const LIMBS: usize = 4;
529
530 fn get_dyn_config(hex_modulus: &str) -> <MontyField<LIMBS> as PrimeField>::Config {
531 let modulus = Uint::new(
532 crypto_bigint::Uint::from_str_radix_vartime(hex_modulus, 16)
533 .expect("Invalid modulus hex string"),
534 );
535 MontyField::make_cfg(&modulus).expect("Failed to create field config")
536 }
537
538 const MODULUS: &str = "0076F668F4274572E39A3EA8285319B5";
539 type F = MontyField<LIMBS>;
540
541 fn any_f(cfg: <F as PrimeField>::Config) -> impl Strategy<Value = F> + 'static {
542 any::<u128>().prop_map(move |v| v.into_with_cfg(&cfg))
543 }
544
545 fn any_dme() -> impl Strategy<Value = DenseMultilinearExtension<F>> {
546 let cfg = get_dyn_config(MODULUS);
547 (0usize..=5).prop_flat_map(move |n| {
548 let len = 1usize << n;
549 let cfg = cfg;
550 prop::collection::vec(any_f(cfg), len).prop_map(move |evals| {
551 DenseMultilinearExtension::from_evaluations_vec(n, evals, F::zero_with_cfg(&cfg))
552 })
553 })
554 }
555
556 #[test]
557 fn test_build_eq_x_r_vec_basic() {
558 let cfg = get_dyn_config(MODULUS);
559 let r: [F; _] = [3_u64.into_with_cfg(&cfg)];
560 let evals = build_eq_x_r_vec(&r, &cfg).unwrap();
561 assert_eq!(
562 evals,
563 vec![F::one_with_cfg(&cfg) - r[0].clone(), r[0].clone()]
564 );
565 }
566
567 #[test]
568 fn test_build_eq_x_r_vec_two_vars() {
569 let cfg = get_dyn_config(MODULUS);
570 let r: [F; _] = [2u64.into_with_cfg(&cfg), 5u64.into_with_cfg(&cfg)];
571 let evals = build_eq_x_r_vec(&r, &cfg).unwrap();
572 let e00 = (F::one_with_cfg(&cfg) - r[0].clone()) * (F::one_with_cfg(&cfg) - r[1].clone());
573 let e01 = r[0].clone() * (F::one_with_cfg(&cfg) - r[1].clone());
574 let e10 = (F::one_with_cfg(&cfg) - r[0].clone()) * r[1].clone();
575 let e11 = r[0].clone() * r[1].clone();
576 assert_eq!(evals, vec![e00, e01, e10, e11]);
577 }
578
579 #[test]
580 fn test_build_eq_x_r_error_on_empty() {
581 let cfg = get_dyn_config(MODULUS);
582 let r: [F; 0] = [];
583 let err = build_eq_x_r_vec(&r, &cfg).unwrap_err();
584 let msg = format!("{err}");
585 assert!(msg.contains("Invalid parameters"));
586 }
587
588 #[test]
589 fn test_build_eq_x_r_mle_properties() {
590 let cfg = get_dyn_config(MODULUS);
591 let r: [F; _] = [
592 7u64.into_with_cfg(&cfg),
593 11u64.into_with_cfg(&cfg),
594 13u64.into_with_cfg(&cfg),
595 ];
596 let mle = build_eq_x_r(&r, &cfg).unwrap();
597 assert_eq!(mle.num_vars, r.len());
598 let evals = mle.evaluations;
599 let direct = build_eq_x_r_vec(&r, &cfg).unwrap();
600 assert_eq!(evals, direct);
601 }
602
603 #[test]
604 fn test_dense_from_slice_and_indexing() {
605 let cfg = get_dyn_config(MODULUS);
606 let n_vars = 3usize;
607 let v = vec![
608 1u64.into_with_cfg(&cfg),
609 2u64.into_with_cfg(&cfg),
610 3u64.into_with_cfg(&cfg),
611 ];
612 let dense =
613 DenseMultilinearExtension::from_evaluations_slice(n_vars, &v, F::zero_with_cfg(&cfg));
614 assert_eq!(dense.num_vars, n_vars);
615 let mut expected = v.clone();
616 expected.resize(1 << n_vars, F::zero_with_cfg(&cfg));
617 assert_eq!(dense.evaluations, expected);
618 assert_eq!(dense[0], 1u64.into_with_cfg(&cfg));
619 let mut d2 = dense.clone();
620 d2[1] = 99u64.into_with_cfg(&cfg);
621 assert_eq!(d2[1], 99u64.into_with_cfg(&cfg));
622 }
623
624 #[test]
625 fn test_fix_variables_and_evaluate() {
626 let cfg = get_dyn_config(MODULUS);
627 let evals = vec![
628 10u64.into_with_cfg(&cfg),
629 20u64.into_with_cfg(&cfg),
630 30u64.into_with_cfg(&cfg),
631 40u64.into_with_cfg(&cfg),
632 ];
633 let mle = DenseMultilinearExtension::from_evaluations_vec(
634 2,
635 evals.clone(),
636 F::zero_with_cfg(&cfg),
637 );
638 for (idx, (x0, x1)) in [
639 (F::zero_with_cfg(&cfg), F::zero_with_cfg(&cfg)),
640 (F::one_with_cfg(&cfg), F::zero_with_cfg(&cfg)),
641 (F::zero_with_cfg(&cfg), F::one_with_cfg(&cfg)),
642 (F::one_with_cfg(&cfg), F::one_with_cfg(&cfg)),
643 ]
644 .iter()
645 .enumerate()
646 {
647 let val = mle
648 .evaluate(&[x0.clone(), x1.clone()], F::zero_with_cfg(&cfg))
649 .unwrap();
650 assert_eq!(val, evals[idx]);
651 }
652 let mut m2 = mle.clone();
653 m2.fix_variables(&[F::one_with_cfg(&cfg)], F::zero_with_cfg(&cfg));
654 assert_eq!(m2.num_vars, 1);
655 assert_eq!(
656 m2.evaluations,
657 vec![20u64.into_with_cfg(&cfg), 40u64.into_with_cfg(&cfg)]
658 );
659 }
660
661 #[test]
662 fn test_from_matrix_padding_and_conversion() {
663 let cfg = get_dyn_config(MODULUS);
664 let m: DenseRowMatrix<F> = DenseRowMatrix::from(vec![
665 vec![5u64.into_with_cfg(&cfg), F::zero_with_cfg(&cfg)],
666 vec![F::zero_with_cfg(&cfg), F::zero_with_cfg(&cfg)],
667 vec![F::zero_with_cfg(&cfg), 7u64.into_with_cfg(&cfg)],
668 ]);
669 let dense = DenseMultilinearExtension::from_matrix(&m, F::zero_with_cfg(&cfg));
670 assert_eq!(dense.num_vars, 3);
671 assert_eq!(dense[0], 5u64.into_with_cfg(&cfg));
672 assert_eq!(dense[5], 7u64.into_with_cfg(&cfg));
673 assert!(dense.iter().enumerate().all(|(i, v)| if i == 0 || i == 5 {
674 true
675 } else {
676 F::is_zero(v)
677 }));
678 }
679
680 #[test]
681 fn test_from_evaluations_vec_padding_branch_and_slice() {
682 let cfg = get_dyn_config(MODULUS);
683 let evals = vec![1u64.into_with_cfg(&cfg), 2u64.into_with_cfg(&cfg)];
685 let n = 2usize; let d1 = DenseMultilinearExtension::from_evaluations_vec(
687 n,
688 evals.clone(),
689 F::zero_with_cfg(&cfg),
690 );
691 let mut expected = evals.clone();
692 expected.resize(1 << n, F::zero_with_cfg(&cfg));
693 assert_eq!(d1.evaluations, expected);
694 let d2 =
695 DenseMultilinearExtension::from_evaluations_slice(n, &evals, F::zero_with_cfg(&cfg));
696 assert_eq!(d2.evaluations, expected);
697 }
698
699 #[test]
700 fn test_fix_variables_edge_cases_and_full_truncate() {
701 let cfg = get_dyn_config(MODULUS);
702 let d = DenseMultilinearExtension::from_evaluations_vec(
703 2,
704 vec![
705 1.into_with_cfg(&cfg),
706 2.into_with_cfg(&cfg),
707 3.into_with_cfg(&cfg),
708 4.into_with_cfg(&cfg),
709 ],
710 F::zero_with_cfg(&cfg),
711 );
712 let d_fixed = d.fixed_variables(&[], F::zero_with_cfg(&cfg));
713 assert_eq!(d_fixed.num_vars, 2);
714 assert_eq!(
715 d_fixed.evaluations,
716 vec![
717 1.into_with_cfg(&cfg),
718 2.into_with_cfg(&cfg),
719 3.into_with_cfg(&cfg),
720 4.into_with_cfg(&cfg)
721 ]
722 );
723 let mut d2 = DenseMultilinearExtension::from_evaluations_vec(
724 2,
725 vec![
726 10.into_with_cfg(&cfg),
727 20.into_with_cfg(&cfg),
728 30.into_with_cfg(&cfg),
729 40.into_with_cfg(&cfg),
730 ],
731 F::zero_with_cfg(&cfg),
732 );
733 d2.fix_variables(
734 &[F::one_with_cfg(&cfg), F::zero_with_cfg(&cfg)],
735 F::zero_with_cfg(&cfg),
736 );
737 assert_eq!(d2.num_vars, 0);
738 assert_eq!(d2.evaluations, vec![20.into_with_cfg(&cfg)]);
739 }
740
741 #[test]
742 fn test_evaluate_length_mismatch_returns_error() {
743 let cfg = get_dyn_config(MODULUS);
744 let d = DenseMultilinearExtension::from_evaluations_vec(
745 2,
746 vec![
747 1.into_with_cfg(&cfg),
748 2.into_with_cfg(&cfg),
749 3.into_with_cfg(&cfg),
750 4.into_with_cfg(&cfg),
751 ],
752 F::zero_with_cfg(&cfg),
753 );
754 assert!(
755 d.evaluate(&[F::one_with_cfg(&cfg)], F::zero_with_cfg(&cfg))
756 .is_err()
757 );
758 assert!(
759 d.evaluate(
760 &[
761 F::one_with_cfg(&cfg),
762 F::one_with_cfg(&cfg),
763 F::zero_with_cfg(&cfg)
764 ],
765 F::zero_with_cfg(&cfg)
766 )
767 .is_err()
768 );
769 }
770
771 #[test]
772 fn test_zero_impl_for_dense_mle() {
773 let cfg = get_dyn_config(MODULUS);
774 let z: DenseMultilinearExtension<F> =
775 DenseMultilinearExtension::zero_vars(F::zero_with_cfg(&cfg));
776 assert_eq!(z.num_vars, 0);
777 assert_eq!(z.evaluations, vec![F::zero_with_cfg(&cfg)]);
778 }
779
780 #[test]
781 fn test_arithmetic_ops_elementwise_add_sub_mul_and_neg() {
782 let cfg = get_dyn_config(MODULUS);
783 let a = DenseMultilinearExtension::from_evaluations_vec(
784 2,
785 vec![
786 1.into_with_cfg(&cfg),
787 2.into_with_cfg(&cfg),
788 3.into_with_cfg(&cfg),
789 4.into_with_cfg(&cfg),
790 ],
791 F::zero_with_cfg(&cfg),
792 );
793 let b = DenseMultilinearExtension::from_evaluations_vec(
794 2,
795 vec![
796 5.into_with_cfg(&cfg),
797 6.into_with_cfg(&cfg),
798 7.into_with_cfg(&cfg),
799 8.into_with_cfg(&cfg),
800 ],
801 F::zero_with_cfg(&cfg),
802 );
803
804 let sum = a.clone() + &b;
805 assert_eq!(
806 sum.evaluations,
807 vec![
808 6.into_with_cfg(&cfg),
809 8.into_with_cfg(&cfg),
810 10.into_with_cfg(&cfg),
811 12.into_with_cfg(&cfg)
812 ]
813 );
814
815 let diff = b.clone() - &a;
816 assert_eq!(
817 diff.evaluations,
818 vec![
819 4.into_with_cfg(&cfg),
820 4.into_with_cfg(&cfg),
821 4.into_with_cfg(&cfg),
822 4.into_with_cfg(&cfg)
823 ]
824 );
825
826 let prod = a.clone() * &b;
827 assert_eq!(
828 prod.evaluations,
829 vec![
830 5.into_with_cfg(&cfg),
831 12.into_with_cfg(&cfg),
832 21.into_with_cfg(&cfg),
833 32.into_with_cfg(&cfg)
834 ]
835 );
836
837 let neg_a = -a.clone();
839 let mut expected = vec![];
840 for v in a.evaluations {
841 expected.push(-v);
842 }
843 assert_eq!(neg_a.evaluations, expected);
844 }
845
846 #[test]
847 fn test_scalar_mul_and_assign_variants() {
848 let cfg = get_dyn_config(MODULUS);
849 let a = DenseMultilinearExtension::from_evaluations_vec(
850 2,
851 vec![
852 1.into_with_cfg(&cfg),
853 2.into_with_cfg(&cfg),
854 3.into_with_cfg(&cfg),
855 4.into_with_cfg(&cfg),
856 ],
857 F::zero_with_cfg(&cfg),
858 );
859 let b = DenseMultilinearExtension::from_evaluations_vec(
860 2,
861 vec![
862 10.into_with_cfg(&cfg),
863 20.into_with_cfg(&cfg),
864 30.into_with_cfg(&cfg),
865 40.into_with_cfg(&cfg),
866 ],
867 F::zero_with_cfg(&cfg),
868 );
869
870 let three: F = 3u64.into_with_cfg(&cfg);
871 let scaled = a.clone() * three;
872 assert_eq!(
873 scaled.evaluations,
874 vec![
875 3.into_with_cfg(&cfg),
876 6.into_with_cfg(&cfg),
877 9.into_with_cfg(&cfg),
878 12.into_with_cfg(&cfg)
879 ]
880 );
881
882 let mut c = a.clone();
883 c += &b;
884 assert_eq!(
885 c.evaluations,
886 vec![
887 11.into_with_cfg(&cfg),
888 22.into_with_cfg(&cfg),
889 33.into_with_cfg(&cfg),
890 44.into_with_cfg(&cfg)
891 ]
892 );
893
894 c -= &b;
895 assert_eq!(c.evaluations, a.evaluations);
896
897 let mut d = a.clone();
898 d *= &b;
899 assert_eq!(
900 d.evaluations,
901 vec![
902 10.into_with_cfg(&cfg),
903 40.into_with_cfg(&cfg),
904 90.into_with_cfg(&cfg),
905 160.into_with_cfg(&cfg)
906 ]
907 );
908
909 let mut e = a.clone();
910 let two = 2u64.into_with_cfg(&cfg);
911 e += (two, &b);
912 assert_eq!(
913 e.evaluations,
914 vec![
915 21.into_with_cfg(&cfg),
916 42.into_with_cfg(&cfg),
917 63.into_with_cfg(&cfg),
918 84.into_with_cfg(&cfg)
919 ]
920 );
921 }
922
923 fn any_aligned_pair_with_point() -> impl Strategy<
924 Value = (
925 DenseMultilinearExtension<F>,
926 DenseMultilinearExtension<F>,
927 Vec<F>,
928 ),
929 > {
930 let cfg = get_dyn_config(MODULUS);
931 (0usize..=5).prop_flat_map(move |n| {
932 let cfg = cfg;
933 let len = 1usize << n;
934 prop::collection::vec(any_f(cfg), len).prop_flat_map(move |e1| {
935 let cfg = cfg;
936 let n2 = n;
937 prop::collection::vec(any_f(cfg), len).prop_flat_map(move |e2| {
938 let cfg = cfg;
939 let n3 = n2;
940 point_n(n3).prop_map({
941 let e1v = e1.clone();
942 let e2v = e2.clone();
943 move |r| {
944 (
945 DenseMultilinearExtension::from_evaluations_vec(
946 n3,
947 e1v.clone(),
948 F::zero_with_cfg(&cfg),
949 ),
950 DenseMultilinearExtension::from_evaluations_vec(
951 n3,
952 e2v.clone(),
953 F::zero_with_cfg(&cfg),
954 ),
955 r,
956 )
957 }
958 })
959 })
960 })
961 })
962 }
963 fn point_n(n: usize) -> impl Strategy<Value = Vec<F>> {
964 prop::collection::vec(any_f(get_dyn_config(MODULUS)), n)
965 }
966
967 proptest! {
968 #[test]
969 fn prop_eval_add_is_linear((p1, p2, r) in any_aligned_pair_with_point()) {
970 let cfg = get_dyn_config(MODULUS);
971 let lhs = (p1.clone() + &p2).evaluate(&r, F::zero_with_cfg(&cfg)).unwrap();
972 let rhs = p1.evaluate(&r, F::zero_with_cfg(&cfg)).unwrap() + p2.evaluate(&r, F::zero_with_cfg(&cfg)).unwrap();
973 prop_assert_eq!(lhs, rhs);
974 }
975
976 #[test]
977 fn prop_fix_vars_commutes_with_eval((p, r, k) in any_dme().prop_flat_map(|p| {
978 let n = p.num_vars;
979 let point = point_n(n);
980 let ks = 0usize..=n;
981 (Just(p), point, ks)
982 })) {
983 let cfg = get_dyn_config(MODULUS);
984 let mut pfixed = p.clone();
985 pfixed.fix_variables(&r[..k], F::zero_with_cfg(&cfg));
986 let lhs = pfixed.evaluate(&r[k..], F::zero_with_cfg(&cfg)).unwrap();
987 let rhs = p.evaluate(&r, F::zero_with_cfg(&cfg)).unwrap();
988 prop_assert_eq!(lhs, rhs);
989 }
990
991 #[test]
992 fn prop_fix_vars_is_idempotent((p, k1, k2) in any_dme().prop_flat_map(|p| {
993 let n = p.num_vars;
994 let ks1 = 0usize..=n;
995 (Just(p), ks1).prop_flat_map(move |(p, k1)| {
996 let ks2 = 0usize..=n.saturating_sub(k1);
997 (Just(p), Just(k1), ks2)
998 })
999 }), r1 in prop::collection::vec(any_f(get_dyn_config(MODULUS)), 0..=8usize), r2 in prop::collection::vec(any_f(get_dyn_config(MODULUS)), 0..=8usize)) {
1000 let cfg = get_dyn_config(MODULUS);
1001 let mut p_step = p.clone();
1002 p_step.fix_variables(&r1[..k1.min(r1.len())], F::zero_with_cfg(&cfg));
1003 p_step.fix_variables(&r2[..k2.min(r2.len())], F::zero_with_cfg(&cfg));
1004
1005 let mut p_once = p.clone();
1006 let mut concat = r1[..k1.min(r1.len())].to_vec();
1007 concat.extend_from_slice(&r2[..k2.min(r2.len())]);
1008 p_once.fix_variables(&concat, F::zero_with_cfg(&cfg));
1009
1010 prop_assert_eq!(p_step.evaluations, p_once.evaluations);
1011 prop_assert_eq!(p_step.num_vars, p_once.num_vars);
1012 }
1013
1014 #[test]
1015 fn prop_mle_eval_eq_eval_with_config((p, r) in any_dme().prop_flat_map(|p| {
1016 let n = p.num_vars;
1017 let point = point_n(n);
1018 (Just(p), point)
1019 })) {
1020 let cfg = get_dyn_config(MODULUS);
1021
1022 let p_inner = DenseMultilinearExtension {
1023 num_vars: p.num_vars,
1024 evaluations: p.evaluations.iter().map(|x| *x.inner()).collect()
1025 };
1026
1027 let lhs = p.evaluate(&r, F::zero_with_cfg(&cfg)).unwrap();
1028 let rhs = p_inner.evaluate_with_config(&r, &cfg).unwrap();
1029 prop_assert_eq!(lhs, rhs);
1030 }
1031 }
1032}