1mod pntt;
2
3use crate::{ZipError, code::LinearCode, pcs::structs::ZipTypes};
4use crypto_primitives::{FromPrimitiveWithConfig, FromWithConfig};
5use num_traits::{CheckedAdd, CheckedMul};
6use pntt::radix8::params::Config as PnttConfig;
7pub use pntt::radix8::params::{PnttConfigF65537, PnttInt, Radix8PnttParams};
8use std::{
9 fmt::Debug,
10 iter::Sum,
11 marker::PhantomData,
12 ops::{Add, AddAssign},
13};
14use zinc_utils::{from_ref::FromRef, mul_by_scalar::MulByScalar};
15
16#[derive(Clone)]
20pub struct IprsCode<Zt: ZipTypes, Config: PnttConfig, const REP: usize, const CHECK: bool> {
21 pntt_params: Radix8PnttParams<Config>,
22 _phantom: PhantomData<Zt>,
23}
24
25impl<Zt, Config, const REP: usize, const CHECK: bool> IprsCode<Zt, Config, REP, CHECK>
26where
27 Zt: ZipTypes,
28 Config: PnttConfig,
29{
30 pub fn new(row_len: usize, depth: usize) -> Result<Self, ZipError> {
31 Ok(Self {
34 pntt_params: Radix8PnttParams::new(row_len, depth, REP)?,
35 _phantom: Default::default(),
36 })
37 }
38
39 pub fn new_with_optimal_depth(row_len: usize) -> Result<Self, ZipError> {
44 const MAX_BASE_COLS_LOG2: usize = 8;
45
46 let target_base_len = 1 << MAX_BASE_COLS_LOG2;
47 let depth = 1.max(((1.max(row_len / target_base_len)).ilog2() as usize).div_ceil(3));
49
50 Self::new(row_len, depth)
51 }
52
53 fn encode_inner<In, Out>(&self, row: &[In]) -> Vec<Out>
55 where
56 In: for<'a> MulByScalar<&'a PnttInt, Out> + Clone + Send + Sync,
57 Out: CheckedAdd
58 + for<'a> AddAssign<&'a Out>
59 + for<'a> Add<&'a Out, Output = Out>
60 + CheckedMul
61 + for<'a> MulByScalar<&'a PnttInt>
62 + Sum
63 + FromRef<In>
64 + Clone
65 + Debug
66 + Send
67 + Sync,
68 {
69 assert_eq!(
70 row.len(),
71 self.pntt_params.row_len,
72 "Input length {} does not match expected row length {}",
73 row.len(),
74 self.pntt_params.row_len,
75 );
76
77 macro_rules! mul_fn {
78 () => {
79 |v, tw| {
80 v.mul_by_scalar::<CHECK>(tw)
81 .expect("Multiplication by twiddle should not overflow")
82 }
83 };
84 }
85
86 pntt::radix8::pntt::<_, _, _, CHECK>(row, &self.pntt_params, mul_fn!(), mul_fn!())
87 }
88
89 fn encode_inner_f<F>(&self, row: &[F]) -> Vec<F>
92 where
93 F: FromWithConfig<PnttInt> + FromRef<F>,
94 {
95 assert_eq!(
96 row.len(),
97 self.pntt_params.row_len,
98 "Input length {} does not match expected row length {}",
99 row.len(),
100 self.pntt_params.row_len,
101 );
102
103 let mul_fn = |f: &F, tw: &PnttInt| f.clone() * F::from_with_cfg(*tw, f.cfg());
104
105 pntt::radix8::pntt::<_, _, _, CHECK>(row, &self.pntt_params, mul_fn, mul_fn)
106 }
107}
108
109impl<Zt: ZipTypes, Config, const REP: usize, const CHECK: bool> LinearCode<Zt>
110 for IprsCode<Zt, Config, REP, CHECK>
111where
112 Zt: ZipTypes,
113 Config: PnttConfig,
114 Zt::Eval: for<'a> MulByScalar<&'a PnttInt, Zt::Cw>,
115 Zt::CombR: for<'a> MulByScalar<&'a PnttInt>,
116 Zt::Cw: CheckedAdd + for<'a> MulByScalar<&'a PnttInt>,
117{
118 const REPETITION_FACTOR: usize = REP;
119
120 fn encode(&self, row: &[Zt::Eval]) -> Vec<Zt::Cw> {
121 assert_eq!(
122 row.len(),
123 self.pntt_params.row_len,
124 "Input length {} does not match expected row length {}",
125 row.len(),
126 self.pntt_params.row_len,
127 );
128
129 self.encode_inner(row)
130 }
131
132 fn row_len(&self) -> usize {
133 self.pntt_params.row_len
134 }
135
136 fn codeword_len(&self) -> usize {
137 self.pntt_params.codeword_len
138 }
139
140 fn params_string(&self) -> String {
141 format!(
142 "row_len={}, rate=1/{REP}, depth={}",
143 self.row_len(),
144 self.pntt_params.depth
145 )
146 }
147
148 fn encode_wide(&self, row: &[Zt::CombR]) -> Vec<Zt::CombR> {
149 self.encode_inner(row)
150 }
151
152 fn encode_f<F>(&self, row: &[F]) -> Vec<F>
153 where
154 F: FromPrimitiveWithConfig + FromRef<F>,
155 {
156 self.encode_inner_f(row)
157 }
158}
159
160impl<Zt, Config, const REP: usize, const CHECK: bool> Debug for IprsCode<Zt, Config, REP, CHECK>
161where
162 Zt: ZipTypes,
163 Config: PnttConfig,
164{
165 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166 f.debug_struct("IprsCode")
167 .field("pntt_params", &self.pntt_params)
168 .finish()
169 }
170}
171
172impl<Zt, Config, const REP: usize, const CHECK: bool> PartialEq for IprsCode<Zt, Config, REP, CHECK>
173where
174 Config: PnttConfig,
175 Zt: ZipTypes,
176{
177 fn eq(&self, other: &Self) -> bool {
178 self.pntt_params == other.pntt_params
179 }
180}
181
182impl<Zt, Config, const REP: usize, const CHECK: bool> Eq for IprsCode<Zt, Config, REP, CHECK>
183where
184 Zt: ZipTypes,
185 Config: PnttConfig,
186{
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192 use crate::pcs::{structs::ZipPlus, test_utils::*};
193 use crypto_bigint::U64;
194 use crypto_primitives::{
195 FixedSemiring, boolean::Boolean, crypto_bigint_int::Int, crypto_bigint_uint::Uint,
196 };
197 use rand::{
198 distr::{Distribution, StandardUniform},
199 prelude::ThreadRng,
200 };
201 use zinc_poly::{
202 mle::{DenseMultilinearExtension, MultilinearExtensionRand},
203 univariate::{
204 binary::{BinaryPoly, BinaryPolyInnerProduct},
205 dense::{DensePolyInnerProduct, DensePolynomial},
206 },
207 };
208 use zinc_primality::MillerRabin;
209 use zinc_transcript::traits::ConstTranscribable;
210 use zinc_utils::{
211 CHECKED,
212 inner_product::{MBSInnerProduct, ScalarProduct},
213 named::Named,
214 };
215
216 const INT_LIMBS: usize = U64::LIMBS;
217 const N: usize = INT_LIMBS;
218 const K: usize = INT_LIMBS * 4;
219 const M: usize = INT_LIMBS * 8;
220 type Zt = TestZipTypes<N, K, M>;
221
222 type Code = IprsCode<Zt, PnttConfigF65537, REP_FACTOR, CHECKED>;
223
224 #[test]
225 fn new_with_different_params() {
226 assert!(Code::new(1, 0).is_ok());
227 assert!(Code::new(8, 0).is_ok());
228 assert!(Code::new(1, 1).is_err());
229 assert!(Code::new(8, 1).is_ok());
230
231 assert!(Code::new_with_optimal_depth(1).is_err());
232 assert!(Code::new_with_optimal_depth(8).is_ok());
233 assert!(Code::new_with_optimal_depth(12).is_err());
234 assert!(Code::new_with_optimal_depth(16).is_ok());
235 }
236
237 fn do_encode<Zt, const REP: usize>(num_vars: usize)
238 where
239 Zt: ZipTypes,
240 Zt::Eval: for<'a> MulByScalar<&'a PnttInt, Zt::Cw>,
241 Zt::CombR: for<'a> MulByScalar<&'a PnttInt>,
242 Zt::Cw: CheckedAdd + for<'a> MulByScalar<&'a PnttInt>,
243 StandardUniform: Distribution<Zt::Eval>,
244 {
245 let mut rng = ThreadRng::default();
246 let poly_size: usize = 1 << num_vars;
247 let mle = DenseMultilinearExtension::rand(num_vars, &mut rng);
248
249 let code = IprsCode::<Zt, PnttConfigF65537, 4, CHECKED>::new_with_optimal_depth(poly_size)
250 .unwrap();
251 let pp = ZipPlus::setup(poly_size, code);
252 ZipPlus::<Zt, _>::encode_rows(&pp, &mle.evaluations);
253 }
254
255 #[test]
257 fn encode_bench_int() {
258 #[derive(Clone, Debug)]
259 struct BenchZipTypes {}
260 impl ZipTypes for BenchZipTypes {
261 const NUM_COLUMN_OPENINGS: usize = 147;
262 type Eval = i32;
263 type Cw = i128;
264 type Fmod = Uint<{ INT_LIMBS * 4 }>;
265 type PrimeTest = MillerRabin;
266 type Chal = i128;
267 type Pt = i128;
268 type CombR = Int<{ INT_LIMBS * 3 }>;
269 type Comb = Self::CombR;
270 type EvalDotChal = ScalarProduct;
271 type CombDotChal = ScalarProduct;
272 type ArrCombRDotChal = MBSInnerProduct;
273 }
274
275 do_encode::<BenchZipTypes, 4>(14);
276 }
277
278 #[test]
280 fn encode_bench_poly() {
281 const D_PLUS_ONE: usize = 32;
282
283 #[derive(Clone, Debug)]
284 struct BenchZipPlusTypes<CwCoeff>(PhantomData<CwCoeff>);
285 impl<CwCoeff> ZipTypes for BenchZipPlusTypes<CwCoeff>
286 where
287 CwCoeff: ConstTranscribable
288 + Copy
289 + Default
290 + FromRef<Boolean>
291 + Named
292 + FixedSemiring
293 + Send
294 + Sync,
295 Int<5>: FromRef<CwCoeff>,
296 {
297 const NUM_COLUMN_OPENINGS: usize = 147;
298 type Eval = BinaryPoly<D_PLUS_ONE>;
299 type Cw = DensePolynomial<CwCoeff, D_PLUS_ONE>;
300 type Fmod = Uint<{ INT_LIMBS * 4 }>;
301 type PrimeTest = MillerRabin;
302 type Chal = i128;
303 type Pt = i128;
304 type CombR = Int<{ INT_LIMBS * 5 }>;
305 type Comb = DensePolynomial<Self::CombR, D_PLUS_ONE>;
306 type EvalDotChal = BinaryPolyInnerProduct<Self::Chal, D_PLUS_ONE>;
307 type CombDotChal = DensePolyInnerProduct<
308 Self::CombR,
309 Self::Chal,
310 Self::CombR,
311 MBSInnerProduct,
312 D_PLUS_ONE,
313 >;
314 type ArrCombRDotChal = MBSInnerProduct;
315 }
316
317 do_encode::<BenchZipPlusTypes<i64>, 4>(14);
318 }
319}