Skip to main content

zinc_poly/mle/
dense.rs

1mod try_collect_dense_mle;
2
3use core::ops::{Add, AddAssign, Index, IndexMut, Mul, MulAssign, Neg, Sub, SubAssign};
4use num_traits::Zero;
5#[cfg(feature = "parallel")]
6use rayon::prelude::*;
7use std::{
8    ops::{Deref, DerefMut},
9    slice::SliceIndex,
10};
11
12use crate::{
13    EvaluationError,
14    mle::{MultilinearExtension, MultilinearExtensionRand},
15};
16use crypto_primitives::{Matrix, PrimeField, Ring, Semiring};
17use rand::{distr::StandardUniform, prelude::*};
18use rand_core::RngCore;
19use zinc_utils::{
20    CHECKED, add, cfg_into_iter, inner_transparent_field::InnerTransparentField,
21    mul_by_scalar::MulByScalar, projectable_to_field::ProjectableToField, sub,
22};
23
24use super::MultilinearExtensionWithConfig;
25
26pub use try_collect_dense_mle::*;
27
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct DenseMultilinearExtension<T> {
30    /// The evaluation over {0,1}^`num_vars`
31    pub evaluations: Vec<T>,
32    /// Number of variables
33    pub num_vars: usize,
34}
35
36impl<R> DenseMultilinearExtension<R> {
37    pub fn zero_vars(evaluation: R) -> Self {
38        Self {
39            evaluations: vec![evaluation],
40            num_vars: 0,
41        }
42    }
43}
44
45impl<R: Clone> DenseMultilinearExtension<R> {
46    pub fn from_evaluations_slice(num_vars: usize, evaluations: &[R], zero: R) -> Self {
47        Self::from_evaluations_vec(num_vars, evaluations.to_vec(), zero)
48    }
49
50    pub fn from_evaluations_vec(num_vars: usize, evaluations: Vec<R>, zero: R) -> Self {
51        // assert that the number of variables matches the size of evaluations
52        assert!(
53            evaluations.len() <= 1 << num_vars,
54            "The size of evaluations should not exceed 2^num_vars. \n eval len: {:?}. num vars: {num_vars}",
55            evaluations.len()
56        );
57
58        if evaluations.len() != 1 << num_vars {
59            let mut evaluations = evaluations;
60            evaluations.resize(1 << num_vars, zero);
61            return Self {
62                num_vars,
63                evaluations,
64            };
65        }
66
67        Self {
68            num_vars,
69            evaluations,
70        }
71    }
72
73    /// Returns the dense MLE from the given matrix, without modifying the
74    /// original matrix.
75    #[allow(clippy::arithmetic_side_effects)]
76    pub fn from_matrix<M: Matrix<R>>(matrix: &M, zero: R) -> Self {
77        let n_vars: usize = // n_vars = s + s'
78            (zinc_utils::log2(matrix.num_rows()) + zinc_utils::log2(matrix.num_cols())) as usize;
79
80        // Matrices might need to get padded before turned into an MLE
81        let padded_rows = matrix.num_rows().next_power_of_two();
82        let padded_cols = matrix.num_cols().next_power_of_two();
83
84        // build dense vector representing the sparse padded matrix
85        let mut v = vec![zero.clone(); padded_rows * padded_cols];
86
87        for (row_i, row) in matrix.cells().enumerate() {
88            for (col_i, val) in row {
89                v[(padded_cols * row_i) + col_i] = val.clone();
90            }
91        }
92
93        // convert the dense vector into a mle
94        Self::from_evaluations_slice(n_vars, &v, zero)
95    }
96}
97
98impl<R: Default> DenseMultilinearExtension<R> {
99    pub fn from_evaluations_vec_pad(mut evaluations: Vec<R>) -> Self {
100        let len = evaluations.len();
101
102        evaluations.resize_with(len.next_power_of_two(), Default::default);
103
104        let num_vars = zinc_utils::log2(evaluations.len()) as usize;
105
106        Self {
107            evaluations,
108            num_vars,
109        }
110    }
111}
112
113impl<R: Clone> DenseMultilinearExtension<R> {
114    pub fn from_evaluations_vec_pad_with_zero(mut evaluations: Vec<R>, zero: &R) -> Self {
115        let len = evaluations.len();
116
117        evaluations.resize(len.next_power_of_two(), zero.clone());
118
119        let num_vars = zinc_utils::log2(evaluations.len()) as usize;
120
121        Self {
122            evaluations,
123            num_vars,
124        }
125    }
126}
127
128// Keeping Send bound here to match the FromParallelIterator impl
129impl<R: Send + Default> FromIterator<R> for DenseMultilinearExtension<R> {
130    fn from_iter<T: IntoIterator<Item = R>>(iter: T) -> Self {
131        Self::from_evaluations_vec_pad(iter.into_iter().collect())
132    }
133}
134
135impl<R> Deref for DenseMultilinearExtension<R> {
136    type Target = [R];
137
138    fn deref(&self) -> &Self::Target {
139        &self.evaluations
140    }
141}
142
143impl<R> DerefMut for DenseMultilinearExtension<R> {
144    fn deref_mut(&mut self) -> &mut Self::Target {
145        &mut self.evaluations
146    }
147}
148
149impl<R> IntoIterator for DenseMultilinearExtension<R> {
150    type Item = R;
151
152    type IntoIter = std::vec::IntoIter<R>;
153
154    fn into_iter(self) -> Self::IntoIter {
155        self.evaluations.into_iter()
156    }
157}
158
159#[cfg(feature = "parallel")]
160impl<R: Send + Default> FromParallelIterator<R> for DenseMultilinearExtension<R> {
161    fn from_par_iter<I>(par_iter: I) -> Self
162    where
163        I: IntoParallelIterator<Item = R>,
164    {
165        Self::from_evaluations_vec_pad(par_iter.into_par_iter().collect())
166    }
167}
168
169#[cfg(feature = "parallel")]
170impl<R: Send + Sync> IntoParallelIterator for DenseMultilinearExtension<R> {
171    type Iter = rayon::vec::IntoIter<R>;
172
173    type Item = R;
174
175    fn into_par_iter(self) -> Self::Iter {
176        self.evaluations.into_par_iter()
177    }
178}
179
180#[cfg(feature = "parallel")]
181impl<'data, R: Send + Sync> IntoParallelRefIterator<'data> for &'data DenseMultilinearExtension<R> {
182    type Iter = rayon::slice::Iter<'data, R>;
183
184    type Item = &'data R;
185
186    fn par_iter(&'data self) -> Self::Iter {
187        self.evaluations.par_iter()
188    }
189}
190
191#[cfg(feature = "parallel")]
192impl<'data, R: Send + Sync> IntoParallelRefMutIterator<'data>
193    for &'data mut DenseMultilinearExtension<R>
194{
195    type Iter = rayon::slice::IterMut<'data, R>;
196
197    type Item = &'data mut R;
198
199    fn par_iter_mut(&'data mut self) -> Self::Iter {
200        self.evaluations.par_iter_mut()
201    }
202}
203
204impl<R: Semiring> DenseMultilinearExtension<R> {
205    pub fn evaluate<S>(&self, point: &[S], zero: R) -> Result<R, EvaluationError>
206    where
207        R: for<'a> MulByScalar<&'a S>,
208    {
209        if point.len() == self.num_vars {
210            Ok(self
211                .fixed_variables(point, zero)
212                .into_iter()
213                .next()
214                .expect("Evaluations should not be empty"))
215        } else {
216            Err(EvaluationError::WrongPointWidth {
217                expected: self.num_vars,
218                actual: point.len(),
219            })
220        }
221    }
222
223    fn unary<G>(&mut self, f: G)
224    where
225        G: FnMut(&mut R),
226    {
227        self.iter_mut().for_each(f);
228    }
229
230    fn binary<G>(&mut self, other: &Self, mut f: G)
231    where
232        G: FnMut(&mut R, &R),
233    {
234        self.iter_mut().zip(other.iter()).for_each(|(a, b)| f(a, b));
235    }
236}
237
238impl<F> MultilinearExtensionWithConfig<F> for DenseMultilinearExtension<F::Inner>
239where
240    F: InnerTransparentField,
241{
242    #[allow(clippy::arithmetic_side_effects)]
243    fn fix_variables_with_config(
244        &mut self,
245        partial_point: &[F],
246        config: &<F as PrimeField>::Config,
247    ) {
248        assert!(
249            partial_point.len() <= self.num_vars,
250            "too many partial points"
251        );
252
253        if partial_point.len().is_zero() {
254            return;
255        }
256
257        let nv = self.num_vars;
258        let dim = partial_point.len();
259
260        let mut r = partial_point[0].clone();
261        for i in 1..dim + 1 {
262            for b in 0..1 << (nv - i) {
263                *r.inner_mut() = partial_point[i - 1].inner().clone();
264                if self[2 * b + 1] != self[2 * b] {
265                    // a = f(1) - f(0)
266                    let a = F::sub_inner(&self[2 * b + 1], &self[2 * b], config);
267
268                    // self[b] = f(0) + r * a
269                    r.mul_assign_by_inner(&a);
270                    self[b] = F::add_inner(&self[2 * b], r.inner(), config);
271                } else {
272                    self[b] = self[2 * b].clone();
273                };
274            }
275        }
276
277        self.evaluations.truncate(1 << (nv - dim));
278        self.num_vars = sub!(nv, dim);
279    }
280
281    fn fixed_variables_with_config(
282        &self,
283        partial_point: &[F],
284        config: &<F as PrimeField>::Config,
285    ) -> Self {
286        let mut res = self.clone();
287        res.fix_variables_with_config(partial_point, config);
288        res
289    }
290
291    fn evaluate_with_config(
292        mut self,
293        point: &[F],
294        config: &<F as PrimeField>::Config,
295    ) -> Result<F, EvaluationError> {
296        if point.len() == self.num_vars {
297            self.fix_variables_with_config(point, config);
298            Ok(F::new_unchecked_with_cfg(
299                self.into_iter()
300                    .next()
301                    .expect("Evaluations should not be empty"),
302                config,
303            ))
304        } else {
305            Err(EvaluationError::WrongPointWidth {
306                expected: point.len(),
307                actual: self.num_vars,
308            })
309        }
310    }
311}
312
313impl<R> MultilinearExtension<R> for DenseMultilinearExtension<R>
314where
315    R: Semiring,
316{
317    #[allow(clippy::arithmetic_side_effects)]
318    fn fix_variables<S>(&mut self, partial_point: &[S], zero: R)
319    where
320        R: for<'a> MulByScalar<&'a S>,
321    {
322        assert!(
323            partial_point.len() <= self.num_vars,
324            "too many partial points"
325        );
326
327        let nv = self.num_vars;
328        let dim = partial_point.len();
329
330        for i in 1..dim + 1 {
331            let r = &partial_point[i - 1];
332            for b in 0..1 << (nv - i) {
333                let left = &self[2 * b];
334                let right = &self[2 * b + 1];
335                // a = f(1) - f(0)
336                let a = sub!(*right, left);
337                if a != zero {
338                    // self[b] = f(0) + r * a
339                    let ar = a
340                        .mul_by_scalar::<CHECKED>(r)
341                        .expect("Multiplication overflow");
342                    self[b] = add!(*left, ar);
343                } else {
344                    self[b] = left.clone();
345                };
346            }
347        }
348
349        self.evaluations.truncate(1 << (nv - dim));
350        self.num_vars = sub!(nv, dim);
351    }
352
353    fn fixed_variables<S>(&self, partial_point: &[S], zero: R) -> Self
354    where
355        R: for<'a> MulByScalar<&'a S>,
356    {
357        let mut res = self.clone();
358        res.fix_variables(partial_point, zero);
359        res
360    }
361}
362
363impl<R> MultilinearExtensionRand<R> for DenseMultilinearExtension<R>
364where
365    R: Send + Clone + Default,
366    StandardUniform: Distribution<R>,
367{
368    fn rand<Rng: RngCore + ?Sized>(num_vars: usize, rng: &mut Rng) -> Self {
369        (0..1 << num_vars).map(|_| rng.random::<R>()).collect()
370    }
371}
372
373impl<T, I: SliceIndex<[T]>> Index<I> for DenseMultilinearExtension<T> {
374    type Output = I::Output;
375
376    fn index(&self, index: I) -> &Self::Output {
377        &self.evaluations[index]
378    }
379}
380
381impl<T, I: SliceIndex<[T]>> IndexMut<I> for DenseMultilinearExtension<T> {
382    fn index_mut(&mut self, index: I) -> &mut Self::Output {
383        &mut self.evaluations[index]
384    }
385}
386
387impl<R: Ring> Neg for DenseMultilinearExtension<R> {
388    type Output = Self;
389
390    fn neg(mut self) -> Self::Output {
391        self.unary(|v| *v = v.checked_neg().expect("Negation overflow"));
392        self
393    }
394}
395
396impl<R: Semiring> Add for DenseMultilinearExtension<R> {
397    type Output = Self;
398
399    #[allow(clippy::arithmetic_side_effects)]
400    fn add(self, rhs: Self) -> Self::Output {
401        self + &rhs
402    }
403}
404
405impl<R: Semiring> Add<&Self> for DenseMultilinearExtension<R> {
406    type Output = Self;
407
408    #[allow(clippy::arithmetic_side_effects)]
409    fn add(mut self, rhs: &Self) -> Self::Output {
410        self.binary(rhs, |a, b| *a += b);
411        self
412    }
413}
414
415impl<R: Semiring> Sub<&Self> for DenseMultilinearExtension<R> {
416    type Output = Self;
417
418    #[allow(clippy::arithmetic_side_effects)]
419    fn sub(mut self, rhs: &Self) -> Self::Output {
420        self.binary(rhs, |a, b| *a -= b);
421        self
422    }
423}
424
425impl<R: Semiring> Mul<&Self> for DenseMultilinearExtension<R> {
426    type Output = Self;
427
428    #[allow(clippy::arithmetic_side_effects)]
429    fn mul(mut self, rhs: &Self) -> Self::Output {
430        self.binary(rhs, |a, b| *a *= b);
431        self
432    }
433}
434
435impl<R: Semiring> Mul<R> for DenseMultilinearExtension<R> {
436    type Output = Self;
437
438    #[allow(clippy::arithmetic_side_effects)]
439    fn mul(mut self, rhs: R) -> Self::Output {
440        self.unary(|v| *v *= &rhs);
441        self
442    }
443}
444
445impl<R: Semiring> AddAssign<&Self> for DenseMultilinearExtension<R> {
446    #[allow(clippy::arithmetic_side_effects)]
447    fn add_assign(&mut self, rhs: &Self) {
448        self.binary(rhs, |a, b| *a += b);
449    }
450}
451
452impl<R: Semiring> SubAssign<&Self> for DenseMultilinearExtension<R> {
453    #[allow(clippy::arithmetic_side_effects)]
454    fn sub_assign(&mut self, rhs: &Self) {
455        self.binary(rhs, |a, b| *a -= b);
456    }
457}
458
459impl<R: Semiring> MulAssign<&Self> for DenseMultilinearExtension<R> {
460    #[allow(clippy::arithmetic_side_effects)]
461    fn mul_assign(&mut self, rhs: &Self) {
462        self.binary(rhs, |a, b| *a *= b);
463    }
464}
465
466impl<R: Semiring> AddAssign<(R, &Self)> for DenseMultilinearExtension<R> {
467    #[allow(clippy::arithmetic_side_effects)]
468    fn add_assign(&mut self, rhs: (R, &Self)) {
469        let coeff = rhs.0;
470        self.binary(rhs.1, |a, b| *a += b.clone() * &coeff);
471    }
472}
473
474pub fn project_coeffs<F: PrimeField, R: ProjectableToField<F> + Send + Sync>(
475    mle: DenseMultilinearExtension<R>,
476    sampled_value: &F,
477) -> DenseMultilinearExtension<F::Inner> {
478    let projection = R::prepare_projection(sampled_value);
479
480    DenseMultilinearExtension {
481        evaluations: cfg_into_iter!(mle.evaluations)
482            .map(|x| projection(&x).into_inner())
483            .collect(),
484        num_vars: mle.num_vars,
485    }
486}
487
488pub trait CollectDenseMleWithZero: Iterator {
489    fn collect_dense_mle_with_zero(
490        self,
491        zero: &Self::Item,
492    ) -> DenseMultilinearExtension<Self::Item>;
493}
494
495impl<T> CollectDenseMleWithZero for T
496where
497    T: Iterator,
498    T::Item: Clone,
499{
500    fn collect_dense_mle_with_zero(
501        self,
502        zero: &Self::Item,
503    ) -> DenseMultilinearExtension<Self::Item> {
504        let evaluations = self.collect();
505
506        DenseMultilinearExtension::from_evaluations_vec_pad_with_zero(evaluations, zero)
507    }
508}
509
510#[cfg(test)]
511#[allow(
512    clippy::arithmetic_side_effects,
513    clippy::cast_possible_truncation,
514    clippy::cast_possible_wrap,
515    clippy::cast_sign_loss
516)]
517mod tests {
518    use crate::utils::{build_eq_x_r, build_eq_x_r_vec};
519
520    use super::*;
521
522    use crypto_primitives::{
523        DenseRowMatrix, IntoWithConfig, PrimeField, crypto_bigint_monty::MontyField,
524        crypto_bigint_uint::Uint,
525    };
526    use proptest::prelude::*;
527
528    const LIMBS: usize = 4;
529
530    fn get_dyn_config(hex_modulus: &str) -> <MontyField<LIMBS> as PrimeField>::Config {
531        let modulus = Uint::new(
532            crypto_bigint::Uint::from_str_radix_vartime(hex_modulus, 16)
533                .expect("Invalid modulus hex string"),
534        );
535        MontyField::make_cfg(&modulus).expect("Failed to create field config")
536    }
537
538    const MODULUS: &str = "0076F668F4274572E39A3EA8285319B5";
539    type F = MontyField<LIMBS>;
540
541    fn any_f(cfg: <F as PrimeField>::Config) -> impl Strategy<Value = F> + 'static {
542        any::<u128>().prop_map(move |v| v.into_with_cfg(&cfg))
543    }
544
545    fn any_dme() -> impl Strategy<Value = DenseMultilinearExtension<F>> {
546        let cfg = get_dyn_config(MODULUS);
547        (0usize..=5).prop_flat_map(move |n| {
548            let len = 1usize << n;
549            let cfg = cfg;
550            prop::collection::vec(any_f(cfg), len).prop_map(move |evals| {
551                DenseMultilinearExtension::from_evaluations_vec(n, evals, F::zero_with_cfg(&cfg))
552            })
553        })
554    }
555
556    #[test]
557    fn test_build_eq_x_r_vec_basic() {
558        let cfg = get_dyn_config(MODULUS);
559        let r: [F; _] = [3_u64.into_with_cfg(&cfg)];
560        let evals = build_eq_x_r_vec(&r, &cfg).unwrap();
561        assert_eq!(
562            evals,
563            vec![F::one_with_cfg(&cfg) - r[0].clone(), r[0].clone()]
564        );
565    }
566
567    #[test]
568    fn test_build_eq_x_r_vec_two_vars() {
569        let cfg = get_dyn_config(MODULUS);
570        let r: [F; _] = [2u64.into_with_cfg(&cfg), 5u64.into_with_cfg(&cfg)];
571        let evals = build_eq_x_r_vec(&r, &cfg).unwrap();
572        let e00 = (F::one_with_cfg(&cfg) - r[0].clone()) * (F::one_with_cfg(&cfg) - r[1].clone());
573        let e01 = r[0].clone() * (F::one_with_cfg(&cfg) - r[1].clone());
574        let e10 = (F::one_with_cfg(&cfg) - r[0].clone()) * r[1].clone();
575        let e11 = r[0].clone() * r[1].clone();
576        assert_eq!(evals, vec![e00, e01, e10, e11]);
577    }
578
579    #[test]
580    fn test_build_eq_x_r_error_on_empty() {
581        let cfg = get_dyn_config(MODULUS);
582        let r: [F; 0] = [];
583        let err = build_eq_x_r_vec(&r, &cfg).unwrap_err();
584        let msg = format!("{err}");
585        assert!(msg.contains("Invalid parameters"));
586    }
587
588    #[test]
589    fn test_build_eq_x_r_mle_properties() {
590        let cfg = get_dyn_config(MODULUS);
591        let r: [F; _] = [
592            7u64.into_with_cfg(&cfg),
593            11u64.into_with_cfg(&cfg),
594            13u64.into_with_cfg(&cfg),
595        ];
596        let mle = build_eq_x_r(&r, &cfg).unwrap();
597        assert_eq!(mle.num_vars, r.len());
598        let evals = mle.evaluations;
599        let direct = build_eq_x_r_vec(&r, &cfg).unwrap();
600        assert_eq!(evals, direct);
601    }
602
603    #[test]
604    fn test_dense_from_slice_and_indexing() {
605        let cfg = get_dyn_config(MODULUS);
606        let n_vars = 3usize;
607        let v = vec![
608            1u64.into_with_cfg(&cfg),
609            2u64.into_with_cfg(&cfg),
610            3u64.into_with_cfg(&cfg),
611        ];
612        let dense =
613            DenseMultilinearExtension::from_evaluations_slice(n_vars, &v, F::zero_with_cfg(&cfg));
614        assert_eq!(dense.num_vars, n_vars);
615        let mut expected = v.clone();
616        expected.resize(1 << n_vars, F::zero_with_cfg(&cfg));
617        assert_eq!(dense.evaluations, expected);
618        assert_eq!(dense[0], 1u64.into_with_cfg(&cfg));
619        let mut d2 = dense.clone();
620        d2[1] = 99u64.into_with_cfg(&cfg);
621        assert_eq!(d2[1], 99u64.into_with_cfg(&cfg));
622    }
623
624    #[test]
625    fn test_fix_variables_and_evaluate() {
626        let cfg = get_dyn_config(MODULUS);
627        let evals = vec![
628            10u64.into_with_cfg(&cfg),
629            20u64.into_with_cfg(&cfg),
630            30u64.into_with_cfg(&cfg),
631            40u64.into_with_cfg(&cfg),
632        ];
633        let mle = DenseMultilinearExtension::from_evaluations_vec(
634            2,
635            evals.clone(),
636            F::zero_with_cfg(&cfg),
637        );
638        for (idx, (x0, x1)) in [
639            (F::zero_with_cfg(&cfg), F::zero_with_cfg(&cfg)),
640            (F::one_with_cfg(&cfg), F::zero_with_cfg(&cfg)),
641            (F::zero_with_cfg(&cfg), F::one_with_cfg(&cfg)),
642            (F::one_with_cfg(&cfg), F::one_with_cfg(&cfg)),
643        ]
644        .iter()
645        .enumerate()
646        {
647            let val = mle
648                .evaluate(&[x0.clone(), x1.clone()], F::zero_with_cfg(&cfg))
649                .unwrap();
650            assert_eq!(val, evals[idx]);
651        }
652        let mut m2 = mle.clone();
653        m2.fix_variables(&[F::one_with_cfg(&cfg)], F::zero_with_cfg(&cfg));
654        assert_eq!(m2.num_vars, 1);
655        assert_eq!(
656            m2.evaluations,
657            vec![20u64.into_with_cfg(&cfg), 40u64.into_with_cfg(&cfg)]
658        );
659    }
660
661    #[test]
662    fn test_from_matrix_padding_and_conversion() {
663        let cfg = get_dyn_config(MODULUS);
664        let m: DenseRowMatrix<F> = DenseRowMatrix::from(vec![
665            vec![5u64.into_with_cfg(&cfg), F::zero_with_cfg(&cfg)],
666            vec![F::zero_with_cfg(&cfg), F::zero_with_cfg(&cfg)],
667            vec![F::zero_with_cfg(&cfg), 7u64.into_with_cfg(&cfg)],
668        ]);
669        let dense = DenseMultilinearExtension::from_matrix(&m, F::zero_with_cfg(&cfg));
670        assert_eq!(dense.num_vars, 3);
671        assert_eq!(dense[0], 5u64.into_with_cfg(&cfg));
672        assert_eq!(dense[5], 7u64.into_with_cfg(&cfg));
673        assert!(dense.iter().enumerate().all(|(i, v)| if i == 0 || i == 5 {
674            true
675        } else {
676            F::is_zero(v)
677        }));
678    }
679
680    #[test]
681    fn test_from_evaluations_vec_padding_branch_and_slice() {
682        let cfg = get_dyn_config(MODULUS);
683        // len < 2^n triggers padding branch
684        let evals = vec![1u64.into_with_cfg(&cfg), 2u64.into_with_cfg(&cfg)];
685        let n = 2usize; // 4 expected
686        let d1 = DenseMultilinearExtension::from_evaluations_vec(
687            n,
688            evals.clone(),
689            F::zero_with_cfg(&cfg),
690        );
691        let mut expected = evals.clone();
692        expected.resize(1 << n, F::zero_with_cfg(&cfg));
693        assert_eq!(d1.evaluations, expected);
694        let d2 =
695            DenseMultilinearExtension::from_evaluations_slice(n, &evals, F::zero_with_cfg(&cfg));
696        assert_eq!(d2.evaluations, expected);
697    }
698
699    #[test]
700    fn test_fix_variables_edge_cases_and_full_truncate() {
701        let cfg = get_dyn_config(MODULUS);
702        let d = DenseMultilinearExtension::from_evaluations_vec(
703            2,
704            vec![
705                1.into_with_cfg(&cfg),
706                2.into_with_cfg(&cfg),
707                3.into_with_cfg(&cfg),
708                4.into_with_cfg(&cfg),
709            ],
710            F::zero_with_cfg(&cfg),
711        );
712        let d_fixed = d.fixed_variables(&[], F::zero_with_cfg(&cfg));
713        assert_eq!(d_fixed.num_vars, 2);
714        assert_eq!(
715            d_fixed.evaluations,
716            vec![
717                1.into_with_cfg(&cfg),
718                2.into_with_cfg(&cfg),
719                3.into_with_cfg(&cfg),
720                4.into_with_cfg(&cfg)
721            ]
722        );
723        let mut d2 = DenseMultilinearExtension::from_evaluations_vec(
724            2,
725            vec![
726                10.into_with_cfg(&cfg),
727                20.into_with_cfg(&cfg),
728                30.into_with_cfg(&cfg),
729                40.into_with_cfg(&cfg),
730            ],
731            F::zero_with_cfg(&cfg),
732        );
733        d2.fix_variables(
734            &[F::one_with_cfg(&cfg), F::zero_with_cfg(&cfg)],
735            F::zero_with_cfg(&cfg),
736        );
737        assert_eq!(d2.num_vars, 0);
738        assert_eq!(d2.evaluations, vec![20.into_with_cfg(&cfg)]);
739    }
740
741    #[test]
742    fn test_evaluate_length_mismatch_returns_error() {
743        let cfg = get_dyn_config(MODULUS);
744        let d = DenseMultilinearExtension::from_evaluations_vec(
745            2,
746            vec![
747                1.into_with_cfg(&cfg),
748                2.into_with_cfg(&cfg),
749                3.into_with_cfg(&cfg),
750                4.into_with_cfg(&cfg),
751            ],
752            F::zero_with_cfg(&cfg),
753        );
754        assert!(
755            d.evaluate(&[F::one_with_cfg(&cfg)], F::zero_with_cfg(&cfg))
756                .is_err()
757        );
758        assert!(
759            d.evaluate(
760                &[
761                    F::one_with_cfg(&cfg),
762                    F::one_with_cfg(&cfg),
763                    F::zero_with_cfg(&cfg)
764                ],
765                F::zero_with_cfg(&cfg)
766            )
767            .is_err()
768        );
769    }
770
771    #[test]
772    fn test_zero_impl_for_dense_mle() {
773        let cfg = get_dyn_config(MODULUS);
774        let z: DenseMultilinearExtension<F> =
775            DenseMultilinearExtension::zero_vars(F::zero_with_cfg(&cfg));
776        assert_eq!(z.num_vars, 0);
777        assert_eq!(z.evaluations, vec![F::zero_with_cfg(&cfg)]);
778    }
779
780    #[test]
781    fn test_arithmetic_ops_elementwise_add_sub_mul_and_neg() {
782        let cfg = get_dyn_config(MODULUS);
783        let a = DenseMultilinearExtension::from_evaluations_vec(
784            2,
785            vec![
786                1.into_with_cfg(&cfg),
787                2.into_with_cfg(&cfg),
788                3.into_with_cfg(&cfg),
789                4.into_with_cfg(&cfg),
790            ],
791            F::zero_with_cfg(&cfg),
792        );
793        let b = DenseMultilinearExtension::from_evaluations_vec(
794            2,
795            vec![
796                5.into_with_cfg(&cfg),
797                6.into_with_cfg(&cfg),
798                7.into_with_cfg(&cfg),
799                8.into_with_cfg(&cfg),
800            ],
801            F::zero_with_cfg(&cfg),
802        );
803
804        let sum = a.clone() + &b;
805        assert_eq!(
806            sum.evaluations,
807            vec![
808                6.into_with_cfg(&cfg),
809                8.into_with_cfg(&cfg),
810                10.into_with_cfg(&cfg),
811                12.into_with_cfg(&cfg)
812            ]
813        );
814
815        let diff = b.clone() - &a;
816        assert_eq!(
817            diff.evaluations,
818            vec![
819                4.into_with_cfg(&cfg),
820                4.into_with_cfg(&cfg),
821                4.into_with_cfg(&cfg),
822                4.into_with_cfg(&cfg)
823            ]
824        );
825
826        let prod = a.clone() * &b;
827        assert_eq!(
828            prod.evaluations,
829            vec![
830                5.into_with_cfg(&cfg),
831                12.into_with_cfg(&cfg),
832                21.into_with_cfg(&cfg),
833                32.into_with_cfg(&cfg)
834            ]
835        );
836
837        // Neg
838        let neg_a = -a.clone();
839        let mut expected = vec![];
840        for v in a.evaluations {
841            expected.push(-v);
842        }
843        assert_eq!(neg_a.evaluations, expected);
844    }
845
846    #[test]
847    fn test_scalar_mul_and_assign_variants() {
848        let cfg = get_dyn_config(MODULUS);
849        let a = DenseMultilinearExtension::from_evaluations_vec(
850            2,
851            vec![
852                1.into_with_cfg(&cfg),
853                2.into_with_cfg(&cfg),
854                3.into_with_cfg(&cfg),
855                4.into_with_cfg(&cfg),
856            ],
857            F::zero_with_cfg(&cfg),
858        );
859        let b = DenseMultilinearExtension::from_evaluations_vec(
860            2,
861            vec![
862                10.into_with_cfg(&cfg),
863                20.into_with_cfg(&cfg),
864                30.into_with_cfg(&cfg),
865                40.into_with_cfg(&cfg),
866            ],
867            F::zero_with_cfg(&cfg),
868        );
869
870        let three: F = 3u64.into_with_cfg(&cfg);
871        let scaled = a.clone() * three;
872        assert_eq!(
873            scaled.evaluations,
874            vec![
875                3.into_with_cfg(&cfg),
876                6.into_with_cfg(&cfg),
877                9.into_with_cfg(&cfg),
878                12.into_with_cfg(&cfg)
879            ]
880        );
881
882        let mut c = a.clone();
883        c += &b;
884        assert_eq!(
885            c.evaluations,
886            vec![
887                11.into_with_cfg(&cfg),
888                22.into_with_cfg(&cfg),
889                33.into_with_cfg(&cfg),
890                44.into_with_cfg(&cfg)
891            ]
892        );
893
894        c -= &b;
895        assert_eq!(c.evaluations, a.evaluations);
896
897        let mut d = a.clone();
898        d *= &b;
899        assert_eq!(
900            d.evaluations,
901            vec![
902                10.into_with_cfg(&cfg),
903                40.into_with_cfg(&cfg),
904                90.into_with_cfg(&cfg),
905                160.into_with_cfg(&cfg)
906            ]
907        );
908
909        let mut e = a.clone();
910        let two = 2u64.into_with_cfg(&cfg);
911        e += (two, &b);
912        assert_eq!(
913            e.evaluations,
914            vec![
915                21.into_with_cfg(&cfg),
916                42.into_with_cfg(&cfg),
917                63.into_with_cfg(&cfg),
918                84.into_with_cfg(&cfg)
919            ]
920        );
921    }
922
923    fn any_aligned_pair_with_point() -> impl Strategy<
924        Value = (
925            DenseMultilinearExtension<F>,
926            DenseMultilinearExtension<F>,
927            Vec<F>,
928        ),
929    > {
930        let cfg = get_dyn_config(MODULUS);
931        (0usize..=5).prop_flat_map(move |n| {
932            let cfg = cfg;
933            let len = 1usize << n;
934            prop::collection::vec(any_f(cfg), len).prop_flat_map(move |e1| {
935                let cfg = cfg;
936                let n2 = n;
937                prop::collection::vec(any_f(cfg), len).prop_flat_map(move |e2| {
938                    let cfg = cfg;
939                    let n3 = n2;
940                    point_n(n3).prop_map({
941                        let e1v = e1.clone();
942                        let e2v = e2.clone();
943                        move |r| {
944                            (
945                                DenseMultilinearExtension::from_evaluations_vec(
946                                    n3,
947                                    e1v.clone(),
948                                    F::zero_with_cfg(&cfg),
949                                ),
950                                DenseMultilinearExtension::from_evaluations_vec(
951                                    n3,
952                                    e2v.clone(),
953                                    F::zero_with_cfg(&cfg),
954                                ),
955                                r,
956                            )
957                        }
958                    })
959                })
960            })
961        })
962    }
963    fn point_n(n: usize) -> impl Strategy<Value = Vec<F>> {
964        prop::collection::vec(any_f(get_dyn_config(MODULUS)), n)
965    }
966
967    proptest! {
968        #[test]
969        fn prop_eval_add_is_linear((p1, p2, r) in any_aligned_pair_with_point()) {
970            let cfg = get_dyn_config(MODULUS);
971            let lhs = (p1.clone() + &p2).evaluate(&r, F::zero_with_cfg(&cfg)).unwrap();
972            let rhs = p1.evaluate(&r, F::zero_with_cfg(&cfg)).unwrap() + p2.evaluate(&r, F::zero_with_cfg(&cfg)).unwrap();
973            prop_assert_eq!(lhs, rhs);
974        }
975
976        #[test]
977        fn prop_fix_vars_commutes_with_eval((p, r, k) in any_dme().prop_flat_map(|p| {
978            let n = p.num_vars;
979            let point = point_n(n);
980            let ks = 0usize..=n;
981            (Just(p), point, ks)
982        })) {
983            let cfg = get_dyn_config(MODULUS);
984            let mut pfixed = p.clone();
985            pfixed.fix_variables(&r[..k], F::zero_with_cfg(&cfg));
986            let lhs = pfixed.evaluate(&r[k..], F::zero_with_cfg(&cfg)).unwrap();
987            let rhs = p.evaluate(&r, F::zero_with_cfg(&cfg)).unwrap();
988            prop_assert_eq!(lhs, rhs);
989        }
990
991        #[test]
992        fn prop_fix_vars_is_idempotent((p, k1, k2) in any_dme().prop_flat_map(|p| {
993            let n = p.num_vars;
994            let ks1 = 0usize..=n;
995            (Just(p), ks1).prop_flat_map(move |(p, k1)| {
996                let ks2 = 0usize..=n.saturating_sub(k1);
997                (Just(p), Just(k1), ks2)
998            })
999        }), r1 in prop::collection::vec(any_f(get_dyn_config(MODULUS)), 0..=8usize), r2 in prop::collection::vec(any_f(get_dyn_config(MODULUS)), 0..=8usize)) {
1000            let cfg = get_dyn_config(MODULUS);
1001            let mut p_step = p.clone();
1002            p_step.fix_variables(&r1[..k1.min(r1.len())], F::zero_with_cfg(&cfg));
1003            p_step.fix_variables(&r2[..k2.min(r2.len())], F::zero_with_cfg(&cfg));
1004
1005            let mut p_once = p.clone();
1006            let mut concat = r1[..k1.min(r1.len())].to_vec();
1007            concat.extend_from_slice(&r2[..k2.min(r2.len())]);
1008            p_once.fix_variables(&concat, F::zero_with_cfg(&cfg));
1009
1010            prop_assert_eq!(p_step.evaluations, p_once.evaluations);
1011            prop_assert_eq!(p_step.num_vars, p_once.num_vars);
1012        }
1013
1014        #[test]
1015        fn prop_mle_eval_eq_eval_with_config((p, r) in any_dme().prop_flat_map(|p| {
1016            let n = p.num_vars;
1017            let point = point_n(n);
1018            (Just(p), point)
1019        })) {
1020            let cfg = get_dyn_config(MODULUS);
1021
1022            let p_inner = DenseMultilinearExtension {
1023                num_vars: p.num_vars,
1024                evaluations: p.evaluations.iter().map(|x| *x.inner()).collect()
1025            };
1026
1027            let lhs = p.evaluate(&r, F::zero_with_cfg(&cfg)).unwrap();
1028            let rhs = p_inner.evaluate_with_config(&r, &cfg).unwrap();
1029            prop_assert_eq!(lhs, rhs);
1030        }
1031    }
1032}