1use crate::{from_ref::FromRef, mul_by_scalar::MulByScalar};
2use crypto_primitives::{FromWithConfig, PrimeField, boolean::Boolean};
3use num_traits::CheckedAdd;
4use thiserror::Error;
5
6pub trait InnerProduct<Lhs: ?Sized, Rhs, Output> {
8 fn inner_product<const CHECK: bool>(
11 lhs: &Lhs,
12 rhs: &[Rhs],
13 zero: Output,
14 ) -> Result<Output, InnerProductError>;
15}
16
17#[derive(Clone, Debug, PartialEq, Error)]
18pub enum InnerProductError {
19 #[error("The length of LHS and RHS does not match: LHS={lhs}, RHS={rhs}")]
20 LengthMismatch { lhs: usize, rhs: usize },
21 #[error("Arithmetic overflow")]
22 Overflow,
23}
24
25#[derive(Clone, Debug)]
30pub struct MBSInnerProduct;
31
32impl<Lhs, Rhs, Out> InnerProduct<[Lhs], Rhs, Out> for MBSInnerProduct
33where
34 Out: FromRef<Lhs> + for<'a> MulByScalar<&'a Rhs> + CheckedAdd,
35{
36 #[allow(clippy::arithmetic_side_effects)] fn inner_product<const CHECK: bool>(
39 lhs: &[Lhs],
40 rhs: &[Rhs],
41 zero: Out,
42 ) -> Result<Out, InnerProductError> {
43 if lhs.len() != rhs.len() {
44 return Err(InnerProductError::LengthMismatch {
45 lhs: lhs.len(),
46 rhs: rhs.len(),
47 });
48 }
49
50 lhs.iter().zip(rhs).try_fold(zero, |acc, (l, r)| {
51 let widened = Out::from_ref(l);
52 let product = widened
53 .mul_by_scalar::<CHECK>(r)
54 .ok_or(InnerProductError::Overflow)?;
55 if CHECK {
56 acc.checked_add(&product).ok_or(InnerProductError::Overflow)
57 } else {
58 Ok(acc + product)
59 }
60 })
61 }
62}
63
64impl MBSInnerProduct {
65 #[allow(clippy::arithmetic_side_effects)]
66 pub fn inner_product_field<Lhs, F>(
67 lhs: &[Lhs],
68 rhs: &[F],
69 zero: F,
70 ) -> Result<F, InnerProductError>
71 where
72 F: PrimeField + for<'a> FromWithConfig<&'a Lhs>,
73 {
74 if lhs.len() != rhs.len() {
75 return Err(InnerProductError::LengthMismatch {
76 lhs: lhs.len(),
77 rhs: rhs.len(),
78 });
79 }
80 let cfg = zero.cfg().clone();
81
82 Ok(lhs.iter().zip(rhs).fold(zero, |acc, (a, r)| {
83 let product: F = F::from_with_cfg(a, &cfg) * r;
84 acc + product
85 }))
86 }
87}
88
89#[derive(Clone, Debug)]
93pub struct ScalarProduct;
94
95impl<Lhs, Rhs, Out> InnerProduct<Lhs, Rhs, Out> for ScalarProduct
96where
97 Out: for<'a> MulByScalar<&'a Rhs> + FromRef<Lhs>,
98{
99 fn inner_product<const CHECK: bool>(
102 lhs: &Lhs,
103 point: &[Rhs],
104 _zero: Out,
105 ) -> Result<Out, InnerProductError> {
106 if point.as_ref().len() != 1 {
107 Err(InnerProductError::LengthMismatch {
108 lhs: 1,
109 rhs: point.as_ref().len(),
110 })
111 } else {
112 Ok(Out::from_ref(lhs)
113 .mul_by_scalar::<CHECK>(&point[0])
114 .ok_or(InnerProductError::Overflow)?)
115 }
116 }
117}
118
119pub struct BooleanInnerProductAdd;
123
124impl<Rhs: Clone, Out: FromRef<Rhs> + CheckedAdd> InnerProduct<[Boolean], Rhs, Out>
125 for BooleanInnerProductAdd
126{
127 #[allow(clippy::arithmetic_side_effects)] fn inner_product<const CHECK: bool>(
130 lhs: &[Boolean],
131 rhs: &[Rhs],
132 zero: Out,
133 ) -> Result<Out, InnerProductError> {
134 if lhs.len() != rhs.as_ref().len() {
135 return Err(InnerProductError::LengthMismatch {
136 lhs: lhs.len(),
137 rhs: rhs.as_ref().len(),
138 });
139 }
140
141 (0..lhs.len())
142 .filter(|&i| lhs[i].into_inner())
143 .try_fold(zero, |acc, i| {
144 let rhs = Out::from_ref(&rhs[i]);
145 if CHECK {
146 acc.checked_add(&rhs).ok_or(InnerProductError::Overflow)
147 } else {
148 Ok(acc + rhs)
149 }
150 })
151 }
152}
153
154#[cfg(test)]
155mod test {
156 use crate::{CHECKED, UNCHECKED};
157 use crypto_bigint::{U64, const_monty_params};
158 use crypto_primitives::crypto_bigint_const_monty::ConstMontyField;
159 use num_traits::ConstZero;
160
161 use super::*;
162
163 #[test]
164 fn test_inner_product_basic() {
165 let lhs = [1, 2, 3];
166 let rhs = [4, 5, 6];
167 assert_eq!(
168 MBSInnerProduct::inner_product::<CHECKED>(&lhs, &rhs, 0),
169 Ok(4 + 2 * 5 + 3 * 6)
170 );
171 }
172
173 #[test]
174 fn scalar_product() {
175 let lhs = 42i32;
176 let rhs = 23i128;
177
178 assert_eq!(
179 ScalarProduct::inner_product::<CHECKED>(&lhs, &[rhs], 0).unwrap(),
180 i128::from(lhs) * rhs
181 )
182 }
183
184 #[test]
185 fn boolean_checked_eq_mbs_inner_product() {
186 let lhs = [
187 Boolean::from(true),
188 Boolean::from(false),
189 Boolean::from(true),
190 Boolean::from(true),
191 ];
192 let rhs = [1i128, 2, 3, 4];
193
194 assert_eq!(
195 BooleanInnerProductAdd::inner_product::<CHECKED>(&lhs, &rhs, 0),
196 MBSInnerProduct::inner_product::<CHECKED>(&rhs, &lhs, 0i128)
197 );
198 }
199
200 const_monty_params!(Params, U64, "0000000000000007");
201
202 #[test]
203 fn boolean_unchecked_eq_boolean_checked() {
204 let lhs = [
205 Boolean::from(true),
206 Boolean::from(false),
207 Boolean::from(true),
208 Boolean::from(true),
209 ];
210 let rhs = [
211 ConstMontyField::<Params, 1>::from(1),
212 ConstMontyField::<Params, 1>::from(2),
213 ConstMontyField::<Params, 1>::from(3),
214 ConstMontyField::<Params, 1>::from(4),
215 ];
216
217 assert_eq!(
218 BooleanInnerProductAdd::inner_product::<CHECKED>(&lhs, &rhs, ConstMontyField::ZERO),
219 BooleanInnerProductAdd::inner_product::<UNCHECKED>(&lhs, &rhs, ConstMontyField::ZERO)
220 );
221 }
222}