Skip to main content

zinc_piop/
shift_predicate.rs

1//! Shift predicate evaluation.
2//!
3//! Evaluates `S_c(x, y)` — the multilinear extension of the shift-by-c
4//! indicator — at arbitrary field points.
5
6use crypto_primitives::PrimeField;
7use zinc_poly::utils::next_mle_eval;
8
9/// Evaluate the shift predicate `S_c(x, y)` at arbitrary field points.
10///
11/// Uses the high/low decomposition:
12///   `S_c(x, y) = L_0(x_lo, y_lo) · eq(x_hi, y_hi)
13///              + L_1(x_lo, y_lo) · next_mle(x_hi, y_hi)`
14///
15/// where `k = ceil(log2(2c))` determines the split point.
16///
17/// Cost: O(m + c · log c) field operations.
18#[allow(clippy::arithmetic_side_effects)]
19pub fn eval_shift_predicate<F: PrimeField>(x: &[F], y: &[F], c: usize, cfg: &F::Config) -> F {
20    let m = x.len();
21    assert_eq!(y.len(), m);
22    let zero = F::zero_with_cfg(cfg);
23    let one = F::one_with_cfg(cfg);
24
25    // S_0(x, y) = eq(x, y): identity shift.
26    if c == 0 {
27        return eval_eq_poly(x, y, &one);
28    }
29
30    // S_1(x, y) = next_mle(x, y): the successor predicate is exactly shift-by-1.
31    if c == 1 {
32        return next_mle_eval(x, y, zero, one);
33    }
34
35    assert!(c < (1usize << m), "shift c must satisfy c < 2^m");
36    // k = ceil(log2(2*c))
37    let k = (2 * c).next_power_of_two().trailing_zeros() as usize;
38    if k >= m {
39        return eval_shift_small(x, y, c, m, &zero, &one);
40    }
41
42    // LE convention: x[0..k] are the low bits, x[k..] are the high bits.
43    let (x_lo, x_hi) = x.split_at(k);
44    let (y_lo, y_hi) = y.split_at(k);
45
46    let l0 = eval_l0(x_lo, y_lo, c, k, &zero, &one);
47    let l1 = eval_l1(x_lo, y_lo, c, k, &zero, &one);
48    let eq = eval_eq_poly(x_hi, y_hi, &one);
49    let next = next_mle_eval(x_hi, y_hi, zero, one);
50    l0 * eq + l1 * next
51}
52
53/// `eq(u, v) = prod_i (u_i * v_i + (1 - u_i)(1 - v_i))`
54///
55/// Evaluates the Multilinear polynomial for eq polynomial
56pub(crate) fn eval_eq_poly<F: PrimeField>(u: &[F], v: &[F], one: &F) -> F {
57    u.iter()
58        .zip(v.iter())
59        .map(|(u_i, v_i)| u_i.clone() * v_i + (one.clone() - u_i) * (one.clone() - v_i))
60        .fold(one.clone(), |acc, term| acc * term)
61}
62
63/// `delta_{bin_k(a)}(u) = eq(u, bin_k(a))`.
64///
65/// Evaluates the Lagrange basis polynomial for the binary encoding of `a`
66/// with `k` bits at the point `u`.
67///
68/// LE convention: `u[i]` corresponds to bit `i` (LSB = index 0).
69pub(crate) fn eval_delta<F: PrimeField>(u: &[F], a: usize, k: usize, one: &F) -> F {
70    let mut result = one.clone();
71    for (i, u) in u.iter().take(k).enumerate() {
72        let bit = (a >> i) & 1;
73        if bit == 1 {
74            result *= u;
75        } else {
76            result *= one.clone() - u
77        }
78    }
79    result
80}
81
82/// `L_0^{(c)}(x_lo, y_lo)` — no-carry component.
83///
84/// `sum_{a=0}^{2^k - 1 - c} delta(x_lo, a) * delta(y_lo, a + c)`
85///
86/// On Booleans: 1 iff `Val(y_lo) = Val(x_lo) + c` with no carry into the high
87/// block.
88#[allow(clippy::arithmetic_side_effects)]
89pub(crate) fn eval_l0<F: PrimeField>(
90    x_lo: &[F],
91    y_lo: &[F],
92    c: usize,
93    k: usize,
94    zero: &F,
95    one: &F,
96) -> F {
97    let upper = (1 << k) - c;
98    (0..upper).fold(zero.clone(), |acc, a| {
99        acc + eval_delta(x_lo, a, k, one) * eval_delta(y_lo, a + c, k, one)
100    })
101}
102
103/// `L_1^{(c)}(x_lo, y_lo)` — carry component.
104///
105/// `sum_{a=2^k-c}^{2^k-1} delta(x_lo, a) * delta(y_lo, a + c - 2^k)`
106///
107/// On Booleans: 1 iff the addition carries into the high block.
108#[allow(clippy::arithmetic_side_effects)]
109pub(crate) fn eval_l1<F: PrimeField>(
110    x_lo: &[F],
111    y_lo: &[F],
112    c: usize,
113    k: usize,
114    zero: &F,
115    one: &F,
116) -> F {
117    let two_k = 1 << k;
118    ((two_k - c)..two_k).fold(zero.clone(), |acc, a| {
119        acc + eval_delta(x_lo, a, k, one) * eval_delta(y_lo, a + c - two_k, k, one)
120    })
121}
122
123/// Special case when `k >= m`: no high block, direct evaluation.
124///
125/// `sum_{a=0}^{n-1-c} delta(x, a, m) * delta(y, a+c, m)`
126#[allow(clippy::arithmetic_side_effects)]
127fn eval_shift_small<F: PrimeField>(x: &[F], y: &[F], c: usize, m: usize, zero: &F, one: &F) -> F {
128    let upper = (1 << m) - c;
129    (0..upper).fold(zero.clone(), |acc, a| {
130        acc + eval_delta(x, a, m, one) * eval_delta(y, a + c, m, one)
131    })
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use crate::test_utils::test_config;
138    use crypto_primitives::{Field, FromWithConfig, crypto_bigint_monty::MontyField};
139    use rand::Rng;
140    use zinc_poly::utils::{build_eq_x_r_inner, build_next_c_r_mle};
141
142    type F = MontyField<4>;
143
144    /// LE convention: to_bin(val, i) = bit i of val (LSB = index 0).
145    fn to_bin(val: usize, bit: usize, cfg: &<F as PrimeField>::Config) -> F {
146        if (val >> bit) & 1 == 1 {
147            F::one_with_cfg(cfg)
148        } else {
149            F::zero_with_cfg(cfg)
150        }
151    }
152
153    /// Convert F::Inner back to F.
154    fn from_inner(inner: <F as Field>::Inner, cfg: &<F as PrimeField>::Config) -> F {
155        let mut f = F::zero_with_cfg(cfg);
156        *f.inner_mut() = inner;
157        f
158    }
159
160    fn rand_field(rng: &mut impl Rng, cfg: &<F as PrimeField>::Config) -> F {
161        F::from_with_cfg(rng.random::<u32>(), cfg)
162    }
163
164    /// Check S_c on Boolean inputs: S_c(bin(a), bin(a+c)) = 1,
165    /// and S_c(bin(a), bin(b)) = 0 for b != a+c.
166    #[test]
167    fn test_shift_predicate_boolean() {
168        let cfg = test_config();
169        let m = 4;
170        let n = 1usize << m;
171
172        for c in [1, 2, 5] {
173            for a in 0..n {
174                for b in 0..n {
175                    let x: Vec<F> = (0..m).map(|i| to_bin(a, i, &cfg)).collect();
176                    let y: Vec<F> = (0..m).map(|i| to_bin(b, i, &cfg)).collect();
177                    let val = eval_shift_predicate(&x, &y, c, &cfg);
178
179                    if b == a + c && a + c < n {
180                        assert_eq!(
181                            val,
182                            F::one_with_cfg(&cfg),
183                            "S_c({a},{b}) should be 1 for c={c}"
184                        );
185                    } else {
186                        assert_eq!(
187                            val,
188                            F::zero_with_cfg(&cfg),
189                            "S_c({a},{b}) should be 0 for c={c}"
190                        );
191                    }
192                }
193            }
194        }
195    }
196
197    /// Verify the next_mle on all Boolean inputs.
198    #[test]
199    fn test_next_boolean() {
200        let cfg = test_config();
201        let m = 4;
202        let n = 1usize << m;
203        for a in 0..n {
204            for b in 0..n {
205                let u: Vec<F> = (0..m).map(|i| to_bin(a, i, &cfg)).collect();
206                let v: Vec<F> = (0..m).map(|i| to_bin(b, i, &cfg)).collect();
207                let val = next_mle_eval(&u, &v, F::zero_with_cfg(&cfg), F::one_with_cfg(&cfg));
208
209                if b == a + 1 && a + 1 < n {
210                    assert_eq!(val, F::one_with_cfg(&cfg), "Next({a},{b}) should be 1");
211                } else {
212                    assert_eq!(val, F::zero_with_cfg(&cfg), "Next({a},{b}) should be 0");
213                }
214            }
215        }
216    }
217
218    /// Check verifier (`eval_shift_predicate`) against prover
219    /// (`build_next_c_r_mle`) at Boolean points:
220    ///   eval_shift_predicate(r, bin(b), c) == build_next_c_r_mle(r, c)[b]
221    #[test]
222    fn test_shift_predicate_vs_prover_mle() {
223        let cfg = test_config();
224        let mut rng = rand::rng();
225        let m = 4;
226        let n = 1usize << m;
227        let c = 3;
228
229        let r: Vec<F> = (0..m).map(|_| rand_field(&mut rng, &cfg)).collect();
230        let next_c = build_next_c_r_mle(&r, c, &cfg).unwrap();
231
232        for b in 0..n {
233            let b_bin: Vec<F> = (0..m).map(|i| to_bin(b, i, &cfg)).collect();
234            let val = eval_shift_predicate(&r, &b_bin, c, &cfg);
235            let expected = from_inner(next_c.evaluations[b], &cfg);
236            assert_eq!(val, expected, "S_{c}(r, bin({b})) mismatch with prover MLE");
237        }
238    }
239
240    /// Check at random field points via MLE summation:
241    ///   eval_shift_predicate(r, y, c) == sum_b build_next_c_r_mle(r, c)[b] *
242    /// eq(b, y)
243    #[test]
244    fn test_shift_predicate_random_points() {
245        let cfg = test_config();
246        let mut rng = rand::rng();
247        let m = 4;
248        let c = 3;
249
250        for _ in 0..8 {
251            let r: Vec<F> = (0..m).map(|_| rand_field(&mut rng, &cfg)).collect();
252            let y: Vec<F> = (0..m).map(|_| rand_field(&mut rng, &cfg)).collect();
253
254            let next_c = build_next_c_r_mle(&r, c, &cfg).unwrap();
255            let eq_y = build_eq_x_r_inner(&y, &cfg).unwrap();
256            let zero = F::zero_with_cfg(&cfg);
257            let rhs = next_c
258                .evaluations
259                .iter()
260                .zip(eq_y.evaluations.iter())
261                .fold(zero, |acc, (ni, ei)| {
262                    acc + from_inner(*ni, &cfg) * from_inner(*ei, &cfg)
263                });
264            let lhs = eval_shift_predicate(&r, &y, c, &cfg);
265
266            assert_eq!(lhs, rhs, "random-point MLE mismatch");
267        }
268    }
269
270    /// Test c=0 (identity) and c=1 (successor) fast paths at random points,
271    /// and verify predicate vs prover MLE consistency across multiple shift
272    /// amounts.
273    #[test]
274    fn test_fast_paths_and_multi_c() {
275        let cfg = test_config();
276        let mut rng = rand::rng();
277        let m = 4;
278        let n = 1usize << m;
279
280        for c in [0, 1, 2, 5, 7] {
281            let r: Vec<F> = (0..m).map(|_| rand_field(&mut rng, &cfg)).collect();
282            let next_c = build_next_c_r_mle(&r, c, &cfg).unwrap();
283
284            // Predicate vs prover MLE at Boolean y
285            for b in 0..n {
286                let b_bin: Vec<F> = (0..m).map(|i| to_bin(b, i, &cfg)).collect();
287                let val = eval_shift_predicate(&r, &b_bin, c, &cfg);
288                let expected = from_inner(next_c.evaluations[b], &cfg);
289                assert_eq!(val, expected, "S_{c}(r, bin({b})) mismatch with prover MLE");
290            }
291
292            // Predicate vs prover MLE at random y (MLE consistency)
293            for _ in 0..4 {
294                let y: Vec<F> = (0..m).map(|_| rand_field(&mut rng, &cfg)).collect();
295                let eq_y = build_eq_x_r_inner(&y, &cfg).unwrap();
296                let zero = F::zero_with_cfg(&cfg);
297                let rhs = next_c
298                    .evaluations
299                    .iter()
300                    .zip(eq_y.evaluations.iter())
301                    .fold(zero.clone(), |acc, (ni, ei)| {
302                        acc + from_inner(*ni, &cfg) * from_inner(*ei, &cfg)
303                    });
304                let lhs = eval_shift_predicate(&r, &y, c, &cfg);
305                assert_eq!(lhs, rhs, "random-point MLE mismatch for c={c}");
306            }
307        }
308    }
309
310    /// Boundary test: large c values where most rows shift beyond the domain.
311    #[test]
312    fn test_shift_predicate_boundary() {
313        let cfg = test_config();
314        let m = 3;
315        let n = 1usize << m; // 8
316
317        for c in [n / 2, n - 1] {
318            // Boolean correctness: S_c(bin(a), bin(b)) = 1 iff b == a+c < n
319            for a in 0..n {
320                for b in 0..n {
321                    let x: Vec<F> = (0..m).map(|i| to_bin(a, i, &cfg)).collect();
322                    let y: Vec<F> = (0..m).map(|i| to_bin(b, i, &cfg)).collect();
323                    let val = eval_shift_predicate(&x, &y, c, &cfg);
324
325                    if b == a + c && a + c < n {
326                        assert_eq!(
327                            val,
328                            F::one_with_cfg(&cfg),
329                            "S_{c}(bin({a}), bin({b})) should be 1"
330                        );
331                    } else {
332                        assert_eq!(
333                            val,
334                            F::zero_with_cfg(&cfg),
335                            "S_{c}(bin({a}), bin({b})) should be 0"
336                        );
337                    }
338                }
339            }
340
341            // Prover MLE: first c entries zero, rest match eq(r, b-c)
342            let mut rng = rand::rng();
343            let r: Vec<F> = (0..m).map(|_| rand_field(&mut rng, &cfg)).collect();
344            let next_c = build_next_c_r_mle(&r, c, &cfg).unwrap();
345            let zero_inner = *F::zero_with_cfg(&cfg).inner();
346
347            // First c entries must be zero
348            for b in 0..c {
349                assert_eq!(
350                    next_c.evaluations[b], zero_inner,
351                    "next_c[{b}] should be zero for c={c}"
352                );
353            }
354            // Remaining entries should be nonzero (with overwhelming probability)
355            let nonzero_count = next_c.evaluations[c..]
356                .iter()
357                .filter(|e| **e != zero_inner)
358                .count();
359            assert_eq!(
360                nonzero_count,
361                n - c,
362                "expected {} nonzero entries for c={c}",
363                n - c
364            );
365        }
366    }
367
368    /// Check that build_next_c_r_mle correctly reproduces MLE[shift_c(v)](r)
369    /// via inner product: sum_b next_c(b) * v[b] == sum_b eq(r, b-c) * v[b].
370    #[test]
371    fn test_prover_mle_inner_product() {
372        let cfg = test_config();
373        let mut rng = rand::rng();
374        let m = 4;
375        let n = 1usize << m;
376
377        for c in [1, 2, 3, 7] {
378            let v: Vec<F> = (0..n).map(|_| rand_field(&mut rng, &cfg)).collect();
379            let r: Vec<F> = (0..m).map(|_| rand_field(&mut rng, &cfg)).collect();
380
381            // Ground truth: sum_{b>=c} eq(r, b-c) * v[b]
382            let eq_r = build_eq_x_r_inner(&r, &cfg).unwrap();
383            let zero = F::zero_with_cfg(&cfg);
384            let expected = (c..n).fold(zero.clone(), |acc, b| {
385                acc + from_inner(eq_r.evaluations[b - c], &cfg) * &v[b]
386            });
387
388            // Via prover MLE: sum_b next_c[b] * v[b]
389            let next_c = build_next_c_r_mle(&r, c, &cfg).unwrap();
390            let got = next_c
391                .evaluations
392                .iter()
393                .zip(v.iter())
394                .fold(zero, |acc, (ni, vi)| {
395                    acc + vi.clone() * from_inner(*ni, &cfg)
396                });
397
398            assert_eq!(got, expected, "prover MLE inner product mismatch for c={c}");
399        }
400    }
401}