Skip to main content

zip_plus/code/iprs/pntt/radix8/
params.rs

1use crate::ZipError;
2use ark_ff::{FftField, FpConfig};
3use num_traits::Euclid;
4use std::{fmt::Debug, marker::PhantomData};
5use zinc_utils::{mul, sub};
6
7/// The integer types of twiddles.
8pub type PnttInt = i64;
9
10/// Configuration of radix-8 pseudo NTT.
11pub trait Config: Debug + Copy + PartialEq + Send + Sync {
12    /// The field used to generate the twiddle factors
13    /// and the base matrix for this pseudo NTT.
14    type Field: FftField;
15    const FIELD_MODULUS: u32;
16
17    /// The coefficients used to combine subresults.
18    /// They are the 8-th roots of unity from the field `Self::Field`
19    /// lifted to `PnttInt`.
20    const BASE_TWIDDLES: [PnttInt; 8];
21
22    /// A helper to get an integer representation that
23    /// lies in the range `[-(p - 1)/2, (p - 1)/2]` from a field element.
24    // TODO(alex): If we need it somewhere else, it would make sense to make this a
25    // property of the field.
26    fn field_to_int_normalized(x: Self::Field) -> PnttInt;
27}
28
29// Precomputed parameters needed for
30// pseudo NTT algorithm.
31#[derive(Clone, Debug, PartialEq, Eq)]
32pub struct Radix8PnttParams<C: Config> {
33    /// The length of the pseudo NTT's input.
34    pub row_len: usize,
35    /// The length of the pseudo NTT's output.
36    pub codeword_len: usize,
37    /// The number of steps where NTT is performed recursively.
38    pub depth: usize,
39    /// The number of columns of the base matrix.
40    pub base_len: usize,
41    /// The number of rows of the base matrix, always a power of 2.
42    pub base_dim: usize,
43    /// log2 of the number of rows of the base matrix.
44    pub base_dim_log2: u32,
45    /// The mask to compute `i % base_dim`.
46    pub base_dim_mask: usize,
47    /// The base matrix of the pseudo NTT.
48    pub base_matrix: Vec<Vec<PnttInt>>, // TODO(Alex): Maybe use DenseRowMatrix for this?
49    /// Precomputed twiddles for every stage that already contain the relevant
50    /// root-of-unity factor. This lets the butterfly apply a single
51    /// multiplication per term instead of two.
52    pub butterfly_twiddles: Vec<Vec<[[PnttInt; 8]; 7]>>,
53
54    _phantom: PhantomData<C>,
55}
56
57impl<C: Config> Radix8PnttParams<C> {
58    /// Precompute pseudo NTT parameters.
59    pub fn new(row_len: usize, depth: usize, rep_factor: usize) -> Result<Self, ZipError> {
60        let codeword_len = mul!(row_len, rep_factor);
61        if codeword_len >= C::FIELD_MODULUS as usize {
62            return Err(ZipError::InvalidPcsParam(
63                "Codeword length is more than the number of elements in the field".to_owned(),
64            ));
65        }
66
67        let coeff = 1_usize << mul!(3, depth);
68        let (base_len, base_len_rem) = row_len.div_rem_euclid(&coeff);
69        if base_len_rem != 0 {
70            return Err(ZipError::InvalidPcsParam(format!(
71                "Row length {row_len} must be a multiple of {coeff}"
72            )));
73        }
74
75        let base_dim = mul!(base_len, rep_factor);
76        if !base_dim.is_power_of_two() {
77            return Err(ZipError::InvalidPcsParam(format!(
78                "Base dimension {base_dim} must be a power of 2"
79            )));
80        }
81
82        let base_dim_log2: u32 = base_dim.trailing_zeros();
83
84        let base_dim_mask: usize = sub!(base_dim, 1);
85
86        Ok(Self {
87            row_len,
88            codeword_len,
89            depth,
90            base_len,
91            base_dim,
92            base_dim_log2,
93            base_dim_mask,
94            base_matrix: precompute::precompute_base_matrix::<C>(base_dim, base_len),
95            butterfly_twiddles: precompute::precompute_butterfly_twiddles::<C>(
96                base_dim,
97                codeword_len,
98                depth,
99            ),
100            _phantom: PhantomData,
101        })
102    }
103}
104
105mod precompute {
106    use super::{Config, PnttInt};
107    use ark_ff::Field;
108    use ark_poly::{EvaluationDomain, Radix2EvaluationDomain};
109    use itertools::Itertools;
110    use std::array;
111    use zinc_utils::mul;
112
113    #[allow(clippy::arithmetic_side_effects)]
114    pub(super) fn precompute_butterfly_twiddles<C: Config>(
115        base_dim: usize,
116        output_len: usize,
117        depth: usize,
118    ) -> Vec<Vec<[[PnttInt; 8]; 7]>> {
119        let roots_of_unity = precompute_roots_of_unity::<C>(output_len);
120
121        (0..depth)
122            .map(|k| {
123                let sub_chunk_length = base_dim * (1 << (3 * k));
124                let curr_prim_root_power = 1 << (3 * (depth - 1 - k));
125
126                (0..sub_chunk_length)
127                    .map(|i| {
128                        array::from_fn(|j_minus_1| {
129                            let root = roots_of_unity[curr_prim_root_power * i * (j_minus_1 + 1)];
130
131                            array::from_fn(|twiddle_idx| {
132                                mul_and_normalize_twiddle(
133                                    C::BASE_TWIDDLES[twiddle_idx],
134                                    root,
135                                    C::FIELD_MODULUS,
136                                )
137                            })
138                        })
139                    })
140                    .collect()
141            })
142            .collect()
143    }
144
145    pub(super) fn precompute_base_matrix<C: Config>(
146        base_dim: usize,
147        base_len: usize,
148    ) -> Vec<Vec<PnttInt>> {
149        let domain =
150            Radix2EvaluationDomain::<C::Field>::new(base_dim).expect("Failed to create NTT domain");
151
152        domain
153            .elements()
154            .map(|root| {
155                (0..base_len)
156                    .map(move |i| C::field_to_int_normalized(root.pow([i as u64])))
157                    .collect_vec()
158            })
159            .collect()
160    }
161
162    #[allow(clippy::arithmetic_side_effects)]
163    pub(super) fn precompute_roots_of_unity<C: Config>(n: usize) -> Vec<PnttInt> {
164        let domain =
165            Radix2EvaluationDomain::<C::Field>::new(n).expect("Failed to create NTT domain");
166
167        domain
168            .elements()
169            .map(C::field_to_int_normalized)
170            .collect_vec()
171    }
172
173    /// Field normalization for at most 32-bit fields.
174    /// Might have unpleasant overflows if used for bigger fields.
175    #[allow(clippy::arithmetic_side_effects, clippy::cast_possible_wrap)]
176    pub(super) fn normalize_field_element(x: u64, p: u32) -> i64 {
177        debug_assert!(x <= i64::MAX as u64);
178        let x = x as i64;
179        let p = i64::from(p);
180        if x >= (p - 1) / 2 { x - p } else { x }
181    }
182
183    #[allow(clippy::arithmetic_side_effects)]
184    fn mul_and_normalize_twiddle(twiddle: PnttInt, root: PnttInt, modulus: u32) -> PnttInt {
185        let twiddle_mod = to_positive_mod_repr(twiddle, modulus);
186        let root_mod = to_positive_mod_repr(root, modulus);
187        let product = mul!(twiddle_mod, root_mod) % u64::from(modulus);
188
189        normalize_field_element(product, modulus)
190    }
191
192    #[allow(clippy::cast_sign_loss)]
193    fn to_positive_mod_repr(value: PnttInt, modulus: u32) -> u64 {
194        value.rem_euclid(i64::from(modulus)) as u64
195    }
196}
197
198mod f65537 {
199    #![allow(non_local_definitions)]
200    use ark_ff::{Fp64, MontBackend, MontConfig};
201    #[derive(MontConfig)]
202    #[modulus = "65537"]
203    #[generator = "3"]
204    pub struct Config;
205
206    pub type Backend = MontBackend<Config, 1>;
207    pub type Field = Fp64<Backend>;
208
209    #[allow(clippy::cast_possible_truncation)] // We know modulus is small enough.
210    pub const MODULUS: u32 = Config::MODULUS.0[0] as u32;
211}
212
213/// Pseudo NTT configuration for F65537 (2^16 + 1).
214#[derive(Debug, Clone, Copy, PartialEq, Eq)]
215pub struct PnttConfigF65537;
216
217impl Config for PnttConfigF65537 {
218    type Field = f65537::Field;
219    const FIELD_MODULUS: u32 = f65537::MODULUS;
220    const BASE_TWIDDLES: [PnttInt; 8] = [1, 4096, -256, 16, -1, -4096, 256, -16];
221
222    fn field_to_int_normalized(x: Self::Field) -> PnttInt {
223        let big_int = f65537::Backend::into_bigint(x);
224        precompute::normalize_field_element(big_int.0[0], Self::FIELD_MODULUS)
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    // Twiddles are indeed the 8th roots of unity.
233    fn check_twiddles_generic<C: Config>() {
234        let expected = precompute::precompute_roots_of_unity::<C>(8);
235        let our = C::BASE_TWIDDLES.to_vec();
236        assert_eq!(expected, our);
237    }
238
239    #[test]
240    fn check_twiddles() {
241        check_twiddles_generic::<PnttConfigF65537>();
242    }
243}