Skip to main content

zip_plus/code/
raa.rs

1use crate::{code::LinearCode, pcs::structs::ZipTypes, utils::shuffle_seeded};
2use crypto_primitives::PrimeField;
3use num_traits::CheckedAdd;
4use std::{fmt::Debug, marker::PhantomData, ops::AddAssign};
5use zinc_poly::ConstCoeffBitWidth;
6use zinc_utils::{add, from_ref::FromRef, mul};
7
8pub trait RaaConfig: Copy + Send + Sync {
9    /// Whether to permute the codeword in place, instead of copying it using a
10    /// precomputed permutation.
11    const PERMUTE_IN_PLACE: bool;
12    /// Whether to check for overflows during encoding
13    // TODO: Unify with `CHECK_FOR_OVERFLOW` in `zinc_poly`
14    const CHECK_FOR_OVERFLOWS: bool;
15}
16
17/// Implementation of a repeat-accumulate-accumulate (RAA) codes over the binary
18/// field, as defined by the Blaze paper (https://eprint.iacr.org/2024/1609)
19#[derive(Clone)]
20pub struct RaaCode<Zt: ZipTypes, Config: RaaConfig, const REP: usize> {
21    pub(crate) row_len: usize,
22    /// Randomness seed for the first permutation
23    pub(crate) perm_1_seed: u64,
24
25    /// Randomness seed for the second permutation
26    pub(crate) perm_2_seed: u64,
27
28    /// First permutation
29    pub(crate) perm_1: Vec<usize>,
30
31    /// Second permutation
32    pub(crate) perm_2: Vec<usize>,
33
34    phantom: PhantomData<(Zt, Config)>,
35}
36
37impl<Zt: ZipTypes, Config: RaaConfig, const REP: usize> RaaCode<Zt, Config, REP> {
38    pub fn new(row_len: usize) -> Self {
39        assert!(
40            REP.is_power_of_two(),
41            "Repetition factor must be a power of two"
42        );
43        assert!(
44            row_len.is_power_of_two(),
45            "Row length must be a power of two"
46        );
47
48        // Width of each entry in codeword vector, in bits.
49        // For RAA it's initial_bits + 2*log2(codeword_len),
50        // where codeword_len = row_len * REP and the factor of 2
51        // comes from the two accumulation steps.
52        let codeword_width_bits = {
53            let initial_bits =
54                u32::try_from(Zt::Eval::COEFF_BIT_WIDTH).expect("Size of EvalR type is too large");
55
56            let row_len_log = row_len.ilog2();
57            let rep_factor_log = REP.ilog2();
58            add!(
59                initial_bits,
60                add!(mul!(row_len_log, 2), mul!(rep_factor_log, 2))
61            )
62        };
63        let codeword_type_bits =
64            u32::try_from(Zt::Cw::COEFF_BIT_WIDTH).expect("Size of CwR type is too large");
65        assert!(
66            codeword_type_bits >= codeword_width_bits,
67            "Cannot fit {codeword_width_bits}-bit wide codeword entries in {} bits entries",
68            codeword_type_bits
69        );
70
71        // We don't need a secure/unpredictable randomness here, so use fixed seeds
72        const PERM_1_SEED: u64 = 1;
73        const PERM_2_SEED: u64 = 2;
74
75        let codeword_len = mul!(row_len, REP);
76
77        let mut perm_1: Vec<usize> = (0..codeword_len).collect();
78        shuffle_seeded(&mut perm_1, PERM_1_SEED);
79        let mut perm_2: Vec<usize> = (0..codeword_len).collect();
80        shuffle_seeded(&mut perm_2, PERM_2_SEED);
81
82        Self {
83            row_len,
84            perm_1_seed: PERM_1_SEED,
85            perm_2_seed: PERM_2_SEED,
86            perm_1,
87            perm_2,
88            phantom: PhantomData,
89        }
90    }
91
92    /// Do the actual encoding, as per RAA spec
93    fn encode_inner<In, Out>(&self, row: &[In]) -> Vec<Out>
94    where
95        Out: CheckedAdd + for<'a> AddAssign<&'a Out> + FromRef<In> + Clone,
96    {
97        debug_assert_eq!(
98            row.len(),
99            self.row_len,
100            "Row length must match the code's row length"
101        );
102
103        let mut result: Vec<Out> = repeat(row, REP);
104        if Config::PERMUTE_IN_PLACE {
105            shuffle_seeded(&mut result, self.perm_1_seed);
106        } else {
107            result = clone_shuffled(&result, &self.perm_1);
108        }
109        if Config::CHECK_FOR_OVERFLOWS {
110            accumulate(&mut result);
111        } else {
112            accumulate_unchecked(&mut result);
113        }
114        if Config::PERMUTE_IN_PLACE {
115            shuffle_seeded(&mut result, self.perm_2_seed);
116        } else {
117            result = clone_shuffled(&result, &self.perm_2);
118        }
119        if Config::CHECK_FOR_OVERFLOWS {
120            accumulate(&mut result);
121        } else {
122            accumulate_unchecked(&mut result);
123        }
124        debug_assert_eq!(result.len(), self.codeword_len());
125        result
126    }
127}
128
129impl<Zt: ZipTypes, Config: RaaConfig, const REP: usize> LinearCode<Zt>
130    for RaaCode<Zt, Config, REP>
131{
132    const REPETITION_FACTOR: usize = REP;
133
134    fn row_len(&self) -> usize {
135        self.row_len
136    }
137
138    #[allow(clippy::arithmetic_side_effects)]
139    fn codeword_len(&self) -> usize {
140        self.row_len * REP
141    }
142
143    fn params_string(&self) -> String {
144        format!("row_len={}, rate=1/{REP}", self.row_len())
145    }
146
147    fn encode(&self, row: &[Zt::Eval]) -> Vec<Zt::Cw> {
148        self.encode_inner(row)
149    }
150
151    fn encode_wide(&self, row: &[Zt::CombR]) -> Vec<Zt::CombR> {
152        self.encode_inner(row)
153    }
154
155    fn encode_f<F>(&self, row: &[F]) -> Vec<F>
156    where
157        F: PrimeField + FromRef<F>,
158    {
159        self.encode_inner(row)
160    }
161}
162
163impl<Zt: ZipTypes, Config: RaaConfig, const REP: usize> Debug for RaaCode<Zt, Config, REP> {
164    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165        f.debug_struct("RaaCode")
166            .field("row_len", &self.row_len)
167            .field("perm_1_seed", &self.perm_1_seed)
168            .field("perm_2_seed", &self.perm_2_seed)
169            .finish()
170    }
171}
172
173impl<Zt: ZipTypes, Config: RaaConfig, const REP: usize> PartialEq for RaaCode<Zt, Config, REP> {
174    fn eq(&self, other: &Self) -> bool {
175        self.row_len == other.row_len
176            && self.perm_1_seed == other.perm_1_seed
177            && self.perm_2_seed == other.perm_2_seed
178    }
179}
180
181impl<Zt: ZipTypes, Config: RaaConfig, const REP: usize> Eq for RaaCode<Zt, Config, REP> {}
182
183/// Repeat the given slice N times, e.g `[1,2,3] => [1,2,3,1,2,3]`
184#[allow(clippy::arithmetic_side_effects)]
185pub(crate) fn repeat<In, Out: FromRef<In> + Clone>(
186    input: &[In],
187    repetition_factor: usize,
188) -> Vec<Out> {
189    input
190        .iter()
191        .map(Out::from_ref)
192        .cycle()
193        .take(input.len() * repetition_factor)
194        .collect()
195}
196
197/// Perform an operation equivalent to multiplying the slice in-place by the
198/// accumulation matrix from the RAA code - a lower triangular matrix of the
199/// appropriate size, i.e. a matrix looking like this:
200///
201/// ```text
202/// 1 0 0 0
203/// 1 1 0 0
204/// 1 1 1 0
205/// 1 1 1 1
206/// ```
207#[allow(clippy::arithmetic_side_effects)] // Clippy is too dumb to realize `i - 1` is safe here
208pub(crate) fn accumulate<I>(input: &mut [I])
209where
210    I: CheckedAdd + Clone,
211{
212    if let Some(first) = input.first().cloned() {
213        let mut acc = first;
214        for curr in input.iter_mut().skip(1) {
215            acc = add!(*curr, acc, "Accumulation overflow");
216            *curr = acc.clone();
217        }
218    }
219}
220
221#[allow(clippy::arithmetic_side_effects)]
222pub(crate) fn accumulate_unchecked<I>(input: &mut [I])
223where
224    I: for<'a> AddAssign<&'a I> + Clone,
225{
226    if let Some(first) = input.first().cloned() {
227        let mut acc = first;
228        for i in 1..input.len() {
229            // Avoid bound checking
230            unsafe {
231                acc += input.get_unchecked(i);
232                *input.get_unchecked_mut(i) = acc.clone();
233            };
234        }
235    }
236}
237
238/// Clone the data using a precomputed permutation.
239pub(crate) fn clone_shuffled<T>(data: &[T], perm: &[usize]) -> Vec<T>
240where
241    T: Clone,
242{
243    perm.iter().map(|&i| data[i].clone()).collect()
244}
245
246#[cfg(test)]
247mod tests {
248    use crypto_bigint::U64;
249    use crypto_primitives::crypto_bigint_int::Int;
250    use num_traits::Zero;
251
252    use super::*;
253    use crate::{code::LinearCode, pcs::test_utils::TestZipTypes, utils::shuffle_seeded};
254
255    const REPETITION_FACTOR: usize = 4;
256
257    // Define common types for testing
258    const INT_LIMBS: usize = U64::LIMBS;
259
260    const N: usize = INT_LIMBS;
261    const K: usize = INT_LIMBS * 4;
262    const M: usize = INT_LIMBS * 8;
263
264    #[derive(Clone, Copy)]
265    struct RaaConfigGeneric<const PERMUTE_IN_PLACE: bool, const CHECK_FOR_OVERFLOWS: bool>;
266
267    impl<const PERMUTE_IN_PLACE: bool, const CHECK_FOR_OVERFLOWS: bool> RaaConfig
268        for RaaConfigGeneric<PERMUTE_IN_PLACE, CHECK_FOR_OVERFLOWS>
269    {
270        const PERMUTE_IN_PLACE: bool = PERMUTE_IN_PLACE;
271        const CHECK_FOR_OVERFLOWS: bool = CHECK_FOR_OVERFLOWS;
272    }
273
274    macro_rules! test_raa {
275        ($zt:ty, $row_len:expr, $f:expr) => {
276            test_raa!($zt, $row_len, $f, RaaConfigGeneric<false, false>);
277            test_raa!($zt, $row_len, $f, RaaConfigGeneric<false, true>);
278            test_raa!($zt, $row_len, $f, RaaConfigGeneric<true, false>);
279            test_raa!($zt, $row_len, $f, RaaConfigGeneric<true, true>);
280        };
281
282        ($zt:ty, $row_len:expr, $f:expr, $config:ty) => {
283            {
284                let code = RaaCode::<$zt, $config, REPETITION_FACTOR>::new($row_len);
285                $f(&code)
286            }
287        };
288    }
289
290    #[test]
291    fn repeat_function_duplicates_row_correctly() {
292        type I = Int<N>;
293        let input = [10, 20].map(I::from);
294
295        let repetition_factor = 3;
296
297        let repeated_output = repeat::<_, I>(&input, repetition_factor);
298
299        let expected_output: Vec<_> = [10, 20, 10, 20, 10, 20].into_iter().map(I::from).collect();
300        assert_eq!(
301            repeated_output, expected_output,
302            "Failed on repetition factor > 1"
303        );
304
305        let empty_input: Vec<I> = vec![];
306        let repeated_empty = repeat::<_, I>(&empty_input, 5);
307        assert!(repeated_empty.is_empty(), "Failed on empty input vector");
308
309        let repeated_once = repeat::<_, I>(&input, 1);
310        assert_eq!(repeated_once, input, "Failed on repetition factor of 1");
311    }
312
313    #[test]
314    fn accumulate_function_computes_cumulative_sum() {
315        type I = Int<N>;
316        let mut input1: Vec<I> = [1, 2, 3, 4].into_iter().map(I::from).collect();
317        let expected1: Vec<I> = [1, 3, 6, 10].into_iter().map(I::from).collect();
318        accumulate(&mut input1);
319        assert_eq!(input1, expected1, "Failed on positive integers");
320
321        let mut input1: Vec<I> = [1, 2, 3, 4].into_iter().map(I::from).collect();
322        accumulate_unchecked(&mut input1);
323        assert_eq!(input1, expected1, "Failed on positive integers");
324
325        let mut input2: Vec<I> = [5, 0, 2, 0].into_iter().map(I::from).collect();
326        let expected2: Vec<I> = [5, 5, 7, 7].into_iter().map(I::from).collect();
327        accumulate(&mut input2);
328        assert_eq!(input2, expected2, "Failed on vector with zeros");
329
330        let mut input3: Vec<I> = [-1, 5, -10, 2].into_iter().map(I::from).collect();
331        let expected3: Vec<I> = [-1, 4, -6, -4].into_iter().map(I::from).collect();
332        accumulate(&mut input3);
333        assert_eq!(input3, expected3, "Failed on vector with negative numbers");
334
335        let mut empty_input: Vec<I> = vec![];
336        let expected_empty: Vec<I> = vec![];
337        accumulate(&mut empty_input);
338        assert_eq!(empty_input, expected_empty, "Failed on empty vector");
339    }
340
341    #[test]
342    fn shuffle_is_deterministic_for_a_given_seed() {
343        type I = Int<N>;
344        let original: Vec<I> = (1..=10).map(I::from).collect();
345        let mut vec1 = original.clone();
346        let mut vec2 = original.clone();
347        let mut vec3 = original.clone();
348
349        let seed1 = 12345;
350        let seed2 = 54321;
351
352        shuffle_seeded(&mut vec1, seed1);
353        shuffle_seeded(&mut vec2, seed1);
354        shuffle_seeded(&mut vec3, seed2);
355
356        assert_eq!(
357            vec1, vec2,
358            "Shuffling with the same seed should produce the same result"
359        );
360        assert_ne!(
361            vec1, vec3,
362            "Shuffling with different seeds should produce different results"
363        );
364        assert_ne!(
365            vec1, original,
366            "Shuffled vector should not be the same as the original"
367        );
368        assert_ne!(
369            vec3, original,
370            "Shuffled vector should not be the same as the original"
371        );
372    }
373
374    #[test]
375    #[allow(clippy::arithmetic_side_effects)] // False alert
376    fn encoding_preserves_linearity() {
377        test_raa!(TestZipTypes<N, K, M>, 4, |code: &RaaCode<_, _, _>| {
378            let a: Vec<Int<N>> = (1..=4).map(Int::<N>::from).collect();
379            let b: Vec<Int<N>> = (5..=8).map(Int::<N>::from).collect();
380            let sum_ab: Vec<Int<N>> = a.iter().zip(b.iter()).map(|(x, y)| *x + y).collect();
381
382            let encode_a: Vec<Int<K>> = code.encode(&a);
383            let encode_b: Vec<Int<K>> = code.encode(&b);
384            let encode_sum_ab: Vec<Int<K>> = code.encode(&sum_ab);
385
386            let sum_encode_ab: Vec<Int<K>> = encode_a
387                .iter()
388                .zip(encode_b.iter())
389                .map(|(x, y)| *x + y)
390                .collect();
391
392            assert_eq!(encode_sum_ab, sum_encode_ab);
393        });
394    }
395
396    /// Since our shuffle seeds are fixed, we can test the encoding
397    /// against a known output.
398    #[test]
399    fn encoding_produces_predictable_results() {
400        let a: Vec<Int<N>> = (1..=4).map(Int::<N>::from).collect();
401
402        test_raa!(TestZipTypes<N, K, M>, 4, |code: &RaaCode<_, _, _>| {
403            let encode_a: Vec<Int<K>> = code.encode(&a);
404            assert_eq!(
405                encode_a,
406                [
407                    0x1E, 0x36, 0x39, 0x5A, 0x70, 0x7E, 0xA5, 0xC1, 0xCB, 0xDC, 0xF9, 0x11E, 0x124,
408                    0x14C, 0x14D, 0x160
409                ]
410                .map(Int::<K>::from)
411            );
412        });
413    }
414
415    #[test]
416    fn encoding_zero_vector_results_in_zero_codeword() {
417        test_raa!(TestZipTypes<N, K, M>, 4, |code: &RaaCode<_, _, _>| {
418            let zero_vector: Vec<_> = vec![Int::<N>::zero(); code.row_len()];
419            let encoded_vector: Vec<Int<K>> = code.encode(&zero_vector);
420
421            let expected_codeword: Vec<Int<K>> = vec![Int::zero(); code.codeword_len()];
422
423            assert_eq!(
424                encoded_vector, expected_codeword,
425                "Encoding a zero vector should result in a zero codeword"
426            );
427        });
428    }
429
430    #[test]
431    fn in_place_permutation_should_not_affect_order() {
432        let data: Vec<Int<N>> = (1..=1024).map(Int::<N>::from).collect();
433        let row_len = data.len();
434        let codeword_1: Vec<Int<K>> = {
435            let code_in_place = RaaCode::<
436                TestZipTypes<N, K, M>,
437                RaaConfigGeneric<true, true>,
438                REPETITION_FACTOR,
439            >::new(row_len);
440            code_in_place.encode(&data)
441        };
442
443        let codeword_2: Vec<Int<K>> = {
444            let code_cloning = RaaCode::<
445                TestZipTypes<N, K, M>,
446                RaaConfigGeneric<false, true>,
447                REPETITION_FACTOR,
448            >::new(row_len);
449            code_cloning.encode(&data)
450        };
451        assert_eq!(
452            codeword_1, codeword_2,
453            "In-place permutation should not affect the final codeword"
454        );
455    }
456
457    #[test]
458    #[should_panic]
459    fn constructor_panics_on_insufficient_codeword_width() {
460        const N: usize = 1;
461        const K: usize = 1;
462
463        let _code = RaaCode::<
464            TestZipTypes<N, K, N>,
465            RaaConfigGeneric<false, true>,
466            REPETITION_FACTOR,
467        >::new(1 << 15);
468    }
469
470    #[test]
471    #[should_panic(expected = "Row length must match the code's row length")]
472    #[cfg(debug_assertions)]
473    fn encode_panics_on_mismatched_row_length() {
474        test_raa!(TestZipTypes<N, K, M>, 4, |code: &RaaCode<_, _, _>| {
475            let incorrect_row = vec![Int::<N>::from(1), Int::<N>::from(2), Int::<N>::from(3)];
476            let _: Vec<Int<K>> = code.encode(&incorrect_row);
477        });
478    }
479}