Skip to main content

zinc_utils/
inner_product.rs

1use crate::{from_ref::FromRef, mul_by_scalar::MulByScalar};
2use crypto_primitives::{FromWithConfig, PrimeField, boolean::Boolean};
3use num_traits::CheckedAdd;
4use thiserror::Error;
5
6/// A trait for inner product algorithms implementations.
7pub trait InnerProduct<Lhs: ?Sized, Rhs, Output> {
8    /// The main entry point for the inner product.
9    /// `CHECK` determines whether the implementation should check for overflow.
10    fn inner_product<const CHECK: bool>(
11        lhs: &Lhs,
12        rhs: &[Rhs],
13        zero: Output,
14    ) -> Result<Output, InnerProductError>;
15}
16
17#[derive(Clone, Debug, PartialEq, Error)]
18pub enum InnerProductError {
19    #[error("The length of LHS and RHS does not match: LHS={lhs}, RHS={rhs}")]
20    LengthMismatch { lhs: usize, rhs: usize },
21    #[error("Arithmetic overflow")]
22    Overflow,
23}
24
25/// An implementation of inner product that piggies back
26/// on the `MulByScalar` and `CheckedAdd` traits.
27/// It does `mul_by_scalar` for products of terms
28/// and then combines the results using either `add` or `checked_add`.
29#[derive(Clone, Debug)]
30pub struct MBSInnerProduct;
31
32impl<Lhs, Rhs, Out> InnerProduct<[Lhs], Rhs, Out> for MBSInnerProduct
33where
34    Out: FromRef<Lhs> + for<'a> MulByScalar<&'a Rhs> + CheckedAdd,
35{
36    /// The mul-by-scalar inner product.
37    #[allow(clippy::arithmetic_side_effects)] // Used in unchecked mode
38    fn inner_product<const CHECK: bool>(
39        lhs: &[Lhs],
40        rhs: &[Rhs],
41        zero: Out,
42    ) -> Result<Out, InnerProductError> {
43        if lhs.len() != rhs.len() {
44            return Err(InnerProductError::LengthMismatch {
45                lhs: lhs.len(),
46                rhs: rhs.len(),
47            });
48        }
49
50        lhs.iter().zip(rhs).try_fold(zero, |acc, (l, r)| {
51            let widened = Out::from_ref(l);
52            let product = widened
53                .mul_by_scalar::<CHECK>(r)
54                .ok_or(InnerProductError::Overflow)?;
55            if CHECK {
56                acc.checked_add(&product).ok_or(InnerProductError::Overflow)
57            } else {
58                Ok(acc + product)
59            }
60        })
61    }
62}
63
64impl MBSInnerProduct {
65    #[allow(clippy::arithmetic_side_effects)]
66    pub fn inner_product_field<Lhs, F>(
67        lhs: &[Lhs],
68        rhs: &[F],
69        zero: F,
70    ) -> Result<F, InnerProductError>
71    where
72        F: PrimeField + for<'a> FromWithConfig<&'a Lhs>,
73    {
74        if lhs.len() != rhs.len() {
75            return Err(InnerProductError::LengthMismatch {
76                lhs: lhs.len(),
77                rhs: rhs.len(),
78            });
79        }
80        let cfg = zero.cfg().clone();
81
82        Ok(lhs.iter().zip(rhs).fold(zero, |acc, (a, r)| {
83            let product: F = F::from_with_cfg(a, &cfg) * r;
84            acc + product
85        }))
86    }
87}
88
89/// The inner product for vectors of length 1 (a.k.a. scalars).
90/// Uses `mul_by_scalar` to multiply the only components of vectors
91/// to get the result.
92#[derive(Clone, Debug)]
93pub struct ScalarProduct;
94
95impl<Lhs, Rhs, Out> InnerProduct<Lhs, Rhs, Out> for ScalarProduct
96where
97    Out: for<'a> MulByScalar<&'a Rhs> + FromRef<Lhs>,
98{
99    /// A scalar inner product. Assumes `Lhs` is a scalar type
100    /// and always asserts that `point` has only one component.
101    fn inner_product<const CHECK: bool>(
102        lhs: &Lhs,
103        point: &[Rhs],
104        _zero: Out,
105    ) -> Result<Out, InnerProductError> {
106        if point.as_ref().len() != 1 {
107            Err(InnerProductError::LengthMismatch {
108                lhs: 1,
109                rhs: point.as_ref().len(),
110            })
111        } else {
112            Ok(Out::from_ref(lhs)
113                .mul_by_scalar::<CHECK>(&point[0])
114                .ok_or(InnerProductError::Overflow)?)
115        }
116    }
117}
118
119/// The inner product for slices containing `Boolean` elements.
120/// Uses `add` or `checked_add` to sum the elements of the RHS that
121/// correspond to `true` elements of the boolean slice.
122pub struct BooleanInnerProductAdd;
123
124impl<Rhs: Clone, Out: FromRef<Rhs> + CheckedAdd> InnerProduct<[Boolean], Rhs, Out>
125    for BooleanInnerProductAdd
126{
127    /// Boolean inner product.
128    #[allow(clippy::arithmetic_side_effects)] // Used in unchecked mode
129    fn inner_product<const CHECK: bool>(
130        lhs: &[Boolean],
131        rhs: &[Rhs],
132        zero: Out,
133    ) -> Result<Out, InnerProductError> {
134        if lhs.len() != rhs.as_ref().len() {
135            return Err(InnerProductError::LengthMismatch {
136                lhs: lhs.len(),
137                rhs: rhs.as_ref().len(),
138            });
139        }
140
141        (0..lhs.len())
142            .filter(|&i| lhs[i].into_inner())
143            .try_fold(zero, |acc, i| {
144                let rhs = Out::from_ref(&rhs[i]);
145                if CHECK {
146                    acc.checked_add(&rhs).ok_or(InnerProductError::Overflow)
147                } else {
148                    Ok(acc + rhs)
149                }
150            })
151    }
152}
153
154#[cfg(test)]
155mod test {
156    use crate::{CHECKED, UNCHECKED};
157    use crypto_bigint::{U64, const_monty_params};
158    use crypto_primitives::crypto_bigint_const_monty::ConstMontyField;
159    use num_traits::ConstZero;
160
161    use super::*;
162
163    #[test]
164    fn test_inner_product_basic() {
165        let lhs = [1, 2, 3];
166        let rhs = [4, 5, 6];
167        assert_eq!(
168            MBSInnerProduct::inner_product::<CHECKED>(&lhs, &rhs, 0),
169            Ok(4 + 2 * 5 + 3 * 6)
170        );
171    }
172
173    #[test]
174    fn scalar_product() {
175        let lhs = 42i32;
176        let rhs = 23i128;
177
178        assert_eq!(
179            ScalarProduct::inner_product::<CHECKED>(&lhs, &[rhs], 0).unwrap(),
180            i128::from(lhs) * rhs
181        )
182    }
183
184    #[test]
185    fn boolean_checked_eq_mbs_inner_product() {
186        let lhs = [
187            Boolean::from(true),
188            Boolean::from(false),
189            Boolean::from(true),
190            Boolean::from(true),
191        ];
192        let rhs = [1i128, 2, 3, 4];
193
194        assert_eq!(
195            BooleanInnerProductAdd::inner_product::<CHECKED>(&lhs, &rhs, 0),
196            MBSInnerProduct::inner_product::<CHECKED>(&rhs, &lhs, 0i128)
197        );
198    }
199
200    const_monty_params!(Params, U64, "0000000000000007");
201
202    #[test]
203    fn boolean_unchecked_eq_boolean_checked() {
204        let lhs = [
205            Boolean::from(true),
206            Boolean::from(false),
207            Boolean::from(true),
208            Boolean::from(true),
209        ];
210        let rhs = [
211            ConstMontyField::<Params, 1>::from(1),
212            ConstMontyField::<Params, 1>::from(2),
213            ConstMontyField::<Params, 1>::from(3),
214            ConstMontyField::<Params, 1>::from(4),
215        ];
216
217        assert_eq!(
218            BooleanInnerProductAdd::inner_product::<CHECKED>(&lhs, &rhs, ConstMontyField::ZERO),
219            BooleanInnerProductAdd::inner_product::<UNCHECKED>(&lhs, &rhs, ConstMontyField::ZERO)
220        );
221    }
222}