Skip to main content

zip_plus/code/
raa_sign_flip.rs

1use super::raa::*;
2use crate::{code::LinearCode, pcs::structs::ZipTypes, utils::shuffle_seeded};
3use crypto_primitives::{PrimeField, Ring};
4use num_traits::{CheckedAdd, CheckedNeg};
5use std::{
6    fmt::Debug,
7    ops::{AddAssign, Neg},
8};
9use zinc_utils::{from_ref::FromRef, neg};
10
11/// Implementation of a repeat-accumulate-accumulate (RAA) codes.
12/// Flips signs of every second entry in the codeword, starting from the second
13/// one.
14#[derive(Clone)]
15pub struct RaaSignFlippingCode<Zt: ZipTypes, Config: RaaConfig, const REP: usize> {
16    raa: RaaCode<Zt, Config, REP>,
17}
18
19impl<Zt: ZipTypes, Config: RaaConfig, const REP: usize> RaaSignFlippingCode<Zt, Config, REP>
20where
21    Zt::Cw: Ring,
22{
23    pub fn new(row_len: usize) -> Self {
24        Self {
25            raa: RaaCode::new(row_len),
26        }
27    }
28
29    /// Do the actual encoding, as per RAA spec
30    fn encode_inner<In, Out>(&self, row: &[In]) -> Vec<Out>
31    where
32        Out: Neg<Output = Out>
33            + CheckedNeg
34            + CheckedAdd
35            + for<'a> AddAssign<&'a Out>
36            + FromRef<In>
37            + Clone,
38    {
39        debug_assert_eq!(
40            row.len(),
41            self.raa.row_len,
42            "Row length must match the code's row length"
43        );
44
45        let mut result: Vec<Out> = repeat(row, REP);
46        flip_even_signs(&mut result, Config::CHECK_FOR_OVERFLOWS);
47        if Config::PERMUTE_IN_PLACE {
48            shuffle_seeded(&mut result, self.raa.perm_1_seed);
49        } else {
50            result = clone_shuffled(&result, &self.raa.perm_1);
51        }
52        if Config::CHECK_FOR_OVERFLOWS {
53            accumulate(&mut result);
54        } else {
55            accumulate_unchecked(&mut result);
56        }
57        flip_even_signs(&mut result, Config::CHECK_FOR_OVERFLOWS);
58        if Config::PERMUTE_IN_PLACE {
59            shuffle_seeded(&mut result, self.raa.perm_2_seed);
60        } else {
61            result = clone_shuffled(&result, &self.raa.perm_2);
62        }
63        if Config::CHECK_FOR_OVERFLOWS {
64            accumulate(&mut result);
65        } else {
66            accumulate_unchecked(&mut result);
67        }
68        debug_assert_eq!(result.len(), self.codeword_len());
69        result
70    }
71}
72
73impl<Zt: ZipTypes, Config: RaaConfig, const REP: usize> LinearCode<Zt>
74    for RaaSignFlippingCode<Zt, Config, REP>
75where
76    Zt::Cw: Ring,
77{
78    const REPETITION_FACTOR: usize = REP;
79
80    fn row_len(&self) -> usize {
81        self.raa.row_len()
82    }
83
84    #[allow(clippy::arithmetic_side_effects)]
85    fn codeword_len(&self) -> usize {
86        self.raa.codeword_len()
87    }
88
89    fn params_string(&self) -> String {
90        self.raa.params_string()
91    }
92
93    fn encode(&self, row: &[Zt::Eval]) -> Vec<Zt::Cw> {
94        self.encode_inner(row)
95    }
96
97    fn encode_wide(&self, row: &[Zt::CombR]) -> Vec<Zt::CombR> {
98        self.encode_inner(row)
99    }
100
101    fn encode_f<F>(&self, row: &[F]) -> Vec<F>
102    where
103        F: PrimeField + FromRef<F>,
104    {
105        self.encode_inner(row)
106    }
107}
108
109impl<Zt: ZipTypes, Config: RaaConfig, const REP: usize> Debug
110    for RaaSignFlippingCode<Zt, Config, REP>
111{
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        f.debug_struct("SignFlipping")
114            .field("row_len", &self.raa.row_len)
115            .field("perm_1_seed", &self.raa.perm_1_seed)
116            .field("perm_2_seed", &self.raa.perm_2_seed)
117            .finish()
118    }
119}
120
121impl<Zt: ZipTypes, Config: RaaConfig, const REP: usize> PartialEq
122    for RaaSignFlippingCode<Zt, Config, REP>
123{
124    fn eq(&self, other: &Self) -> bool {
125        self.raa == other.raa
126    }
127}
128
129impl<Zt: ZipTypes, Config: RaaConfig, const REP: usize> Eq
130    for RaaSignFlippingCode<Zt, Config, REP>
131{
132}
133
134/// Flip every other entry in the codeword, starting from the second one.
135fn flip_even_signs<Out>(result: &mut [Out], check_for_overflows: bool)
136where
137    Out: Neg<Output = Out> + CheckedNeg + Clone,
138{
139    if check_for_overflows {
140        flip_even_signs_checked(result);
141    } else {
142        flip_even_signs_unchecked(result);
143    }
144}
145
146fn flip_even_signs_checked<Out>(result: &mut [Out])
147where
148    Out: CheckedNeg + Clone,
149{
150    // Flip every other entry in the codeword
151    for i in (1..result.len()).step_by(2) {
152        result[i] = neg!(result[i]);
153    }
154}
155
156/// Flip every other entry in the codeword, starting from the second one.
157fn flip_even_signs_unchecked<Out>(result: &mut [Out])
158where
159    Out: Neg<Output = Out> + Clone,
160{
161    // Flip every other entry in the codeword
162    for i in (1..result.len()).step_by(2) {
163        result[i] = result[i].clone().neg();
164    }
165}