1use crate::{code::LinearCode, pcs::structs::ZipTypes, utils::shuffle_seeded};
2use crypto_primitives::PrimeField;
3use num_traits::CheckedAdd;
4use std::{fmt::Debug, marker::PhantomData, ops::AddAssign};
5use zinc_poly::ConstCoeffBitWidth;
6use zinc_utils::{add, from_ref::FromRef, mul};
7
8pub trait RaaConfig: Copy + Send + Sync {
9 const PERMUTE_IN_PLACE: bool;
12 const CHECK_FOR_OVERFLOWS: bool;
15}
16
17#[derive(Clone)]
20pub struct RaaCode<Zt: ZipTypes, Config: RaaConfig, const REP: usize> {
21 pub(crate) row_len: usize,
22 pub(crate) perm_1_seed: u64,
24
25 pub(crate) perm_2_seed: u64,
27
28 pub(crate) perm_1: Vec<usize>,
30
31 pub(crate) perm_2: Vec<usize>,
33
34 phantom: PhantomData<(Zt, Config)>,
35}
36
37impl<Zt: ZipTypes, Config: RaaConfig, const REP: usize> RaaCode<Zt, Config, REP> {
38 pub fn new(row_len: usize) -> Self {
39 assert!(
40 REP.is_power_of_two(),
41 "Repetition factor must be a power of two"
42 );
43 assert!(
44 row_len.is_power_of_two(),
45 "Row length must be a power of two"
46 );
47
48 let codeword_width_bits = {
53 let initial_bits =
54 u32::try_from(Zt::Eval::COEFF_BIT_WIDTH).expect("Size of EvalR type is too large");
55
56 let row_len_log = row_len.ilog2();
57 let rep_factor_log = REP.ilog2();
58 add!(
59 initial_bits,
60 add!(mul!(row_len_log, 2), mul!(rep_factor_log, 2))
61 )
62 };
63 let codeword_type_bits =
64 u32::try_from(Zt::Cw::COEFF_BIT_WIDTH).expect("Size of CwR type is too large");
65 assert!(
66 codeword_type_bits >= codeword_width_bits,
67 "Cannot fit {codeword_width_bits}-bit wide codeword entries in {} bits entries",
68 codeword_type_bits
69 );
70
71 const PERM_1_SEED: u64 = 1;
73 const PERM_2_SEED: u64 = 2;
74
75 let codeword_len = mul!(row_len, REP);
76
77 let mut perm_1: Vec<usize> = (0..codeword_len).collect();
78 shuffle_seeded(&mut perm_1, PERM_1_SEED);
79 let mut perm_2: Vec<usize> = (0..codeword_len).collect();
80 shuffle_seeded(&mut perm_2, PERM_2_SEED);
81
82 Self {
83 row_len,
84 perm_1_seed: PERM_1_SEED,
85 perm_2_seed: PERM_2_SEED,
86 perm_1,
87 perm_2,
88 phantom: PhantomData,
89 }
90 }
91
92 fn encode_inner<In, Out>(&self, row: &[In]) -> Vec<Out>
94 where
95 Out: CheckedAdd + for<'a> AddAssign<&'a Out> + FromRef<In> + Clone,
96 {
97 debug_assert_eq!(
98 row.len(),
99 self.row_len,
100 "Row length must match the code's row length"
101 );
102
103 let mut result: Vec<Out> = repeat(row, REP);
104 if Config::PERMUTE_IN_PLACE {
105 shuffle_seeded(&mut result, self.perm_1_seed);
106 } else {
107 result = clone_shuffled(&result, &self.perm_1);
108 }
109 if Config::CHECK_FOR_OVERFLOWS {
110 accumulate(&mut result);
111 } else {
112 accumulate_unchecked(&mut result);
113 }
114 if Config::PERMUTE_IN_PLACE {
115 shuffle_seeded(&mut result, self.perm_2_seed);
116 } else {
117 result = clone_shuffled(&result, &self.perm_2);
118 }
119 if Config::CHECK_FOR_OVERFLOWS {
120 accumulate(&mut result);
121 } else {
122 accumulate_unchecked(&mut result);
123 }
124 debug_assert_eq!(result.len(), self.codeword_len());
125 result
126 }
127}
128
129impl<Zt: ZipTypes, Config: RaaConfig, const REP: usize> LinearCode<Zt>
130 for RaaCode<Zt, Config, REP>
131{
132 const REPETITION_FACTOR: usize = REP;
133
134 fn row_len(&self) -> usize {
135 self.row_len
136 }
137
138 #[allow(clippy::arithmetic_side_effects)]
139 fn codeword_len(&self) -> usize {
140 self.row_len * REP
141 }
142
143 fn params_string(&self) -> String {
144 format!("row_len={}, rate=1/{REP}", self.row_len())
145 }
146
147 fn encode(&self, row: &[Zt::Eval]) -> Vec<Zt::Cw> {
148 self.encode_inner(row)
149 }
150
151 fn encode_wide(&self, row: &[Zt::CombR]) -> Vec<Zt::CombR> {
152 self.encode_inner(row)
153 }
154
155 fn encode_f<F>(&self, row: &[F]) -> Vec<F>
156 where
157 F: PrimeField + FromRef<F>,
158 {
159 self.encode_inner(row)
160 }
161}
162
163impl<Zt: ZipTypes, Config: RaaConfig, const REP: usize> Debug for RaaCode<Zt, Config, REP> {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 f.debug_struct("RaaCode")
166 .field("row_len", &self.row_len)
167 .field("perm_1_seed", &self.perm_1_seed)
168 .field("perm_2_seed", &self.perm_2_seed)
169 .finish()
170 }
171}
172
173impl<Zt: ZipTypes, Config: RaaConfig, const REP: usize> PartialEq for RaaCode<Zt, Config, REP> {
174 fn eq(&self, other: &Self) -> bool {
175 self.row_len == other.row_len
176 && self.perm_1_seed == other.perm_1_seed
177 && self.perm_2_seed == other.perm_2_seed
178 }
179}
180
181impl<Zt: ZipTypes, Config: RaaConfig, const REP: usize> Eq for RaaCode<Zt, Config, REP> {}
182
183#[allow(clippy::arithmetic_side_effects)]
185pub(crate) fn repeat<In, Out: FromRef<In> + Clone>(
186 input: &[In],
187 repetition_factor: usize,
188) -> Vec<Out> {
189 input
190 .iter()
191 .map(Out::from_ref)
192 .cycle()
193 .take(input.len() * repetition_factor)
194 .collect()
195}
196
197#[allow(clippy::arithmetic_side_effects)] pub(crate) fn accumulate<I>(input: &mut [I])
209where
210 I: CheckedAdd + Clone,
211{
212 if let Some(first) = input.first().cloned() {
213 let mut acc = first;
214 for curr in input.iter_mut().skip(1) {
215 acc = add!(*curr, acc, "Accumulation overflow");
216 *curr = acc.clone();
217 }
218 }
219}
220
221#[allow(clippy::arithmetic_side_effects)]
222pub(crate) fn accumulate_unchecked<I>(input: &mut [I])
223where
224 I: for<'a> AddAssign<&'a I> + Clone,
225{
226 if let Some(first) = input.first().cloned() {
227 let mut acc = first;
228 for i in 1..input.len() {
229 unsafe {
231 acc += input.get_unchecked(i);
232 *input.get_unchecked_mut(i) = acc.clone();
233 };
234 }
235 }
236}
237
238pub(crate) fn clone_shuffled<T>(data: &[T], perm: &[usize]) -> Vec<T>
240where
241 T: Clone,
242{
243 perm.iter().map(|&i| data[i].clone()).collect()
244}
245
246#[cfg(test)]
247mod tests {
248 use crypto_bigint::U64;
249 use crypto_primitives::crypto_bigint_int::Int;
250 use num_traits::Zero;
251
252 use super::*;
253 use crate::{code::LinearCode, pcs::test_utils::TestZipTypes, utils::shuffle_seeded};
254
255 const REPETITION_FACTOR: usize = 4;
256
257 const INT_LIMBS: usize = U64::LIMBS;
259
260 const N: usize = INT_LIMBS;
261 const K: usize = INT_LIMBS * 4;
262 const M: usize = INT_LIMBS * 8;
263
264 #[derive(Clone, Copy)]
265 struct RaaConfigGeneric<const PERMUTE_IN_PLACE: bool, const CHECK_FOR_OVERFLOWS: bool>;
266
267 impl<const PERMUTE_IN_PLACE: bool, const CHECK_FOR_OVERFLOWS: bool> RaaConfig
268 for RaaConfigGeneric<PERMUTE_IN_PLACE, CHECK_FOR_OVERFLOWS>
269 {
270 const PERMUTE_IN_PLACE: bool = PERMUTE_IN_PLACE;
271 const CHECK_FOR_OVERFLOWS: bool = CHECK_FOR_OVERFLOWS;
272 }
273
274 macro_rules! test_raa {
275 ($zt:ty, $row_len:expr, $f:expr) => {
276 test_raa!($zt, $row_len, $f, RaaConfigGeneric<false, false>);
277 test_raa!($zt, $row_len, $f, RaaConfigGeneric<false, true>);
278 test_raa!($zt, $row_len, $f, RaaConfigGeneric<true, false>);
279 test_raa!($zt, $row_len, $f, RaaConfigGeneric<true, true>);
280 };
281
282 ($zt:ty, $row_len:expr, $f:expr, $config:ty) => {
283 {
284 let code = RaaCode::<$zt, $config, REPETITION_FACTOR>::new($row_len);
285 $f(&code)
286 }
287 };
288 }
289
290 #[test]
291 fn repeat_function_duplicates_row_correctly() {
292 type I = Int<N>;
293 let input = [10, 20].map(I::from);
294
295 let repetition_factor = 3;
296
297 let repeated_output = repeat::<_, I>(&input, repetition_factor);
298
299 let expected_output: Vec<_> = [10, 20, 10, 20, 10, 20].into_iter().map(I::from).collect();
300 assert_eq!(
301 repeated_output, expected_output,
302 "Failed on repetition factor > 1"
303 );
304
305 let empty_input: Vec<I> = vec![];
306 let repeated_empty = repeat::<_, I>(&empty_input, 5);
307 assert!(repeated_empty.is_empty(), "Failed on empty input vector");
308
309 let repeated_once = repeat::<_, I>(&input, 1);
310 assert_eq!(repeated_once, input, "Failed on repetition factor of 1");
311 }
312
313 #[test]
314 fn accumulate_function_computes_cumulative_sum() {
315 type I = Int<N>;
316 let mut input1: Vec<I> = [1, 2, 3, 4].into_iter().map(I::from).collect();
317 let expected1: Vec<I> = [1, 3, 6, 10].into_iter().map(I::from).collect();
318 accumulate(&mut input1);
319 assert_eq!(input1, expected1, "Failed on positive integers");
320
321 let mut input1: Vec<I> = [1, 2, 3, 4].into_iter().map(I::from).collect();
322 accumulate_unchecked(&mut input1);
323 assert_eq!(input1, expected1, "Failed on positive integers");
324
325 let mut input2: Vec<I> = [5, 0, 2, 0].into_iter().map(I::from).collect();
326 let expected2: Vec<I> = [5, 5, 7, 7].into_iter().map(I::from).collect();
327 accumulate(&mut input2);
328 assert_eq!(input2, expected2, "Failed on vector with zeros");
329
330 let mut input3: Vec<I> = [-1, 5, -10, 2].into_iter().map(I::from).collect();
331 let expected3: Vec<I> = [-1, 4, -6, -4].into_iter().map(I::from).collect();
332 accumulate(&mut input3);
333 assert_eq!(input3, expected3, "Failed on vector with negative numbers");
334
335 let mut empty_input: Vec<I> = vec![];
336 let expected_empty: Vec<I> = vec![];
337 accumulate(&mut empty_input);
338 assert_eq!(empty_input, expected_empty, "Failed on empty vector");
339 }
340
341 #[test]
342 fn shuffle_is_deterministic_for_a_given_seed() {
343 type I = Int<N>;
344 let original: Vec<I> = (1..=10).map(I::from).collect();
345 let mut vec1 = original.clone();
346 let mut vec2 = original.clone();
347 let mut vec3 = original.clone();
348
349 let seed1 = 12345;
350 let seed2 = 54321;
351
352 shuffle_seeded(&mut vec1, seed1);
353 shuffle_seeded(&mut vec2, seed1);
354 shuffle_seeded(&mut vec3, seed2);
355
356 assert_eq!(
357 vec1, vec2,
358 "Shuffling with the same seed should produce the same result"
359 );
360 assert_ne!(
361 vec1, vec3,
362 "Shuffling with different seeds should produce different results"
363 );
364 assert_ne!(
365 vec1, original,
366 "Shuffled vector should not be the same as the original"
367 );
368 assert_ne!(
369 vec3, original,
370 "Shuffled vector should not be the same as the original"
371 );
372 }
373
374 #[test]
375 #[allow(clippy::arithmetic_side_effects)] fn encoding_preserves_linearity() {
377 test_raa!(TestZipTypes<N, K, M>, 4, |code: &RaaCode<_, _, _>| {
378 let a: Vec<Int<N>> = (1..=4).map(Int::<N>::from).collect();
379 let b: Vec<Int<N>> = (5..=8).map(Int::<N>::from).collect();
380 let sum_ab: Vec<Int<N>> = a.iter().zip(b.iter()).map(|(x, y)| *x + y).collect();
381
382 let encode_a: Vec<Int<K>> = code.encode(&a);
383 let encode_b: Vec<Int<K>> = code.encode(&b);
384 let encode_sum_ab: Vec<Int<K>> = code.encode(&sum_ab);
385
386 let sum_encode_ab: Vec<Int<K>> = encode_a
387 .iter()
388 .zip(encode_b.iter())
389 .map(|(x, y)| *x + y)
390 .collect();
391
392 assert_eq!(encode_sum_ab, sum_encode_ab);
393 });
394 }
395
396 #[test]
399 fn encoding_produces_predictable_results() {
400 let a: Vec<Int<N>> = (1..=4).map(Int::<N>::from).collect();
401
402 test_raa!(TestZipTypes<N, K, M>, 4, |code: &RaaCode<_, _, _>| {
403 let encode_a: Vec<Int<K>> = code.encode(&a);
404 assert_eq!(
405 encode_a,
406 [
407 0x1E, 0x36, 0x39, 0x5A, 0x70, 0x7E, 0xA5, 0xC1, 0xCB, 0xDC, 0xF9, 0x11E, 0x124,
408 0x14C, 0x14D, 0x160
409 ]
410 .map(Int::<K>::from)
411 );
412 });
413 }
414
415 #[test]
416 fn encoding_zero_vector_results_in_zero_codeword() {
417 test_raa!(TestZipTypes<N, K, M>, 4, |code: &RaaCode<_, _, _>| {
418 let zero_vector: Vec<_> = vec![Int::<N>::zero(); code.row_len()];
419 let encoded_vector: Vec<Int<K>> = code.encode(&zero_vector);
420
421 let expected_codeword: Vec<Int<K>> = vec![Int::zero(); code.codeword_len()];
422
423 assert_eq!(
424 encoded_vector, expected_codeword,
425 "Encoding a zero vector should result in a zero codeword"
426 );
427 });
428 }
429
430 #[test]
431 fn in_place_permutation_should_not_affect_order() {
432 let data: Vec<Int<N>> = (1..=1024).map(Int::<N>::from).collect();
433 let row_len = data.len();
434 let codeword_1: Vec<Int<K>> = {
435 let code_in_place = RaaCode::<
436 TestZipTypes<N, K, M>,
437 RaaConfigGeneric<true, true>,
438 REPETITION_FACTOR,
439 >::new(row_len);
440 code_in_place.encode(&data)
441 };
442
443 let codeword_2: Vec<Int<K>> = {
444 let code_cloning = RaaCode::<
445 TestZipTypes<N, K, M>,
446 RaaConfigGeneric<false, true>,
447 REPETITION_FACTOR,
448 >::new(row_len);
449 code_cloning.encode(&data)
450 };
451 assert_eq!(
452 codeword_1, codeword_2,
453 "In-place permutation should not affect the final codeword"
454 );
455 }
456
457 #[test]
458 #[should_panic]
459 fn constructor_panics_on_insufficient_codeword_width() {
460 const N: usize = 1;
461 const K: usize = 1;
462
463 let _code = RaaCode::<
464 TestZipTypes<N, K, N>,
465 RaaConfigGeneric<false, true>,
466 REPETITION_FACTOR,
467 >::new(1 << 15);
468 }
469
470 #[test]
471 #[should_panic(expected = "Row length must match the code's row length")]
472 #[cfg(debug_assertions)]
473 fn encode_panics_on_mismatched_row_length() {
474 test_raa!(TestZipTypes<N, K, M>, 4, |code: &RaaCode<_, _, _>| {
475 let incorrect_row = vec![Int::<N>::from(1), Int::<N>::from(2), Int::<N>::from(3)];
476 let _: Vec<Int<K>> = code.encode(&incorrect_row);
477 });
478 }
479}