1use crypto_primitives::{FromWithConfig, PrimeField, boolean::Boolean};
2use itertools::Itertools;
3use num_traits::Zero;
4use zinc_poly::{
5 mle::DenseMultilinearExtension,
6 univariate::{binary::BinaryPoly, binary_ref::BinaryRefPoly, binary_u64::BinaryU64Poly},
7};
8use zinc_utils::{add, mul};
9
10pub trait FoldTrace<From, To> {
15 const FOLDING_FACTOR: usize;
17
18 fn fold_trace_mle(mle: &DenseMultilinearExtension<From>) -> DenseMultilinearExtension<To>;
19
20 fn fold_eval_claim<F, A>(
37 bar_u_coeffs: &[F],
38 alphas: &[A],
39 folding_challenges: &[F],
40 field_cfg: &F::Config,
41 ) -> F
42 where
43 F: PrimeField + for<'a> FromWithConfig<&'a A>,
44 {
45 debug_assert_eq!(
58 1usize << folding_challenges.len(),
59 Self::FOLDING_FACTOR,
60 "fold_eval_claim: 1 << folding_challenges.len() must equal FOLDING_FACTOR",
61 );
62 debug_assert!(
63 bar_u_coeffs.len() <= mul!(alphas.len(), Self::FOLDING_FACTOR),
64 "fold_eval_claim: bar_u_coeffs.len() must not exceed alphas.len() * FOLDING_FACTOR",
65 );
66
67 let chunk_size = alphas.len();
68 let alphas = alphas
69 .iter()
70 .map(|a| F::from_with_cfg(a, field_cfg))
71 .collect_vec();
72
73 let zero = F::zero_with_cfg(field_cfg);
74 let one = F::one_with_cfg(field_cfg);
75
76 let chunk_evals: Vec<F> = (0..Self::FOLDING_FACTOR)
80 .map(|i| {
81 let start = mul!(i, chunk_size);
82 let mut acc = zero.clone();
83 for (j, alpha) in alphas.iter().enumerate() {
84 if let Some(coeff) = bar_u_coeffs.get(add!(start, j)) {
85 acc += alpha.clone() * coeff;
86 }
87 }
88 acc
89 })
90 .collect();
91
92 mle_eval_msb_first(chunk_evals, folding_challenges, &one)
95 }
96}
97
98pub struct NoopFoldTrace;
103
104impl<T: Clone> FoldTrace<T, T> for NoopFoldTrace {
105 const FOLDING_FACTOR: usize = 1;
106
107 fn fold_trace_mle(mle: &DenseMultilinearExtension<T>) -> DenseMultilinearExtension<T> {
108 mle.clone()
109 }
110}
111
112pub struct FoldBinaryTrace2x<const D: usize, const HALF_D: usize>;
117
118impl<const D: usize, const HALF_D: usize> FoldTrace<BinaryPoly<D>, BinaryPoly<HALF_D>>
119 for FoldBinaryTrace2x<D, HALF_D>
120{
121 const FOLDING_FACTOR: usize = 2;
122
123 fn fold_trace_mle(
124 mle: &DenseMultilinearExtension<BinaryPoly<D>>,
125 ) -> DenseMultilinearExtension<BinaryPoly<HALF_D>> {
126 split_binary_poly_mle(mle)
127 }
128}
129
130pub struct FoldBinaryTrace4x<const D: usize, const HALF_D: usize, const QUARTER_D: usize>;
131
132impl<const D: usize, const HALF_D: usize, const QUARTER_D: usize>
133 FoldTrace<BinaryPoly<D>, BinaryPoly<QUARTER_D>> for FoldBinaryTrace4x<D, HALF_D, QUARTER_D>
134{
135 const FOLDING_FACTOR: usize = 4;
136
137 fn fold_trace_mle(
138 mle: &DenseMultilinearExtension<BinaryPoly<D>>,
139 ) -> DenseMultilinearExtension<BinaryPoly<QUARTER_D>> {
140 let mle = split_binary_poly_mle::<D, HALF_D>(mle);
141 split_binary_poly_mle::<HALF_D, QUARTER_D>(&mle)
142 }
143}
144
145fn split_binary_poly_mle<const D: usize, const HALF_D: usize>(
167 mle: &DenseMultilinearExtension<BinaryPoly<D>>,
168) -> DenseMultilinearExtension<BinaryPoly<HALF_D>> {
169 const {
170 assert!(D == 2 * HALF_D, "split_column: D must equal 2 * HALF_D");
171 }
172
173 #[cfg(not(feature = "simd"))]
174 let res = split_binary_poly_mle_ref(mle);
175
176 #[cfg(feature = "simd")]
177 let res = split_binary_poly_mle_u64(mle);
178
179 res
180}
181
182#[allow(dead_code)]
183fn split_binary_poly_mle_ref<const D: usize, const HALF_D: usize>(
184 mle: &DenseMultilinearExtension<BinaryRefPoly<D>>,
185) -> DenseMultilinearExtension<BinaryRefPoly<HALF_D>> {
186 let n = mle.evaluations.len();
187 let mut lo_evals = Vec::with_capacity(n);
188 let mut hi_evals = Vec::with_capacity(n);
189
190 for entry in &mle.evaluations {
191 let lo_arr: [Boolean; HALF_D] = std::array::from_fn(|i| entry[i]);
192 let hi_arr: [Boolean; HALF_D] = std::array::from_fn(|i| entry[add!(HALF_D, i)]);
193 lo_evals.push(BinaryRefPoly::<HALF_D>::new(lo_arr));
194 hi_evals.push(BinaryRefPoly::<HALF_D>::new(hi_arr));
195 }
196
197 lo_evals.extend(hi_evals);
199
200 DenseMultilinearExtension::from_evaluations_vec(add!(mle.num_vars, 1), lo_evals, Zero::zero())
201}
202
203#[allow(dead_code)]
204fn split_binary_poly_mle_u64<const D: usize, const HALF_D: usize>(
205 mle: &DenseMultilinearExtension<BinaryU64Poly<D>>,
206) -> DenseMultilinearExtension<BinaryU64Poly<HALF_D>> {
207 let n = mle.evaluations.len();
208 let mut lo_evals: Vec<BinaryU64Poly<HALF_D>> = Vec::with_capacity(n);
209 let mut hi_evals: Vec<BinaryU64Poly<HALF_D>> = Vec::with_capacity(n);
210
211 for entry in &mle.evaluations {
212 let bits: u64 = *entry.inner();
213 lo_evals.push(BinaryU64Poly::<HALF_D>::from(bits));
216 hi_evals.push(BinaryU64Poly::<HALF_D>::from(bits >> HALF_D));
217 }
218
219 lo_evals.extend(hi_evals);
222
223 DenseMultilinearExtension::from_evaluations_vec(add!(mle.num_vars, 1), lo_evals, Zero::zero())
224}
225
226fn mle_eval_msb_first<F: PrimeField>(values: Vec<F>, gammas: &[F], one: &F) -> F {
231 if gammas.is_empty() {
232 debug_assert_eq!(values.len(), 1);
233 return values.into_iter().next().expect("non-empty values");
234 }
235 debug_assert_eq!(values.len(), 1usize << gammas.len());
236
237 let half = values.len() >> 1;
238 let g = &gammas[0];
239 let one_minus_g = one.clone() - g;
240
241 let mut next: Vec<F> = Vec::with_capacity(half);
242 for i in 0..half {
243 let mut lo = one_minus_g.clone();
244 lo *= &values[i];
245 let mut hi = g.clone();
246 hi *= &values[add!(i, half)];
247 lo += &hi;
248 next.push(lo);
249 }
250 mle_eval_msb_first(next, &gammas[1..], one)
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use rand::{Rng, rng};
257 use zinc_transcript::traits::GenTranscribable;
258
259 fn build_matched_mles<const D: usize>(
262 bits_list: &[u64],
263 ) -> (
264 DenseMultilinearExtension<BinaryRefPoly<D>>,
265 DenseMultilinearExtension<BinaryU64Poly<D>>,
266 ) {
267 let n = bits_list.len();
268 assert!(n.is_power_of_two(), "n must be a power of two");
269 let num_vars = n.trailing_zeros() as usize;
270
271 let ref_entries: Vec<BinaryRefPoly<D>> = bits_list
272 .iter()
273 .map(|&bits| BinaryRefPoly::read_transcription_bytes_exact(&bits.to_le_bytes()))
274 .collect();
275 let u64_entries: Vec<BinaryU64Poly<D>> = bits_list
276 .iter()
277 .map(|&bits| BinaryU64Poly::from(bits))
278 .collect();
279
280 let ref_mle =
281 DenseMultilinearExtension::from_evaluations_vec(num_vars, ref_entries, Zero::zero());
282 let u64_mle =
283 DenseMultilinearExtension::from_evaluations_vec(num_vars, u64_entries, Zero::zero());
284
285 (ref_mle, u64_mle)
286 }
287
288 fn assert_split_matches<const D: usize, const HALF_D: usize>(bits_list: Vec<u64>) {
291 let (ref_mle, u64_mle) = build_matched_mles::<D>(&bits_list);
292
293 let split_ref = split_binary_poly_mle_ref::<D, HALF_D>(&ref_mle);
294 let split_u64 = split_binary_poly_mle_u64::<D, HALF_D>(&u64_mle);
295
296 assert_eq!(split_ref.num_vars, split_u64.num_vars);
297 assert_eq!(split_ref.evaluations.len(), split_u64.evaluations.len());
298 for (idx, (r, u)) in split_ref
299 .evaluations
300 .iter()
301 .zip(split_u64.evaluations.iter())
302 .enumerate()
303 {
304 for i in 0..HALF_D {
307 let r_bit = r[i].inner();
308 let u_bit = ((*u.inner()) >> i) & 1 != 0;
309 assert_eq!(
310 r_bit, u_bit,
311 "mismatch at output entry {idx}, coefficient bit {i}",
312 );
313 }
314
315 for (i, pair) in r.iter().zip_longest(u.iter()).enumerate() {
316 match pair {
317 itertools::EitherOrBoth::Both(r_bit, u_bit) => {
318 assert_eq!(
319 r_bit.inner(),
320 *u_bit,
321 "mismatch at output entry {idx}, coefficient bit {i}",
322 );
323 }
324 itertools::EitherOrBoth::Left(_) | itertools::EitherOrBoth::Right(_) => {
325 panic!("mismatch in number of coefficients at output entry {idx}");
326 }
327 }
328 }
329 }
330 }
331
332 #[test]
333 fn split_ref_and_u64_match_d4_exhaustive() {
334 for bits in 0u64..16 {
336 assert_split_matches::<4, 2>(vec![bits]);
337 }
338
339 let mut rng = rng();
341 for _ in 0..32 {
342 let bits_list: Vec<u64> = (0..4).map(|_| rng.random::<u64>() & 0xF).collect();
343 assert_split_matches::<4, 2>(bits_list);
344 }
345 }
346
347 #[test]
348 fn split_ref_and_u64_match_random() {
349 let mut rng = rng();
350
351 for n_log in 0..=3 {
352 let n = 1usize << n_log;
353 let bits_list: Vec<u64> = (0..n).map(|_| rng.random::<u64>() & 0xF).collect();
354 assert_split_matches::<4, 2>(bits_list);
355 }
356
357 for n_log in 0..=4 {
358 let n = 1usize << n_log;
359 let bits_list: Vec<u64> = (0..n).map(|_| rng.random::<u64>() & 0xFF).collect();
360 assert_split_matches::<8, 4>(bits_list);
361 }
362
363 for n_log in 0..=5 {
364 let n = 1usize << n_log;
365 let bits_list: Vec<u64> = (0..n).map(|_| rng.random::<u64>() & 0xFFFF_FFFF).collect();
366 assert_split_matches::<32, 16>(bits_list);
367 }
368
369 for n_log in 0..=6 {
370 let n = 1usize << n_log;
371 let bits_list: Vec<u64> = (0..n).map(|_| rng.random::<u64>()).collect();
372 assert_split_matches::<64, 32>(bits_list);
373 }
374 }
375
376 #[test]
377 fn split_u64_pins_all_zero_entries() {
378 let bits_list = vec![0u64; 8];
380 assert_split_matches::<8, 4>(bits_list.clone());
381 assert_split_matches::<32, 16>(bits_list.clone());
382 assert_split_matches::<64, 32>(bits_list);
383 }
384
385 #[test]
386 fn split_u64_handles_all_ones_d64() {
387 let bits_list = vec![u64::MAX; 4];
391 assert_split_matches::<64, 32>(bits_list);
392 }
393}