Skip to main content

zinc_protocol/
fold.rs

1use crypto_primitives::{FromWithConfig, PrimeField, boolean::Boolean};
2use itertools::Itertools;
3use num_traits::Zero;
4use zinc_poly::{
5    mle::DenseMultilinearExtension,
6    univariate::{binary::BinaryPoly, binary_ref::BinaryRefPoly, binary_u64::BinaryU64Poly},
7};
8use zinc_utils::{add, mul};
9
10/// Fold a trace of one type into a trace of another (similar but smaller) type.
11///
12/// Note that folding will increase number of variables for MLEs by
13/// `ilog2(Self::FOLDING_FACTOR)`.
14pub trait FoldTrace<From, To> {
15    /// Folding factor, a positive power of 2.
16    const FOLDING_FACTOR: usize;
17
18    fn fold_trace_mle(mle: &DenseMultilinearExtension<From>) -> DenseMultilinearExtension<To>;
19
20    /// Verifier-side: compute one column's contribution to the folded PCS
21    /// eval-claim at the extended evaluation point
22    /// `r_0 || gamma_1 || ... || gamma_k`, where `k = log2(FOLDING_FACTOR)`.
23    ///
24    /// # Inputs:
25    /// - `bar_u_coeffs`: the column's unfolded lifted-eval coefficients in
26    ///   `F_q[X]`, of formal length `D`. May be shorter than `D` if trailing
27    ///   zero coefficients were trimmed.
28    /// - `alphas`: per-poly PCS alphas, of length `D / FOLDING_FACTOR`,
29    ///   matching the folded entry size.
30    /// - `folding_challenges`: `k` field-typed challenges sampled by the
31    ///   verifier, in sampling order (`gamma_1` first, `gamma_k` last).
32    ///
33    /// Returns the value `<alphas, bar_u_folded>` where `bar_u_folded` is the
34    /// column's lifted-eval polynomial after applying `k` chained 2x splits
35    /// and pinning the appended boolean variables to `(gamma_1, ..., gamma_k)`.
36    fn fold_eval_claim<F, A>(
37        bar_u_coeffs: &[F],
38        alphas: &[A],
39        folding_challenges: &[F],
40        field_cfg: &F::Config,
41    ) -> F
42    where
43        F: PrimeField + for<'a> FromWithConfig<&'a A>,
44    {
45        // MSB-first contiguous chained-2x-split chunking.
46        //
47        // `bar_u_coeffs` is partitioned into `FOLDING_FACTOR` contiguous chunks of
48        // length `D / FOLDING_FACTOR`.
49        //
50        // Their inner products with `alphas` form a `k`-variable multilinear polynomial
51        // which is then evaluated at `(gamma_1, ..., gamma_k)` with `gamma_1` paired
52        // with the high bit of the chunk index.
53        //
54        // This covers `NoopFoldTrace` (k = 0, degenerating to a single inner product)
55        // as well as chained 2x folds.
56
57        debug_assert_eq!(
58            1usize << folding_challenges.len(),
59            Self::FOLDING_FACTOR,
60            "fold_eval_claim: 1 << folding_challenges.len() must equal FOLDING_FACTOR",
61        );
62        debug_assert!(
63            bar_u_coeffs.len() <= mul!(alphas.len(), Self::FOLDING_FACTOR),
64            "fold_eval_claim: bar_u_coeffs.len() must not exceed alphas.len() * FOLDING_FACTOR",
65        );
66
67        let chunk_size = alphas.len();
68        let alphas = alphas
69            .iter()
70            .map(|a| F::from_with_cfg(a, field_cfg))
71            .collect_vec();
72
73        let zero = F::zero_with_cfg(field_cfg);
74        let one = F::one_with_cfg(field_cfg);
75
76        // Step 1: Per-chunk inner products
77        //   P_i = sum_{j < chunk_size} alphas[j] * bar_u_coeffs[i*chunk_size + j],
78        // Trimmed (missing-trailing-zero) coefficients are treated as zero.
79        let chunk_evals: Vec<F> = (0..Self::FOLDING_FACTOR)
80            .map(|i| {
81                let start = mul!(i, chunk_size);
82                let mut acc = zero.clone();
83                for (j, alpha) in alphas.iter().enumerate() {
84                    if let Some(coeff) = bar_u_coeffs.get(add!(start, j)) {
85                        acc += alpha.clone() * coeff;
86                    }
87                }
88                acc
89            })
90            .collect();
91
92        // Step 2: MLE-evaluate the per-chunk inner products at folding_challenges,
93        // MSB-first (gamma_1 = high bit, peeled first).
94        mle_eval_msb_first(chunk_evals, folding_challenges, &one)
95    }
96}
97
98//
99// NOOP fold
100//
101
102pub struct NoopFoldTrace;
103
104impl<T: Clone> FoldTrace<T, T> for NoopFoldTrace {
105    const FOLDING_FACTOR: usize = 1;
106
107    fn fold_trace_mle(mle: &DenseMultilinearExtension<T>) -> DenseMultilinearExtension<T> {
108        mle.clone()
109    }
110}
111
112//
113// Binary folds
114//
115
116pub struct FoldBinaryTrace2x<const D: usize, const HALF_D: usize>;
117
118impl<const D: usize, const HALF_D: usize> FoldTrace<BinaryPoly<D>, BinaryPoly<HALF_D>>
119    for FoldBinaryTrace2x<D, HALF_D>
120{
121    const FOLDING_FACTOR: usize = 2;
122
123    fn fold_trace_mle(
124        mle: &DenseMultilinearExtension<BinaryPoly<D>>,
125    ) -> DenseMultilinearExtension<BinaryPoly<HALF_D>> {
126        split_binary_poly_mle(mle)
127    }
128}
129
130pub struct FoldBinaryTrace4x<const D: usize, const HALF_D: usize, const QUARTER_D: usize>;
131
132impl<const D: usize, const HALF_D: usize, const QUARTER_D: usize>
133    FoldTrace<BinaryPoly<D>, BinaryPoly<QUARTER_D>> for FoldBinaryTrace4x<D, HALF_D, QUARTER_D>
134{
135    const FOLDING_FACTOR: usize = 4;
136
137    fn fold_trace_mle(
138        mle: &DenseMultilinearExtension<BinaryPoly<D>>,
139    ) -> DenseMultilinearExtension<BinaryPoly<QUARTER_D>> {
140        let mle = split_binary_poly_mle::<D, HALF_D>(mle);
141        split_binary_poly_mle::<HALF_D, QUARTER_D>(&mle)
142    }
143}
144
145//
146// Helper functions
147//
148
149/// Split a column of `BinaryPoly<D>` entries into a concatenated column
150/// of `BinaryPoly<HALF_D>` entries.
151///
152/// Each entry `v[i]` with `D` binary coefficients is split into:
153/// - `u[i]` = low `HALF_D` coefficients (indices `0..HALF_D`)
154/// - `w[i]` = high `HALF_D` coefficients (indices `HALF_D..D`)
155///
156/// so that `v[i] = u[i] + X^HALF_D ยท w[i]`.
157///
158/// Returns a column of length `2n` where:
159/// - `v'[0..n]   = u[0..n]`  (low halves)
160/// - `v'[n..2n]  = w[0..n]`  (high halves)
161///
162/// The returned MLE has `num_vars + 1` variables, with the last variable
163/// selecting between the low half (0) and high half (1).
164///
165/// Panics at compile-time if `D != 2 * HALF_D`.
166fn split_binary_poly_mle<const D: usize, const HALF_D: usize>(
167    mle: &DenseMultilinearExtension<BinaryPoly<D>>,
168) -> DenseMultilinearExtension<BinaryPoly<HALF_D>> {
169    const {
170        assert!(D == 2 * HALF_D, "split_column: D must equal 2 * HALF_D");
171    }
172
173    #[cfg(not(feature = "simd"))]
174    let res = split_binary_poly_mle_ref(mle);
175
176    #[cfg(feature = "simd")]
177    let res = split_binary_poly_mle_u64(mle);
178
179    res
180}
181
182#[allow(dead_code)]
183fn split_binary_poly_mle_ref<const D: usize, const HALF_D: usize>(
184    mle: &DenseMultilinearExtension<BinaryRefPoly<D>>,
185) -> DenseMultilinearExtension<BinaryRefPoly<HALF_D>> {
186    let n = mle.evaluations.len();
187    let mut lo_evals = Vec::with_capacity(n);
188    let mut hi_evals = Vec::with_capacity(n);
189
190    for entry in &mle.evaluations {
191        let lo_arr: [Boolean; HALF_D] = std::array::from_fn(|i| entry[i]);
192        let hi_arr: [Boolean; HALF_D] = std::array::from_fn(|i| entry[add!(HALF_D, i)]);
193        lo_evals.push(BinaryRefPoly::<HALF_D>::new(lo_arr));
194        hi_evals.push(BinaryRefPoly::<HALF_D>::new(hi_arr));
195    }
196
197    // Concatenate: v' = u || w (low halves first, high halves second).
198    lo_evals.extend(hi_evals);
199
200    DenseMultilinearExtension::from_evaluations_vec(add!(mle.num_vars, 1), lo_evals, Zero::zero())
201}
202
203#[allow(dead_code)]
204fn split_binary_poly_mle_u64<const D: usize, const HALF_D: usize>(
205    mle: &DenseMultilinearExtension<BinaryU64Poly<D>>,
206) -> DenseMultilinearExtension<BinaryU64Poly<HALF_D>> {
207    let n = mle.evaluations.len();
208    let mut lo_evals: Vec<BinaryU64Poly<HALF_D>> = Vec::with_capacity(n);
209    let mut hi_evals: Vec<BinaryU64Poly<HALF_D>> = Vec::with_capacity(n);
210
211    for entry in &mle.evaluations {
212        let bits: u64 = *entry.inner();
213        // `From<u64>` masks off bits at positions `>= HALF_D` so each half
214        // upholds the `BinaryU64Poly<HALF_D>` invariant.
215        lo_evals.push(BinaryU64Poly::<HALF_D>::from(bits));
216        hi_evals.push(BinaryU64Poly::<HALF_D>::from(bits >> HALF_D));
217    }
218
219    // Concatenate: v' = u || w (low halves first, high halves second), matching
220    // the layout produced by `split_binary_poly_mle_ref`.
221    lo_evals.extend(hi_evals);
222
223    DenseMultilinearExtension::from_evaluations_vec(add!(mle.num_vars, 1), lo_evals, Zero::zero())
224}
225
226/// Multilinear evaluation of `values` (length `2^gammas.len()`, MSB-first
227/// indexed) at point `gammas`. Peels `gammas[0]` (the high bit, equivalently
228/// the first sampled challenge) at each recursive step, splitting `values`
229/// into a lower half (high bit = 0) and an upper half (high bit = 1).
230fn mle_eval_msb_first<F: PrimeField>(values: Vec<F>, gammas: &[F], one: &F) -> F {
231    if gammas.is_empty() {
232        debug_assert_eq!(values.len(), 1);
233        return values.into_iter().next().expect("non-empty values");
234    }
235    debug_assert_eq!(values.len(), 1usize << gammas.len());
236
237    let half = values.len() >> 1;
238    let g = &gammas[0];
239    let one_minus_g = one.clone() - g;
240
241    let mut next: Vec<F> = Vec::with_capacity(half);
242    for i in 0..half {
243        let mut lo = one_minus_g.clone();
244        lo *= &values[i];
245        let mut hi = g.clone();
246        hi *= &values[add!(i, half)];
247        lo += &hi;
248        next.push(lo);
249    }
250    mle_eval_msb_first(next, &gammas[1..], one)
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use rand::{Rng, rng};
257    use zinc_transcript::traits::GenTranscribable;
258
259    /// Build two MLEs (`BinaryRefPoly<D>` and `BinaryU64Poly<D>`) carrying the
260    /// same coefficient pattern from a list of bit-packed `u64` entries.
261    fn build_matched_mles<const D: usize>(
262        bits_list: &[u64],
263    ) -> (
264        DenseMultilinearExtension<BinaryRefPoly<D>>,
265        DenseMultilinearExtension<BinaryU64Poly<D>>,
266    ) {
267        let n = bits_list.len();
268        assert!(n.is_power_of_two(), "n must be a power of two");
269        let num_vars = n.trailing_zeros() as usize;
270
271        let ref_entries: Vec<BinaryRefPoly<D>> = bits_list
272            .iter()
273            .map(|&bits| BinaryRefPoly::read_transcription_bytes_exact(&bits.to_le_bytes()))
274            .collect();
275        let u64_entries: Vec<BinaryU64Poly<D>> = bits_list
276            .iter()
277            .map(|&bits| BinaryU64Poly::from(bits))
278            .collect();
279
280        let ref_mle =
281            DenseMultilinearExtension::from_evaluations_vec(num_vars, ref_entries, Zero::zero());
282        let u64_mle =
283            DenseMultilinearExtension::from_evaluations_vec(num_vars, u64_entries, Zero::zero());
284
285        (ref_mle, u64_mle)
286    }
287
288    /// Run both splitters on matched inputs and assert that every output
289    /// coefficient agrees bit-for-bit.
290    fn assert_split_matches<const D: usize, const HALF_D: usize>(bits_list: Vec<u64>) {
291        let (ref_mle, u64_mle) = build_matched_mles::<D>(&bits_list);
292
293        let split_ref = split_binary_poly_mle_ref::<D, HALF_D>(&ref_mle);
294        let split_u64 = split_binary_poly_mle_u64::<D, HALF_D>(&u64_mle);
295
296        assert_eq!(split_ref.num_vars, split_u64.num_vars);
297        assert_eq!(split_ref.evaluations.len(), split_u64.evaluations.len());
298        for (idx, (r, u)) in split_ref
299            .evaluations
300            .iter()
301            .zip(split_u64.evaluations.iter())
302            .enumerate()
303        {
304            // Compare bits in two different ways - directly, as via iterator
305
306            for i in 0..HALF_D {
307                let r_bit = r[i].inner();
308                let u_bit = ((*u.inner()) >> i) & 1 != 0;
309                assert_eq!(
310                    r_bit, u_bit,
311                    "mismatch at output entry {idx}, coefficient bit {i}",
312                );
313            }
314
315            for (i, pair) in r.iter().zip_longest(u.iter()).enumerate() {
316                match pair {
317                    itertools::EitherOrBoth::Both(r_bit, u_bit) => {
318                        assert_eq!(
319                            r_bit.inner(),
320                            *u_bit,
321                            "mismatch at output entry {idx}, coefficient bit {i}",
322                        );
323                    }
324                    itertools::EitherOrBoth::Left(_) | itertools::EitherOrBoth::Right(_) => {
325                        panic!("mismatch in number of coefficients at output entry {idx}");
326                    }
327                }
328            }
329        }
330    }
331
332    #[test]
333    fn split_ref_and_u64_match_d4_exhaustive() {
334        // Single-entry input: enumerate all 16 bit patterns.
335        for bits in 0u64..16 {
336            assert_split_matches::<4, 2>(vec![bits]);
337        }
338
339        // 4-entry input: enumerate all 16^4 = 65536 patterns is too much; sample.
340        let mut rng = rng();
341        for _ in 0..32 {
342            let bits_list: Vec<u64> = (0..4).map(|_| rng.random::<u64>() & 0xF).collect();
343            assert_split_matches::<4, 2>(bits_list);
344        }
345    }
346
347    #[test]
348    fn split_ref_and_u64_match_random() {
349        let mut rng = rng();
350
351        for n_log in 0..=3 {
352            let n = 1usize << n_log;
353            let bits_list: Vec<u64> = (0..n).map(|_| rng.random::<u64>() & 0xF).collect();
354            assert_split_matches::<4, 2>(bits_list);
355        }
356
357        for n_log in 0..=4 {
358            let n = 1usize << n_log;
359            let bits_list: Vec<u64> = (0..n).map(|_| rng.random::<u64>() & 0xFF).collect();
360            assert_split_matches::<8, 4>(bits_list);
361        }
362
363        for n_log in 0..=5 {
364            let n = 1usize << n_log;
365            let bits_list: Vec<u64> = (0..n).map(|_| rng.random::<u64>() & 0xFFFF_FFFF).collect();
366            assert_split_matches::<32, 16>(bits_list);
367        }
368
369        for n_log in 0..=6 {
370            let n = 1usize << n_log;
371            let bits_list: Vec<u64> = (0..n).map(|_| rng.random::<u64>()).collect();
372            assert_split_matches::<64, 32>(bits_list);
373        }
374    }
375
376    #[test]
377    fn split_u64_pins_all_zero_entries() {
378        // Zero input should round-trip to all-zero output regardless of D.
379        let bits_list = vec![0u64; 8];
380        assert_split_matches::<8, 4>(bits_list.clone());
381        assert_split_matches::<32, 16>(bits_list.clone());
382        assert_split_matches::<64, 32>(bits_list);
383    }
384
385    #[test]
386    fn split_u64_handles_all_ones_d64() {
387        // All-ones (every bit set) is the high-edge case for D=64 since the
388        // mask `(1 << 64) - 1` is not directly representable. Expect lo = hi =
389        // 2^32 - 1 for D=64, HALF_D=32.
390        let bits_list = vec![u64::MAX; 4];
391        assert_split_matches::<64, 32>(bits_list);
392    }
393}