labrador/ring/
rq_matrix.rs

1use crate::ring::rq_vector::RqVector;
2use rand::{CryptoRng, Rng};
3use std::ops::Mul;
4
5/// Matrix of polynomials in Rq
6#[derive(Debug, Clone)]
7pub struct RqMatrix<const M: usize, const N: usize, const D: usize> {
8    elements: Vec<RqVector<N, D>>,
9}
10
11impl<const M: usize, const N: usize, const D: usize> RqMatrix<M, N, D> {
12    /// Constructor for the Matrix of polynomials in Rq
13    pub const fn new(elements: Vec<RqVector<N, D>>) -> Self {
14        RqMatrix { elements }
15    }
16
17    /// Create a random matrix of polynomials
18    pub fn random<R: Rng + CryptoRng>(rng: &mut R) -> Self {
19        Self {
20            elements: (0..M).map(|_| RqVector::random(rng)).collect(),
21        }
22    }
23
24    /// Create a random matrix of polynomials with ternary coefficients
25    pub fn random_ternary<R: Rng + CryptoRng>(rng: &mut R) -> Self {
26        Self {
27            elements: (0..M).map(|_| RqVector::random_ternary(rng)).collect(),
28        }
29    }
30}
31
32// Implement matrix-vector multiplication for reference to matrix
33impl<const M: usize, const N: usize, const D: usize> Mul<&RqVector<N, D>> for &RqMatrix<M, N, D> {
34    type Output = RqVector<M, D>;
35
36    fn mul(self, rhs: &RqVector<N, D>) -> Self::Output {
37        let mut result = RqVector::zero();
38
39        for (i, row) in self.elements.iter().enumerate() {
40            result[i] = row * rhs;
41        }
42
43        result
44    }
45}
46
47// Implement matrix-vector multiplication for owned matrix by delegating to reference implementation
48impl<const M: usize, const N: usize, const D: usize> Mul<&RqVector<N, D>> for RqMatrix<M, N, D> {
49    type Output = RqVector<M, D>;
50
51    fn mul(self, rhs: &RqVector<N, D>) -> Self::Output {
52        &self * rhs
53    }
54}
55
56#[cfg(test)]
57mod tests {
58    use super::*;
59    use crate::ring::rq::Rq;
60    use crate::ring::zq::Zq;
61
62    #[test]
63    #[cfg(not(feature = "skip-slow-tests"))]
64    fn rqmatrix_fits_stack() {
65        let mut rng = rand::rng();
66        let _: RqMatrix<256, { 1 << 10 }, 64> = RqMatrix::random(&mut rng);
67    }
68
69    #[test]
70    fn test_rqmartrix_mul() {
71        let poly1: Rq<2> = vec![Zq::new(8), Zq::new(6)].into();
72        let poly2: Rq<2> = vec![Zq::new(u32::MAX - 4), Zq::new(u32::MAX - 4)].into();
73        let poly3: Rq<2> = vec![Zq::ONE, Zq::ZERO].into();
74        let poly4: Rq<2> = vec![Zq::ZERO, Zq::new(4)].into();
75        let matrix_1: RqMatrix<1, 2, 2> = RqMatrix::new(vec![RqVector::from(vec![poly1, poly2])]);
76        let vec_1: RqVector<2, 2> = RqVector::from(vec![poly3, poly4]);
77
78        let result_1 = matrix_1.mul(&vec_1);
79        let expected_poly_1 = vec![Zq::new(28), Zq::new(u32::MAX - 13)].into();
80        let expected_1 = RqVector::from(vec![expected_poly_1]);
81        assert_eq!(result_1, expected_1);
82
83        let poly5: Rq<2> = vec![Zq::new(u32::MAX - 6), Zq::new(7)].into();
84        let poly6: Rq<2> = vec![Zq::new(u32::MAX - 2), Zq::ZERO].into();
85        let poly7: Rq<2> = vec![Zq::new(8), Zq::new(u32::MAX - 1)].into();
86        let poly8: Rq<2> = vec![Zq::new(u32::MAX - 3), Zq::new(4)].into();
87        let poly9: Rq<2> = vec![Zq::MAX, Zq::new(u32::MAX - 1)].into();
88        let poly10: Rq<2> = vec![Zq::new(u32::MAX - 2), Zq::new(u32::MAX - 2)].into();
89        let matrix_2: RqMatrix<2, 2, 2> = RqMatrix::new(vec![
90            RqVector::from(vec![poly5, poly6]),
91            RqVector::from(vec![poly7, poly8]),
92        ]);
93        let vec_2: RqVector<2, 2> = RqVector::from(vec![poly9, poly10]);
94
95        let result_2 = matrix_2.mul(&vec_2);
96        let expected_poly_2_1 = vec![Zq::new(30), Zq::new(16)].into();
97        let expected_poly_2_2 = vec![Zq::new(12), Zq::new(u32::MAX - 13)].into();
98        let expected_2 = RqVector::from(vec![expected_poly_2_1, expected_poly_2_2]);
99        assert_eq!(result_2, expected_2);
100    }
101}