labrador/ring/
rq_matrix.rs1use crate::ring::rq_vector::RqVector;
2use rand::{CryptoRng, Rng};
3use std::ops::Mul;
4
5#[derive(Debug, Clone)]
7pub struct RqMatrix<const M: usize, const N: usize, const D: usize> {
8 elements: Vec<RqVector<N, D>>,
9}
10
11impl<const M: usize, const N: usize, const D: usize> RqMatrix<M, N, D> {
12 pub const fn new(elements: Vec<RqVector<N, D>>) -> Self {
14 RqMatrix { elements }
15 }
16
17 pub fn random<R: Rng + CryptoRng>(rng: &mut R) -> Self {
19 Self {
20 elements: (0..M).map(|_| RqVector::random(rng)).collect(),
21 }
22 }
23
24 pub fn random_ternary<R: Rng + CryptoRng>(rng: &mut R) -> Self {
26 Self {
27 elements: (0..M).map(|_| RqVector::random_ternary(rng)).collect(),
28 }
29 }
30}
31
32impl<const M: usize, const N: usize, const D: usize> Mul<&RqVector<N, D>> for &RqMatrix<M, N, D> {
34 type Output = RqVector<M, D>;
35
36 fn mul(self, rhs: &RqVector<N, D>) -> Self::Output {
37 let mut result = RqVector::zero();
38
39 for (i, row) in self.elements.iter().enumerate() {
40 result[i] = row * rhs;
41 }
42
43 result
44 }
45}
46
47impl<const M: usize, const N: usize, const D: usize> Mul<&RqVector<N, D>> for RqMatrix<M, N, D> {
49 type Output = RqVector<M, D>;
50
51 fn mul(self, rhs: &RqVector<N, D>) -> Self::Output {
52 &self * rhs
53 }
54}
55
56#[cfg(test)]
57mod tests {
58 use super::*;
59 use crate::ring::rq::Rq;
60 use crate::ring::zq::Zq;
61
62 #[test]
63 #[cfg(not(feature = "skip-slow-tests"))]
64 fn rqmatrix_fits_stack() {
65 let mut rng = rand::rng();
66 let _: RqMatrix<256, { 1 << 10 }, 64> = RqMatrix::random(&mut rng);
67 }
68
69 #[test]
70 fn test_rqmartrix_mul() {
71 let poly1: Rq<2> = vec![Zq::new(8), Zq::new(6)].into();
72 let poly2: Rq<2> = vec![Zq::new(u32::MAX - 4), Zq::new(u32::MAX - 4)].into();
73 let poly3: Rq<2> = vec![Zq::ONE, Zq::ZERO].into();
74 let poly4: Rq<2> = vec![Zq::ZERO, Zq::new(4)].into();
75 let matrix_1: RqMatrix<1, 2, 2> = RqMatrix::new(vec![RqVector::from(vec![poly1, poly2])]);
76 let vec_1: RqVector<2, 2> = RqVector::from(vec![poly3, poly4]);
77
78 let result_1 = matrix_1.mul(&vec_1);
79 let expected_poly_1 = vec![Zq::new(28), Zq::new(u32::MAX - 13)].into();
80 let expected_1 = RqVector::from(vec![expected_poly_1]);
81 assert_eq!(result_1, expected_1);
82
83 let poly5: Rq<2> = vec![Zq::new(u32::MAX - 6), Zq::new(7)].into();
84 let poly6: Rq<2> = vec![Zq::new(u32::MAX - 2), Zq::ZERO].into();
85 let poly7: Rq<2> = vec![Zq::new(8), Zq::new(u32::MAX - 1)].into();
86 let poly8: Rq<2> = vec![Zq::new(u32::MAX - 3), Zq::new(4)].into();
87 let poly9: Rq<2> = vec![Zq::MAX, Zq::new(u32::MAX - 1)].into();
88 let poly10: Rq<2> = vec![Zq::new(u32::MAX - 2), Zq::new(u32::MAX - 2)].into();
89 let matrix_2: RqMatrix<2, 2, 2> = RqMatrix::new(vec![
90 RqVector::from(vec![poly5, poly6]),
91 RqVector::from(vec![poly7, poly8]),
92 ]);
93 let vec_2: RqVector<2, 2> = RqVector::from(vec![poly9, poly10]);
94
95 let result_2 = matrix_2.mul(&vec_2);
96 let expected_poly_2_1 = vec![Zq::new(30), Zq::new(16)].into();
97 let expected_poly_2_2 = vec![Zq::new(12), Zq::new(u32::MAX - 13)].into();
98 let expected_2 = RqVector::from(vec![expected_poly_2_1, expected_poly_2_2]);
99 assert_eq!(result_2, expected_2);
100 }
101}