Skip to main content

zinc_poly/univariate/
nat_evaluation.rs

1use std::convert::identity;
2
3use crypto_primitives::{FromPrimitiveWithConfig, Semiring};
4
5use crate::{EvaluatablePolynomial, EvaluationError, Polynomial};
6
7/// Polynomial evaluated on 0, 1, 2, ....
8#[derive(Clone, Debug, PartialEq)]
9pub struct NatEvaluatedPoly<F> {
10    /// Evaluations on P(0), P(1), P(2), ...
11    pub evaluations: Vec<F>,
12}
13
14impl<F> NatEvaluatedPoly<F> {
15    #[inline(always)]
16    pub const fn new(evaluations: Vec<F>) -> Self {
17        Self { evaluations }
18    }
19}
20
21impl<F: Clone> Polynomial<F> for NatEvaluatedPoly<F> {
22    const DEGREE_BOUND: usize = usize::MAX;
23}
24
25impl<F: FromPrimitiveWithConfig> EvaluatablePolynomial<F, F> for NatEvaluatedPoly<F> {
26    type EvaluationPoint = F;
27
28    /// Interpolate the *unique* univariate polynomial of degree *at most*
29    /// `evaluations.len()-1` passing through the y-values in `evaluations` at x
30    /// = 0,..., evaluations.len()-1
31    /// and evaluate this  polynomial at `point`. In other words, efficiently
32    /// compute  $\sum_{i=0}^{len\ evaluations - 1} evaluations\[i\] *
33    /// (\prod_{j!=i} (\text{point} - j)/(i-j))$.
34    // All the arithmetic ops in the function
35    // are made sure to not overflow.
36    #[allow(clippy::arithmetic_side_effects, clippy::cast_possible_wrap)]
37    fn evaluate_at_point(&self, point: &Self::EvaluationPoint) -> Result<F, EvaluationError> {
38        let evaluations = &self.evaluations;
39        // TODO(Alex): Once we have benches, it's worth checking
40        //             if we're even winning anything
41        //             with specialized branches above.
42
43        // We will need these a few times
44        let point = point.clone();
45        let config = point.cfg();
46        let zero = F::zero_with_cfg(config);
47        let one = F::one_with_cfg(config);
48
49        let len = evaluations.len();
50
51        let mut evals = vec![];
52
53        let mut prod = point.clone();
54        evals.push(point.clone());
55
56        //`prod = \prod_{j} (x - j)`
57        // we return early if 0 <= x < len, i.e. if the desired value has been passed
58        let mut j = zero.clone();
59        for i in 1..len {
60            if point == j {
61                return Ok(evaluations[i - 1].clone());
62            }
63            j += &one;
64
65            let tmp = point.clone() - j.clone();
66            evals.push(tmp.clone());
67            prod *= tmp;
68        }
69
70        if point == j {
71            return Ok(evaluations[len - 1].clone());
72        }
73
74        let mut res = zero;
75        // we want to compute \prod (j!=i) (i-j) for a given i
76        //
77        // we start from the last step, which is
78        //  denom[len-1] = (len-1) * (len-2) *... * 2 * 1
79        // the step before that is
80        //  denom[len-2] = (len-2) * (len-3) * ... * 2 * 1 * -1
81        // and the step before that is
82        //  denom[len-3] = (len-3) * (len-4) * ... * 2 * 1 * -1 * -2
83        //
84        // i.e., for any i, the one before this will be derived from
85        //  denom[i-1] = - denom[i] * (len-i) / i
86        //
87        // that is, we only need to store
88        // - the last denom for i = len-1, and
89        // - the ratio between the current step and the last step, which is the product
90        //   of -(len-i) / i from all previous steps and we store this product as a
91        //   fraction number to reduce field divisions.
92
93        // We know
94        //  - 2^61 < factorial(20) < 2^62
95        //  - 2^122 < factorial(33) < 2^123
96        // so we will be able to compute the ratio
97        //  - for len <= 20 with i64
98        //  - for len <= 33 with i128
99        //  - for len >  33 with BigInt
100        if evaluations.len() <= 20 {
101            let last_denom: F = F::from_with_cfg(factorial(len - 1, identity), config);
102
103            let mut ratio_numerator = 1i64;
104            let mut ratio_denominator = 1u64;
105
106            for i in (0..len).rev() {
107                let ratio_numerator_f = F::from_with_cfg(ratio_numerator, config);
108
109                let ratio_denominator_f = F::from_with_cfg(ratio_denominator, config);
110
111                let x = prod.clone() * ratio_denominator_f
112                    / (last_denom.clone() * ratio_numerator_f * &evals[i]);
113
114                res += &(evaluations[i].clone() * x);
115
116                // compute ratio for the next step which is current_ratio * -(len-i)/i
117                if i != 0 {
118                    // Using intentionally, overflow isn't possible
119                    ratio_numerator *= -(len as i64 - i as i64);
120                    ratio_denominator *= i as u64;
121                }
122            }
123        } else if evaluations.len() <= 33 {
124            let last_denom = F::from_with_cfg(factorial(len - 1, u128::from), config);
125            let mut ratio_numerator = 1i128;
126            let mut ratio_denominator = 1u128;
127
128            for i in (0..len).rev() {
129                let ratio_numerator_f = F::from_with_cfg(ratio_numerator, config);
130
131                let ratio_denominator_f = F::from_with_cfg(ratio_denominator, config);
132
133                let x: F = prod.clone() * ratio_denominator_f
134                    / (last_denom.clone() * ratio_numerator_f * &evals[i]);
135                res += &(evaluations[i].clone() * x);
136
137                // compute ratio for the next step which is current_ratio * -(len-i)/i
138                if i != 0 {
139                    ratio_numerator *= -(len as i128 - i as i128);
140                    ratio_denominator *= i as u128;
141                }
142            }
143        } else {
144            // since we are using field operations, we can merge
145            // `last_denom` and `ratio_numerator` into a single field element.
146            let mut denom_up = factorial(len - 1, |u| F::from_with_cfg(u, config));
147            let mut denom_down = one;
148
149            for i in (0..len).rev() {
150                let x = prod.clone() * &denom_down / (denom_up.clone() * &evals[i]);
151                res += &(evaluations[i].clone() * x);
152
153                // compute denom for the next step is -current_denom * (len-i)/i
154                if i != 0 {
155                    let denom_up_factor = F::from_with_cfg((len - i) as u64, config);
156                    denom_up *= -denom_up_factor;
157
158                    let denom_down_factor = F::from_with_cfg(i as u64, config);
159                    denom_down *= denom_down_factor;
160                }
161            }
162        }
163
164        Ok(res)
165    }
166}
167
168/// Compute the factorial(a) = 1 * 2 * ... * a.
169#[allow(clippy::arithmetic_side_effects)]
170fn factorial<R, F>(a: usize, from_u64: F) -> R
171where
172    R: Semiring,
173    F: Fn(u64) -> R + Send + Sync,
174{
175    (1..=(a as u64))
176        .map(&from_u64)
177        .reduce(|mut acc, next| {
178            acc *= next;
179            acc
180        })
181        .unwrap_or(from_u64(1))
182}
183
184#[cfg(test)]
185mod tests {
186    use crypto_bigint::{Odd, modular::MontyParams};
187    use crypto_primitives::{FromWithConfig, crypto_bigint_monty::F256};
188    use itertools::Itertools;
189
190    use crate::{EvaluatablePolynomial, univariate::nat_evaluation::NatEvaluatedPoly};
191
192    const LIMBS: usize = 4;
193    type F = F256;
194
195    fn test_config() -> MontyParams<LIMBS> {
196        let modulus = crypto_bigint::Uint::<LIMBS>::from_be_hex(
197            "0000000000000000000000000000000000860995AE68FC80E1B1BD1E39D54B33",
198        );
199        let modulus = Odd::new(modulus).expect("modulus should be odd");
200        MontyParams::new(modulus)
201    }
202
203    #[test]
204    fn evaluate_nat_evaluation() {
205        let field_elem = F::from_with_cfg(100, &test_config());
206
207        let poly = NatEvaluatedPoly::new(
208            (0..1024)
209                .map(|x| F::from_with_cfg(x, &test_config()))
210                .collect_vec(),
211        );
212
213        let res = poly.evaluate_at_point(&field_elem).unwrap();
214
215        assert_eq!(res, F::from_with_cfg(100, &test_config()));
216    }
217}