Skip to main content

zip_plus/code/
iprs.rs

1mod pntt;
2
3use crate::{ZipError, code::LinearCode, pcs::structs::ZipTypes};
4use crypto_primitives::{FromPrimitiveWithConfig, FromWithConfig};
5use num_traits::{CheckedAdd, CheckedMul};
6use pntt::radix8::params::Config as PnttConfig;
7pub use pntt::radix8::params::{PnttConfigF65537, PnttInt, Radix8PnttParams};
8use std::{
9    fmt::Debug,
10    iter::Sum,
11    marker::PhantomData,
12    ops::{Add, AddAssign},
13};
14use zinc_utils::{from_ref::FromRef, mul_by_scalar::MulByScalar};
15
16/// Pseudo Reed-Solomon encoder over the integers. Internally uses a
17/// radix-8 NTT-style recursion with a base Vandermonde matrix sized
18/// `base_len x base_dim` (defaults to 64x32).
19#[derive(Clone)]
20pub struct IprsCode<Zt: ZipTypes, Config: PnttConfig, const REP: usize, const CHECK: bool> {
21    pntt_params: Radix8PnttParams<Config>,
22    _phantom: PhantomData<Zt>,
23}
24
25impl<Zt, Config, const REP: usize, const CHECK: bool> IprsCode<Zt, Config, REP, CHECK>
26where
27    Zt: ZipTypes,
28    Config: PnttConfig,
29{
30    pub fn new(row_len: usize, depth: usize) -> Result<Self, ZipError> {
31        // TODO(alex): Calculate max expected Zt::Cw::COEFF_BIT_WIDTH to ensure in
32        //             advance that the encoding will not overflow
33        Ok(Self {
34            pntt_params: Radix8PnttParams::new(row_len, depth, REP)?,
35            _phantom: Default::default(),
36        })
37    }
38
39    /// Create a new IPRS code with the optimal depth heuristics trying to keep
40    /// number of columns in the base matrix small.
41    /// Currently, keeps number of columns <= 2^8 but this might be tweaked in
42    /// the future.
43    pub fn new_with_optimal_depth(row_len: usize) -> Result<Self, ZipError> {
44        const MAX_BASE_COLS_LOG2: usize = 8;
45
46        let target_base_len = 1 << MAX_BASE_COLS_LOG2;
47        // We want depth to be at least 1.
48        let depth = 1.max(((1.max(row_len / target_base_len)).ilog2() as usize).div_ceil(3));
49
50        Self::new(row_len, depth)
51    }
52
53    /// Encode without modular reduction, purely over the integers.
54    fn encode_inner<In, Out>(&self, row: &[In]) -> Vec<Out>
55    where
56        In: for<'a> MulByScalar<&'a PnttInt, Out> + Clone + Send + Sync,
57        Out: CheckedAdd
58            + for<'a> AddAssign<&'a Out>
59            + for<'a> Add<&'a Out, Output = Out>
60            + CheckedMul
61            + for<'a> MulByScalar<&'a PnttInt>
62            + Sum
63            + FromRef<In>
64            + Clone
65            + Debug
66            + Send
67            + Sync,
68    {
69        assert_eq!(
70            row.len(),
71            self.pntt_params.row_len,
72            "Input length {} does not match expected row length {}",
73            row.len(),
74            self.pntt_params.row_len,
75        );
76
77        macro_rules! mul_fn {
78            () => {
79                |v, tw| {
80                    v.mul_by_scalar::<CHECK>(tw)
81                        .expect("Multiplication by twiddle should not overflow")
82                }
83            };
84        }
85
86        pntt::radix8::pntt::<_, _, _, CHECK>(row, &self.pntt_params, mul_fn!(), mul_fn!())
87    }
88
89    // Do the encoding but make use of the fact
90    // that we are dealing with a field.
91    fn encode_inner_f<F>(&self, row: &[F]) -> Vec<F>
92    where
93        F: FromWithConfig<PnttInt> + FromRef<F>,
94    {
95        assert_eq!(
96            row.len(),
97            self.pntt_params.row_len,
98            "Input length {} does not match expected row length {}",
99            row.len(),
100            self.pntt_params.row_len,
101        );
102
103        let mul_fn = |f: &F, tw: &PnttInt| f.clone() * F::from_with_cfg(*tw, f.cfg());
104
105        pntt::radix8::pntt::<_, _, _, CHECK>(row, &self.pntt_params, mul_fn, mul_fn)
106    }
107}
108
109impl<Zt: ZipTypes, Config, const REP: usize, const CHECK: bool> LinearCode<Zt>
110    for IprsCode<Zt, Config, REP, CHECK>
111where
112    Zt: ZipTypes,
113    Config: PnttConfig,
114    Zt::Eval: for<'a> MulByScalar<&'a PnttInt, Zt::Cw>,
115    Zt::CombR: for<'a> MulByScalar<&'a PnttInt>,
116    Zt::Cw: CheckedAdd + for<'a> MulByScalar<&'a PnttInt>,
117{
118    const REPETITION_FACTOR: usize = REP;
119
120    fn encode(&self, row: &[Zt::Eval]) -> Vec<Zt::Cw> {
121        assert_eq!(
122            row.len(),
123            self.pntt_params.row_len,
124            "Input length {} does not match expected row length {}",
125            row.len(),
126            self.pntt_params.row_len,
127        );
128
129        self.encode_inner(row)
130    }
131
132    fn row_len(&self) -> usize {
133        self.pntt_params.row_len
134    }
135
136    fn codeword_len(&self) -> usize {
137        self.pntt_params.codeword_len
138    }
139
140    fn params_string(&self) -> String {
141        format!(
142            "row_len={}, rate=1/{REP}, depth={}",
143            self.row_len(),
144            self.pntt_params.depth
145        )
146    }
147
148    fn encode_wide(&self, row: &[Zt::CombR]) -> Vec<Zt::CombR> {
149        self.encode_inner(row)
150    }
151
152    fn encode_f<F>(&self, row: &[F]) -> Vec<F>
153    where
154        F: FromPrimitiveWithConfig + FromRef<F>,
155    {
156        self.encode_inner_f(row)
157    }
158}
159
160impl<Zt, Config, const REP: usize, const CHECK: bool> Debug for IprsCode<Zt, Config, REP, CHECK>
161where
162    Zt: ZipTypes,
163    Config: PnttConfig,
164{
165    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166        f.debug_struct("IprsCode")
167            .field("pntt_params", &self.pntt_params)
168            .finish()
169    }
170}
171
172impl<Zt, Config, const REP: usize, const CHECK: bool> PartialEq for IprsCode<Zt, Config, REP, CHECK>
173where
174    Config: PnttConfig,
175    Zt: ZipTypes,
176{
177    fn eq(&self, other: &Self) -> bool {
178        self.pntt_params == other.pntt_params
179    }
180}
181
182impl<Zt, Config, const REP: usize, const CHECK: bool> Eq for IprsCode<Zt, Config, REP, CHECK>
183where
184    Zt: ZipTypes,
185    Config: PnttConfig,
186{
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192    use crate::pcs::{structs::ZipPlus, test_utils::*};
193    use crypto_bigint::U64;
194    use crypto_primitives::{
195        FixedSemiring, boolean::Boolean, crypto_bigint_int::Int, crypto_bigint_uint::Uint,
196    };
197    use rand::{
198        distr::{Distribution, StandardUniform},
199        prelude::ThreadRng,
200    };
201    use zinc_poly::{
202        mle::{DenseMultilinearExtension, MultilinearExtensionRand},
203        univariate::{
204            binary::{BinaryPoly, BinaryPolyInnerProduct},
205            dense::{DensePolyInnerProduct, DensePolynomial},
206        },
207    };
208    use zinc_primality::MillerRabin;
209    use zinc_transcript::traits::ConstTranscribable;
210    use zinc_utils::{
211        CHECKED,
212        inner_product::{MBSInnerProduct, ScalarProduct},
213        named::Named,
214    };
215
216    const INT_LIMBS: usize = U64::LIMBS;
217    const N: usize = INT_LIMBS;
218    const K: usize = INT_LIMBS * 4;
219    const M: usize = INT_LIMBS * 8;
220    type Zt = TestZipTypes<N, K, M>;
221
222    type Code = IprsCode<Zt, PnttConfigF65537, REP_FACTOR, CHECKED>;
223
224    #[test]
225    fn new_with_different_params() {
226        assert!(Code::new(1, 0).is_ok());
227        assert!(Code::new(8, 0).is_ok());
228        assert!(Code::new(1, 1).is_err());
229        assert!(Code::new(8, 1).is_ok());
230
231        assert!(Code::new_with_optimal_depth(1).is_err());
232        assert!(Code::new_with_optimal_depth(8).is_ok());
233        assert!(Code::new_with_optimal_depth(12).is_err());
234        assert!(Code::new_with_optimal_depth(16).is_ok());
235    }
236
237    fn do_encode<Zt, const REP: usize>(num_vars: usize)
238    where
239        Zt: ZipTypes,
240        Zt::Eval: for<'a> MulByScalar<&'a PnttInt, Zt::Cw>,
241        Zt::CombR: for<'a> MulByScalar<&'a PnttInt>,
242        Zt::Cw: CheckedAdd + for<'a> MulByScalar<&'a PnttInt>,
243        StandardUniform: Distribution<Zt::Eval>,
244    {
245        let mut rng = ThreadRng::default();
246        let poly_size: usize = 1 << num_vars;
247        let mle = DenseMultilinearExtension::rand(num_vars, &mut rng);
248
249        let code = IprsCode::<Zt, PnttConfigF65537, 4, CHECKED>::new_with_optimal_depth(poly_size)
250            .unwrap();
251        let pp = ZipPlus::setup(poly_size, code);
252        ZipPlus::<Zt, _>::encode_rows(&pp, &mle.evaluations);
253    }
254
255    /// Test the widest integer encoding used in benchmarks
256    #[test]
257    fn encode_bench_int() {
258        #[derive(Clone, Debug)]
259        struct BenchZipTypes {}
260        impl ZipTypes for BenchZipTypes {
261            const NUM_COLUMN_OPENINGS: usize = 147;
262            type Eval = i32;
263            type Cw = i128;
264            type Fmod = Uint<{ INT_LIMBS * 4 }>;
265            type PrimeTest = MillerRabin;
266            type Chal = i128;
267            type Pt = i128;
268            type CombR = Int<{ INT_LIMBS * 3 }>;
269            type Comb = Self::CombR;
270            type EvalDotChal = ScalarProduct;
271            type CombDotChal = ScalarProduct;
272            type ArrCombRDotChal = MBSInnerProduct;
273        }
274
275        do_encode::<BenchZipTypes, 4>(14);
276    }
277
278    /// Test the widest binary polynomial encoding used in benchmarks
279    #[test]
280    fn encode_bench_poly() {
281        const D_PLUS_ONE: usize = 32;
282
283        #[derive(Clone, Debug)]
284        struct BenchZipPlusTypes<CwCoeff>(PhantomData<CwCoeff>);
285        impl<CwCoeff> ZipTypes for BenchZipPlusTypes<CwCoeff>
286        where
287            CwCoeff: ConstTranscribable
288                + Copy
289                + Default
290                + FromRef<Boolean>
291                + Named
292                + FixedSemiring
293                + Send
294                + Sync,
295            Int<5>: FromRef<CwCoeff>,
296        {
297            const NUM_COLUMN_OPENINGS: usize = 147;
298            type Eval = BinaryPoly<D_PLUS_ONE>;
299            type Cw = DensePolynomial<CwCoeff, D_PLUS_ONE>;
300            type Fmod = Uint<{ INT_LIMBS * 4 }>;
301            type PrimeTest = MillerRabin;
302            type Chal = i128;
303            type Pt = i128;
304            type CombR = Int<{ INT_LIMBS * 5 }>;
305            type Comb = DensePolynomial<Self::CombR, D_PLUS_ONE>;
306            type EvalDotChal = BinaryPolyInnerProduct<Self::Chal, D_PLUS_ONE>;
307            type CombDotChal = DensePolyInnerProduct<
308                Self::CombR,
309                Self::Chal,
310                Self::CombR,
311                MBSInnerProduct,
312                D_PLUS_ONE,
313            >;
314            type ArrCombRDotChal = MBSInnerProduct;
315        }
316
317        do_encode::<BenchZipPlusTypes<i64>, 4>(14);
318    }
319}