labrador/ring/
rq_matrix.rs

1use crate::ring::rq_vector::RqVector;
2use rand::{CryptoRng, Rng};
3use std::ops::Mul;
4
5use super::{rq::Rq, zq::Zq};
6
7/// Matrix of polynomials in Rq
8#[derive(Debug, Clone)]
9pub struct RqMatrix {
10    elements: Vec<RqVector>,
11}
12
13impl RqMatrix {
14    /// Constructor for the Matrix of polynomials in Rq
15    pub const fn new(elements: Vec<RqVector>) -> Self {
16        RqMatrix { elements }
17    }
18
19    pub fn zero(row_len: usize, col_len: usize) -> Self {
20        RqMatrix::new(vec![RqVector::zero(col_len); row_len])
21    }
22
23    pub fn get_row_len(&self) -> usize {
24        self.elements.len()
25    }
26
27    pub fn get_col_len(&self) -> usize {
28        self.elements[0].get_length()
29    }
30
31    /// Create a random matrix of polynomials
32    pub fn random<R: Rng + CryptoRng>(rng: &mut R, row_len: usize, col_len: usize) -> Self {
33        Self {
34            elements: (0..row_len)
35                .map(|_| RqVector::random(rng, col_len))
36                .collect(),
37        }
38    }
39
40    pub fn get_cell_symmetric(&self, row: usize, col: usize) -> Rq {
41        if row >= col {
42            self.elements[row].get_elements()[col].clone()
43        } else {
44            self.elements[col].get_elements()[row].clone()
45        }
46    }
47
48    /// Create a random matrix of polynomials with ternary coefficients
49    pub fn random_ternary<R: Rng + CryptoRng>(rng: &mut R, row_len: usize, col_len: usize) -> Self {
50        Self {
51            elements: (0..row_len)
52                .map(|_| RqVector::random_ternary(rng, col_len))
53                .collect(),
54        }
55    }
56
57    pub fn get_elements(&self) -> &Vec<RqVector> {
58        &self.elements
59    }
60
61    pub fn decompose_each_cell(&self, base: Zq, num_parts: usize) -> RqVector {
62        let mut decomposed_vec = Vec::new();
63        for ring_vector in self.get_elements() {
64            for ring in ring_vector.get_elements() {
65                decomposed_vec.append(&mut ring.decompose(base, num_parts).get_elements().clone());
66            }
67        }
68        RqVector::new(decomposed_vec)
69    }
70}
71
72impl FromIterator<RqVector> for RqMatrix {
73    fn from_iter<T: IntoIterator<Item = RqVector>>(iter: T) -> Self {
74        let mut elements = Vec::new();
75        for item in iter {
76            elements.push(item);
77        }
78        RqMatrix::new(elements)
79    }
80}
81
82// Implement matrix-vector multiplication for reference to matrix
83impl Mul<&RqVector> for &RqMatrix {
84    type Output = RqVector;
85
86    fn mul(self, rhs: &RqVector) -> Self::Output {
87        let mut result = RqVector::zero(self.elements.len());
88
89        for (i, row) in self.elements.iter().enumerate() {
90            result[i] = row * rhs;
91        }
92
93        result
94    }
95}
96
97// Implement matrix-vector multiplication for owned matrix by delegating to reference implementation
98impl Mul<&RqVector> for RqMatrix {
99    type Output = RqVector;
100
101    fn mul(self, rhs: &RqVector) -> Self::Output {
102        &self * rhs
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use crate::ring::rq::Rq;
110    use crate::ring::zq::Zq;
111
112    #[test]
113    #[cfg(not(feature = "skip-slow-tests"))]
114    fn rqmatrix_fits_stack() {
115        let mut rng = rand::rng();
116        let _: RqMatrix = RqMatrix::random(&mut rng, 256, 1 << 10);
117    }
118
119    #[test]
120    fn test_zero_matrix() {
121        let matrix = RqMatrix::zero(10, 20);
122        assert_eq!(matrix.get_row_len(), 10);
123        assert_eq!(matrix.get_col_len(), 20);
124        for row in matrix.get_elements() {
125            for cell in row.get_elements() {
126                assert!(cell.is_zero());
127            }
128        }
129    }
130
131    #[test]
132    fn test_rqmartrix_mul() {
133        let poly1: Rq = vec![Zq::new(8), Zq::new(6)].into();
134        let poly2: Rq = vec![Zq::new(u32::MAX - 4), Zq::new(u32::MAX - 4)].into();
135        let poly3: Rq = vec![Zq::ONE, Zq::ZERO].into();
136        let poly4: Rq = vec![Zq::ZERO, Zq::new(4)].into();
137        let matrix_1: RqMatrix = RqMatrix::new(vec![RqVector::from(vec![poly1, poly2])]);
138        let vec_1: RqVector = RqVector::from(vec![poly3, poly4]);
139
140        let result_1 = matrix_1.mul(&vec_1);
141        let expected_poly_1 =
142            vec![Zq::new(8), Zq::new(u32::MAX - 13), Zq::new(u32::MAX - 19)].into();
143        let expected_1 = RqVector::from(vec![expected_poly_1]);
144        assert_eq!(result_1, expected_1);
145
146        let poly5: Rq = vec![Zq::new(u32::MAX - 6), Zq::new(7)].into();
147        let poly6: Rq = vec![Zq::new(u32::MAX - 2), Zq::ZERO].into();
148        let poly7: Rq = vec![Zq::new(8), Zq::new(u32::MAX - 1)].into();
149        let poly8: Rq = vec![Zq::new(u32::MAX - 3), Zq::new(4)].into();
150        let poly9: Rq = vec![Zq::MAX, Zq::new(u32::MAX - 1)].into();
151        let poly10: Rq = vec![Zq::new(u32::MAX - 2), Zq::new(u32::MAX - 2)].into();
152        let matrix_2: RqMatrix = RqMatrix::new(vec![
153            RqVector::from(vec![poly5, poly6]),
154            RqVector::from(vec![poly7, poly8]),
155        ]);
156        let vec_2: RqVector = RqVector::from(vec![poly9, poly10]);
157
158        let result_2 = matrix_2.mul(&vec_2);
159        let expected_poly_2_1 = vec![Zq::new(16), Zq::new(16), Zq::new(u32::MAX - 13)].into();
160        let expected_poly_2_2 =
161            vec![Zq::new(4), Zq::new(u32::MAX - 13), Zq::new(u32::MAX - 7)].into();
162        let expected_2 = RqVector::from(vec![expected_poly_2_1, expected_poly_2_2]);
163        assert_eq!(result_2, expected_2);
164    }
165}