Skip to main content

zinc_poly/
utils.rs

1use crypto_primitives::{Field, PrimeField, Semiring};
2use num_traits::Zero;
3use thiserror::Error;
4use zinc_utils::{cfg_iter_mut, inner_transparent_field::InnerTransparentField, sub};
5
6#[cfg(feature = "parallel")]
7use rayon::prelude::*;
8
9use crate::mle::{DenseMultilinearExtension, dense::CollectDenseMleWithZero};
10
11/// A `enum` specifying the possible failure modes of the arithmetics.
12#[derive(Debug, Clone, Error)]
13pub enum ArithErrors {
14    #[error("Invalid parameters: {0}")]
15    InvalidParameters(String),
16}
17
18/// This function build the eq(x, r) polynomial for any given r.
19///
20/// Evaluate
21///      eq(x,y) = \prod_i=1^num_var (x_i * y_i + (1-x_i)*(1-y_i))
22/// over r, which is
23///      eq(x,y) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i))
24pub fn build_eq_x_r<F>(
25    r: &[F],
26    cfg: &F::Config,
27) -> Result<DenseMultilinearExtension<F>, ArithErrors>
28where
29    F: PrimeField,
30{
31    let evals = build_eq_x_r_vec(r, cfg)?;
32    let mle =
33        DenseMultilinearExtension::from_evaluations_vec(r.len(), evals, F::zero_with_cfg(cfg));
34
35    Ok(mle)
36}
37
38/// This function build the eq(x, r) polynomial for any given r, and output the
39/// evaluation of eq(x, r) in its vector form.
40///
41/// Evaluate
42///      eq(x,y) = \prod_i=1^num_var (x_i * y_i + (1-x_i)*(1-y_i))
43/// over r, which is
44///      eq(x,y) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i))
45pub fn build_eq_x_r_vec<F>(r: &[F], cfg: &F::Config) -> Result<Vec<F>, ArithErrors>
46where
47    F: PrimeField,
48{
49    // we build eq(x,r) from its evaluations
50    // we want to evaluate eq(x,r) over x \in {0, 1}^num_vars
51    // for example, with num_vars = 4, x is a binary vector of 4, then
52    //  0 0 0 0 -> (1-r0)   * (1-r1)    * (1-r2)    * (1-r3)
53    //  1 0 0 0 -> r0       * (1-r1)    * (1-r2)    * (1-r3)
54    //  0 1 0 0 -> (1-r0)   * r1        * (1-r2)    * (1-r3)
55    //  1 1 0 0 -> r0       * r1        * (1-r2)    * (1-r3)
56    //  ....
57    //  1 1 1 1 -> r0       * r1        * r2        * r3
58    // we will need 2^num_var evaluations
59
60    let mut eval = Vec::new();
61    build_eq_x_r_helper(r, &mut eval, cfg)?;
62
63    Ok(eval)
64}
65
66/// A helper function to build eq(x, r) recursively.
67/// This function takes `r.len()` steps, and for each step it requires a maximum
68/// `r.len()-1` multiplications.
69fn build_eq_x_r_helper<F>(r: &[F], buf: &mut Vec<F>, cfg: &F::Config) -> Result<(), ArithErrors>
70where
71    F: PrimeField,
72{
73    if r.is_empty() {
74        return Err(ArithErrors::InvalidParameters("r length is 0".into()));
75    } else if r.len() == 1 {
76        // initializing the buffer with [1-r_0, r_0]
77        buf.push(F::one_with_cfg(cfg) - &r[0]);
78        buf.push(r[0].clone());
79    } else {
80        build_eq_x_r_helper(&r[1..], buf, cfg)?;
81
82        // suppose at the previous step we received [b_1, ..., b_k]
83        // for the current step we will need
84        // if x_0 = 0:   (1-r0) * [b_1, ..., b_k]
85        // if x_0 = 1:   r0 * [b_1, ..., b_k]
86
87        let mut res = vec![F::zero_with_cfg(cfg); buf.len() << 1];
88        cfg_iter_mut!(res).enumerate().for_each(|(i, val)| {
89            let bi = buf[i >> 1].clone();
90            let tmp = r[0].clone() * &bi;
91            if (i & 1) == 0 {
92                *val = bi - tmp;
93            } else {
94                *val = tmp;
95            }
96        });
97        *buf = res;
98    }
99
100    Ok(())
101}
102
103/// This function build the eq(x, r) polynomial for any given r.
104///
105/// Evaluate
106///      eq(x,y) = \prod_i=1^num_var (x_i * y_i + (1-x_i)*(1-y_i))
107/// over r, which is
108///      eq(x,y) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i))
109pub fn build_eq_x_r_inner<F>(
110    r: &[F],
111    cfg: &F::Config,
112) -> Result<DenseMultilinearExtension<F::Inner>, ArithErrors>
113where
114    F: PrimeField,
115    F::Inner: Zero,
116{
117    let evals = build_eq_x_r_inner_vec(r, cfg)?;
118    let mle = DenseMultilinearExtension {
119        num_vars: r.len(),
120        evaluations: evals,
121    };
122
123    Ok(mle)
124}
125
126/// This function build the eq(x, r) polynomial for any given r, and output the
127/// evaluation of eq(x, r) in its vector form.
128///
129/// Evaluate
130///      eq(x,y) = \prod_i=1^num_var (x_i * y_i + (1-x_i)*(1-y_i))
131/// over r, which is
132///      eq(x,y) = \prod_i=1^num_var (x_i * r_i + (1-x_i)*(1-r_i))
133fn build_eq_x_r_inner_vec<F>(r: &[F], cfg: &F::Config) -> Result<Vec<F::Inner>, ArithErrors>
134where
135    F: PrimeField,
136    F::Inner: Zero,
137{
138    // we build eq(x,r) from its evaluations
139    // we want to evaluate eq(x,r) over x \in {0, 1}^num_vars
140    // for example, with num_vars = 4, x is a binary vector of 4, then
141    //  0 0 0 0 -> (1-r0)   * (1-r1)    * (1-r2)    * (1-r3)
142    //  1 0 0 0 -> r0       * (1-r1)    * (1-r2)    * (1-r3)
143    //  0 1 0 0 -> (1-r0)   * r1        * (1-r2)    * (1-r3)
144    //  1 1 0 0 -> r0       * r1        * (1-r2)    * (1-r3)
145    //  ....
146    //  1 1 1 1 -> r0       * r1        * r2        * r3
147    // we will need 2^num_var evaluations
148
149    let mut eval = Vec::new();
150    build_eq_x_r_inner_helper(r, &mut eval, cfg)?;
151
152    Ok(eval)
153}
154
155/// A helper function to build eq(x, r) recursively.
156/// This function takes `r.len()` steps, and for each step it requires a maximum
157/// `r.len()-1` multiplications.
158fn build_eq_x_r_inner_helper<F>(
159    r: &[F],
160    buf: &mut Vec<F::Inner>,
161    cfg: &F::Config,
162) -> Result<(), ArithErrors>
163where
164    F: PrimeField,
165    F::Inner: Zero,
166{
167    let one = F::one_with_cfg(cfg);
168    if r.is_empty() {
169        return Err(ArithErrors::InvalidParameters("r length is 0".into()));
170    } else if r.len() == 1 {
171        // initializing the buffer with [1-r_0, r_0]
172        buf.push((one - &r[0]).into_inner());
173        buf.push(r[0].inner().clone());
174    } else {
175        build_eq_x_r_inner_helper(&r[1..], buf, cfg)?;
176
177        // suppose at the previous step we received [b_1, ..., b_k]
178        // for the current step we will need
179        // if x_0 = 0:   (1-r0) * [b_1, ..., b_k]
180        // if x_0 = 1:   r0 * [b_1, ..., b_k]
181
182        let mut res = vec![F::Inner::zero(); buf.len() << 1];
183        cfg_iter_mut!(res).enumerate().for_each(|(i, val)| {
184            let bi = F::new_unchecked_with_cfg(buf[i >> 1].clone(), cfg);
185            let tmp = r[0].clone() * &bi;
186            if (i & 1) == 0 {
187                *val = (bi - tmp).into_inner();
188            } else {
189                *val = tmp.into_inner();
190            }
191        });
192        *buf = res;
193    }
194
195    Ok(())
196}
197
198/// Build the shift selector MLE `next_c_mle(r, *)` with the first `num_vars`
199/// variables fixed to `r`.
200///
201/// For each `b in {0,1}^{num_vars}`:
202///   next_c_mle(b) = eq(r, b - c)   if b >= c
203///   next_c_mle(b) = 0              if b < c
204///
205/// Uses the identity `next_c_mle(r, b) = eq(r, b - c)` for `b >= c` and
206/// `0` for `b < c`.
207pub fn build_next_c_r_mle<F>(
208    r: &[F],
209    c: usize,
210    field_cfg: &F::Config,
211) -> Result<DenseMultilinearExtension<F::Inner>, ArithErrors>
212where
213    F: PrimeField,
214    F::Inner: Zero,
215{
216    let num_vars = r.len();
217    let n = 1 << num_vars;
218    assert!(c < n, "shift c={c} must be < domain size {n}");
219    let zero_inner = F::zero_with_cfg(field_cfg).into_inner();
220
221    let eq_r = build_eq_x_r_inner(r, field_cfg)?;
222    if c == 0 {
223        return Ok(eq_r);
224    }
225
226    // next_c_mle(r, 0) = 0 for b < c
227    // next_c_mle(r, b - c) = eq(r, b - c) for b >= c
228    let mut evaluations = Vec::with_capacity(n);
229    evaluations.resize(c, zero_inner);
230    evaluations.extend_from_slice(&eq_r.evaluations[..sub!(n, c)]);
231
232    Ok(DenseMultilinearExtension {
233        num_vars,
234        evaluations,
235    })
236}
237
238/// Evaluate eq polynomial.
239#[allow(clippy::arithmetic_side_effects)]
240pub fn eq_eval<R: Semiring>(x: &[R], y: &[R], one: R) -> Result<R, ArithErrors> {
241    if x.len() != y.len() {
242        return Err(ArithErrors::InvalidParameters(
243            "x and y have different length".to_string(),
244        ));
245    }
246
247    let mut res = one.clone();
248    for (xi, yi) in x.iter().zip(y.iter()) {
249        let xi_yi = xi.clone() * yi;
250        res *= xi_yi.clone() + xi_yi - xi - yi + one.clone();
251    }
252
253    Ok(res)
254}
255
256/// Evaluate an MLE at a point using a precomputed eq table.
257///
258/// Given `evaluations[b]` (in `F::Inner` form) and `eq_table[b] = eq(b, r)`
259/// (precomputed via [`build_eq_x_r_vec`]), returns `\sum_{b} eq_table[b] *
260/// evaluations[b]`.
261///
262/// This is equivalent to `DenseMultilinearExtension::evaluate_with_config`
263/// but avoids cloning the evaluation vector (the fix-variables algorithm is
264/// destructive). When multiple MLEs share the same evaluation point, build the
265/// eq table once and call this function for each MLE.
266#[allow(clippy::arithmetic_side_effects)]
267pub fn mle_eval_with_eq_table<F: InnerTransparentField>(
268    evaluations: &[F::Inner],
269    eq_table: &[F],
270    cfg: &F::Config,
271) -> F {
272    let mut acc = F::zero_with_cfg(cfg);
273    assert_eq!(
274        evaluations.len(),
275        eq_table.len(),
276        "evaluations and eq_table must have the same length"
277    );
278    for (eval, eq_val) in evaluations.iter().zip(eq_table.iter()) {
279        let mut term = eq_val.clone();
280        term.mul_assign_by_inner(eval);
281        acc += &term;
282    }
283    acc
284}
285
286/// Returns a multilinear polynomial in 2n variables that evaluates to 1
287/// if and only if the second n-bit vector is equal to the first vector plus one
288#[allow(clippy::arithmetic_side_effects)]
289pub fn next_mle_inner<F: Field>(
290    num_vars: u32,
291    zero: F,
292    one: F,
293) -> Result<DenseMultilinearExtension<F::Inner>, ArithErrors> {
294    if !num_vars.is_multiple_of(2) {
295        return Err(ArithErrors::InvalidParameters(
296            "num_vars must be even".to_string(),
297        ));
298    }
299
300    let mut mle = (0..1 << num_vars)
301        .map(|_| zero.inner().clone())
302        .collect_dense_mle_with_zero(zero.inner());
303
304    let half_vars = num_vars / 2;
305
306    for i in 0usize..(1 << half_vars) - 1 {
307        let next = i + 1;
308
309        let i_concat_next = (next << half_vars) | i;
310
311        mle[i_concat_next] = one.inner().clone();
312    }
313
314    Ok(mle)
315}
316
317/// Evaluates the next MLE in O(n), by reusing suffix equality and prefix carry
318/// products across carry positions.
319///
320/// Improved from O(n²) approach here: https://github.com/TomWambsgans/Whirlaway/blob/9e3592b/crates/air/src/utils.rs#L92
321///
322/// `next_mle(u, v) = 1` iff `Val(v) = Val(u) + 1` and `Val(u) < 2^n - 1`.
323///
324/// # Arguments
325/// - `u`: first n-bit vector (LE convention: index 0 = LSB).
326/// - `v`: second n-bit vector. Must have `v.len() == u.len()`.
327///
328/// # Algorithm
329/// Uses prefix/suffix products for O(n) evaluation:
330///   `next_mle(u, v) = sum_{j=0}^{n-1}
331///       [prod_{i<j} u_i * (1 - v_i)]      -- bits below j: were 1, flip to 0
332///     * (1 - u_j) * v_j                   -- bit j: 0 → 1
333///     * [prod_{i>j} eq(u_i, v_i)]`        -- bits above j: unchanged
334///
335/// # Panics
336/// Panics if `u.len() != v.len()`.
337#[allow(clippy::arithmetic_side_effects)]
338pub fn next_mle_eval<R: Semiring>(u: &[R], v: &[R], zero: R, one: R) -> R {
339    let n = u.len();
340    assert_eq!(n, v.len(), "u and v must have the same length");
341    if n == 0 {
342        return zero;
343    }
344
345    // suffix_eq[j] = prod_{i=j}^{n-1} eq(u_i, v_i)
346    let mut suffix_eq = vec![one.clone(); n + 1];
347    for i in (0..n).rev() {
348        suffix_eq[i] = suffix_eq[i + 1].clone()
349            * (u[i].clone() * &v[i] + (one.clone() - &u[i]) * (one.clone() - &v[i]));
350    }
351
352    // prefix_carry accumulates prod_{i<j} u_i * (1 - v_i)
353    let mut prefix_carry = one.clone();
354    let mut result = zero;
355    for j in 0..n {
356        result += prefix_carry.clone() * (one.clone() - &u[j]) * &v[j] * &suffix_eq[j + 1];
357        prefix_carry *= u[j].clone() * (one.clone() - &v[j]);
358    }
359    result
360}
361
362#[cfg(test)]
363#[allow(clippy::arithmetic_side_effects, clippy::cast_possible_truncation)]
364mod tests {
365    use crypto_bigint::{U128, const_monty_params};
366    use crypto_primitives::{IntoWithConfig, crypto_bigint_const_monty::ConstMontyField};
367    use num_traits::One;
368    use proptest::{prelude::*, proptest};
369
370    use crate::mle::MultilinearExtensionWithConfig;
371
372    use super::*;
373
374    const_monty_params!(Params, U128, "00000000b933426489189cb5b47d567f");
375
376    type F = ConstMontyField<Params, { U128::LIMBS }>;
377
378    const NUM_VARS: u32 = 8;
379
380    #[test]
381    fn next_mle_is_one_on_successors() {
382        let next_mle = next_mle_inner(NUM_VARS, F::zero(), F::one()).unwrap();
383
384        for i in 0..(1 << ((NUM_VARS / 2) - 1)) {
385            let mut point: Vec<F> = (0..(NUM_VARS / 2))
386                .map(|j| {
387                    if i & (1 << j) == 0 {
388                        F::zero()
389                    } else {
390                        F::one()
391                    }
392                })
393                .collect();
394
395            point.extend((0..(NUM_VARS / 2)).map(|j| {
396                if (i + 1) & (1 << j) == 0 {
397                    F::zero()
398                } else {
399                    F::one()
400                }
401            }));
402
403            assert_eq!(
404                next_mle.clone().evaluate_with_config(&point, &()),
405                Ok(F::one())
406            );
407        }
408    }
409
410    #[test]
411    fn next_mle_is_one_only_on_successors() {
412        let next_mle = next_mle_inner(NUM_VARS, F::zero(), F::one()).unwrap();
413
414        // The number of successors is (1 << (num_vars / 2)) - 1
415        // and we know the mle is one on them. So we need to check
416        // that it is one only on that many points.
417        assert_eq!(
418            next_mle.evaluations.iter().filter(|x| !x.is_zero()).count(),
419            (1 << (NUM_VARS / 2)) - 1
420        );
421    }
422
423    fn any_f(cfg: <F as PrimeField>::Config) -> impl Strategy<Value = F> + 'static {
424        any::<u128>().prop_map(move |v| v.into_with_cfg(&cfg))
425    }
426
427    fn point_n(n: usize) -> impl Strategy<Value = Vec<F>> {
428        prop::collection::vec(any_f(()), n)
429    }
430
431    #[test]
432    fn next_mle_eval_coincides_with_next_mle_evaluated_at_successors() {
433        let next_mle = next_mle_inner(NUM_VARS, F::zero(), F::one()).unwrap();
434
435        for i in 0..(1 << ((NUM_VARS / 2) - 1)) {
436            let mut point: Vec<F> = (0..(NUM_VARS / 2))
437                .map(|j| {
438                    if i & (1 << j) == 0 {
439                        F::zero()
440                    } else {
441                        F::one()
442                    }
443                })
444                .collect();
445
446            point.extend((0..(NUM_VARS / 2)).map(|j| {
447                if (i + 1) & (1 << j) == 0 {
448                    F::zero()
449                } else {
450                    F::one()
451                }
452            }));
453
454            let (u, v) = point.split_at(NUM_VARS as usize / 2);
455            assert_eq!(
456                next_mle.clone().evaluate_with_config(&point, &()),
457                Ok(next_mle_eval(u, v, F::zero(), F::one()))
458            );
459        }
460    }
461
462    proptest! {
463    #[test]
464    fn prop_next_mle_eval_coincides_with_next_mle_evaluate_at_point(r in point_n(NUM_VARS as usize)) {
465        let next_mle = next_mle_inner(NUM_VARS, F::zero(), F::one()).unwrap();
466
467        let (u, v) = r.split_at(NUM_VARS as usize / 2);
468        prop_assert_eq!(
469            next_mle.evaluate_with_config(&r, &()),
470            Ok(next_mle_eval(u, v, F::zero(), F::one()))
471        );
472    }
473    }
474
475    #[test]
476    fn next_c_r_mle_c1_matches_shift_by_1() {
477        // c=1 should give the same result as the original build_next_r_mle
478        let num_vars: usize = 4;
479        let r: Vec<F> = (0..num_vars).map(|i| F::from((i + 3) as u32)).collect();
480
481        let next_1 = build_next_c_r_mle(&r, 1, &()).unwrap();
482
483        // Manually build shift-by-1: evaluations[0] = 0, evaluations[b] = eq(r, b-1)
484        let eq_r = build_eq_x_r_inner(&r, &()).unwrap();
485        let n = 1 << num_vars;
486        let mut expected = vec![F::zero().into_inner(); 1];
487        expected.extend_from_slice(&eq_r.evaluations[..n - 1]);
488
489        assert_eq!(next_1.evaluations, expected);
490    }
491
492    #[test]
493    fn next_c_r_mle_c0_is_eq() {
494        // c=0 should return eq(r, b)
495        let num_vars: usize = 4;
496        let r: Vec<F> = (0..num_vars).map(|i| F::from((i + 7) as u32)).collect();
497
498        let next_0 = build_next_c_r_mle(&r, 0, &()).unwrap();
499        let eq_r = build_eq_x_r_inner(&r, &()).unwrap();
500
501        assert_eq!(next_0.evaluations, eq_r.evaluations);
502    }
503
504    #[test]
505    fn next_c_r_mle_has_correct_structure() {
506        // For any c, evaluations[b] should be:
507        //   0 for b < c
508        //   eq(r, b-c) for b >= c
509        let num_vars: usize = 4;
510        let n = 1 << num_vars;
511        let r: Vec<F> = (0..num_vars).map(|i| F::from((i + 5) as u32)).collect();
512
513        for c in [2, 3, 5, 7] {
514            let next_c = build_next_c_r_mle(&r, c, &()).unwrap();
515            let eq_r = build_eq_x_r_inner(&r, &()).unwrap();
516
517            // First c entries should be zero
518            for b in 0..c {
519                assert!(
520                    next_c.evaluations[b].is_zero(),
521                    "c={c}, b={b}: expected zero"
522                );
523            }
524            // Remaining entries should match eq(r, b-c)
525            for b in c..n {
526                assert_eq!(
527                    next_c.evaluations[b],
528                    eq_r.evaluations[b - c],
529                    "c={c}, b={b}: mismatch"
530                );
531            }
532        }
533    }
534
535    proptest! {
536    #[test]
537    fn prop_next_c_r_mle_evaluates_correctly(r in point_n(4), c in 1..15usize) {
538        // build_next_c_r_mle(r, c) evaluated at random point should equal
539        // the shift-c predicate: sum_b next_c(b) * eq(b, point)
540        let next_c = build_next_c_r_mle(&r, c, &()).unwrap();
541        let eq_r = build_eq_x_r_inner(&r, &()).unwrap();
542
543        // Verify the table structure holds
544        let n = 1 << r.len();
545        for b in 0..c.min(n) {
546            prop_assert!(next_c.evaluations[b].is_zero());
547        }
548        for b in c..n {
549            prop_assert_eq!(&next_c.evaluations[b], &eq_r.evaluations[b - c]);
550        }
551    }
552    }
553}