zip_plus/code/iprs/pntt/radix8/
params.rs1use crate::ZipError;
2use ark_ff::{FftField, FpConfig};
3use num_traits::Euclid;
4use std::{fmt::Debug, marker::PhantomData};
5use zinc_utils::{mul, sub};
6
7pub type PnttInt = i64;
9
10pub trait Config: Debug + Copy + PartialEq + Send + Sync {
12 type Field: FftField;
15 const FIELD_MODULUS: u32;
16
17 const BASE_TWIDDLES: [PnttInt; 8];
21
22 fn field_to_int_normalized(x: Self::Field) -> PnttInt;
27}
28
29#[derive(Clone, Debug, PartialEq, Eq)]
32pub struct Radix8PnttParams<C: Config> {
33 pub row_len: usize,
35 pub codeword_len: usize,
37 pub depth: usize,
39 pub base_len: usize,
41 pub base_dim: usize,
43 pub base_dim_log2: u32,
45 pub base_dim_mask: usize,
47 pub base_matrix: Vec<Vec<PnttInt>>, pub butterfly_twiddles: Vec<Vec<[[PnttInt; 8]; 7]>>,
53
54 _phantom: PhantomData<C>,
55}
56
57impl<C: Config> Radix8PnttParams<C> {
58 pub fn new(row_len: usize, depth: usize, rep_factor: usize) -> Result<Self, ZipError> {
60 let codeword_len = mul!(row_len, rep_factor);
61 if codeword_len >= C::FIELD_MODULUS as usize {
62 return Err(ZipError::InvalidPcsParam(
63 "Codeword length is more than the number of elements in the field".to_owned(),
64 ));
65 }
66
67 let coeff = 1_usize << mul!(3, depth);
68 let (base_len, base_len_rem) = row_len.div_rem_euclid(&coeff);
69 if base_len_rem != 0 {
70 return Err(ZipError::InvalidPcsParam(format!(
71 "Row length {row_len} must be a multiple of {coeff}"
72 )));
73 }
74
75 let base_dim = mul!(base_len, rep_factor);
76 if !base_dim.is_power_of_two() {
77 return Err(ZipError::InvalidPcsParam(format!(
78 "Base dimension {base_dim} must be a power of 2"
79 )));
80 }
81
82 let base_dim_log2: u32 = base_dim.trailing_zeros();
83
84 let base_dim_mask: usize = sub!(base_dim, 1);
85
86 Ok(Self {
87 row_len,
88 codeword_len,
89 depth,
90 base_len,
91 base_dim,
92 base_dim_log2,
93 base_dim_mask,
94 base_matrix: precompute::precompute_base_matrix::<C>(base_dim, base_len),
95 butterfly_twiddles: precompute::precompute_butterfly_twiddles::<C>(
96 base_dim,
97 codeword_len,
98 depth,
99 ),
100 _phantom: PhantomData,
101 })
102 }
103}
104
105mod precompute {
106 use super::{Config, PnttInt};
107 use ark_ff::Field;
108 use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
109 use itertools::Itertools;
110 use std::array;
111 use zinc_utils::mul;
112
113 #[allow(clippy::arithmetic_side_effects)]
114 pub(super) fn precompute_butterfly_twiddles<C: Config>(
115 base_dim: usize,
116 output_len: usize,
117 depth: usize,
118 ) -> Vec<Vec<[[PnttInt; 8]; 7]>> {
119 let roots_of_unity = precompute_roots_of_unity::<C>(output_len);
120
121 (0..depth)
122 .map(|k| {
123 let sub_chunk_length = base_dim * (1 << (3 * k));
124 let curr_prim_root_power = 1 << (3 * (depth - 1 - k));
125
126 (0..sub_chunk_length)
127 .map(|i| {
128 array::from_fn(|j_minus_1| {
129 let root = roots_of_unity[curr_prim_root_power * i * (j_minus_1 + 1)];
130
131 array::from_fn(|twiddle_idx| {
132 mul_and_normalize_twiddle(
133 C::BASE_TWIDDLES[twiddle_idx],
134 root,
135 C::FIELD_MODULUS,
136 )
137 })
138 })
139 })
140 .collect()
141 })
142 .collect()
143 }
144
145 pub(super) fn precompute_base_matrix<C: Config>(
146 base_dim: usize,
147 base_len: usize,
148 ) -> Vec<Vec<PnttInt>> {
149 let domain =
150 Radix2EvaluationDomain::<C::Field>::new(base_dim).expect("Failed to create NTT domain");
151
152 domain
153 .elements()
154 .map(|root| {
155 (0..base_len)
156 .map(move |i| C::field_to_int_normalized(root.pow([i as u64])))
157 .collect_vec()
158 })
159 .collect()
160 }
161
162 #[allow(clippy::arithmetic_side_effects)]
163 pub(super) fn precompute_roots_of_unity<C: Config>(n: usize) -> Vec<PnttInt> {
164 let domain =
165 Radix2EvaluationDomain::<C::Field>::new(n).expect("Failed to create NTT domain");
166
167 domain
168 .elements()
169 .map(C::field_to_int_normalized)
170 .collect_vec()
171 }
172
173 #[allow(clippy::arithmetic_side_effects, clippy::cast_possible_wrap)]
176 pub(super) fn normalize_field_element(x: u64, p: u32) -> i64 {
177 debug_assert!(x <= i64::MAX as u64);
178 let x = x as i64;
179 let p = i64::from(p);
180 if x >= (p - 1) / 2 { x - p } else { x }
181 }
182
183 #[allow(clippy::arithmetic_side_effects)]
184 fn mul_and_normalize_twiddle(twiddle: PnttInt, root: PnttInt, modulus: u32) -> PnttInt {
185 let twiddle_mod = to_positive_mod_repr(twiddle, modulus);
186 let root_mod = to_positive_mod_repr(root, modulus);
187 let product = mul!(twiddle_mod, root_mod) % u64::from(modulus);
188
189 normalize_field_element(product, modulus)
190 }
191
192 #[allow(clippy::cast_sign_loss)]
193 fn to_positive_mod_repr(value: PnttInt, modulus: u32) -> u64 {
194 value.rem_euclid(i64::from(modulus)) as u64
195 }
196}
197
198mod f65537 {
199 #![allow(non_local_definitions)]
200 use ark_ff::{Fp64, MontBackend, MontConfig};
201 #[derive(MontConfig)]
202 #[modulus = "65537"]
203 #[generator = "3"]
204 pub struct Config;
205
206 pub type Backend = MontBackend<Config, 1>;
207 pub type Field = Fp64<Backend>;
208
209 #[allow(clippy::cast_possible_truncation)] pub const MODULUS: u32 = Config::MODULUS.0[0] as u32;
211}
212
213#[derive(Debug, Clone, Copy, PartialEq, Eq)]
215pub struct PnttConfigF65537;
216
217impl Config for PnttConfigF65537 {
218 type Field = f65537::Field;
219 const FIELD_MODULUS: u32 = f65537::MODULUS;
220 const BASE_TWIDDLES: [PnttInt; 8] = [1, 4096, -256, 16, -1, -4096, 256, -16];
221
222 fn field_to_int_normalized(x: Self::Field) -> PnttInt {
223 let big_int = f65537::Backend::into_bigint(x);
224 precompute::normalize_field_element(big_int.0[0], Self::FIELD_MODULUS)
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 fn check_twiddles_generic<C: Config>() {
234 let expected = precompute::precompute_roots_of_unity::<C>(8);
235 let our = C::BASE_TWIDDLES.to_vec();
236 assert_eq!(expected, our);
237 }
238
239 #[test]
240 fn check_twiddles() {
241 check_twiddles_generic::<PnttConfigF65537>();
242 }
243}