labrador/ring/
rq_matrix.rs

1use crate::{core::inner_product, ring::rq_vector::RqVector};
2use rand::{CryptoRng, Rng};
3use std::ops::Mul;
4
5use super::{rq::Rq, zq::Zq};
6
7/// Matrix of polynomials in Rq
8#[derive(Debug, Clone)]
9pub struct RqMatrix {
10    elements: Vec<RqVector>,
11    is_symmetric: bool,
12}
13
14impl RqMatrix {
15    /// Constructor for the Matrix of polynomials in Rq
16    pub fn new(elements: Vec<RqVector>, is_symmetric: bool) -> Self {
17        RqMatrix {
18            elements,
19            is_symmetric,
20        }
21    }
22
23    pub fn zero(row_len: usize, col_len: usize) -> Self {
24        RqMatrix::new(vec![RqVector::zero(col_len); row_len], false)
25    }
26
27    pub fn symmetric_zero(size: usize) -> Self {
28        Self {
29            elements: (0..size).map(|row| RqVector::zero(row + 1)).collect(),
30            is_symmetric: true,
31        }
32    }
33
34    pub fn get_row_len(&self) -> usize {
35        self.elements.len()
36    }
37
38    pub fn get_col_len(&self) -> usize {
39        let last_row = self.get_row_len() - 1;
40        self.elements[last_row].get_length()
41    }
42
43    pub fn get_cell(&self, row: usize, col: usize) -> &Rq {
44        if !self.is_symmetric || row >= col {
45            &self.elements[row].get_elements()[col]
46        } else {
47            &self.elements[col].get_elements()[row]
48        }
49    }
50
51    pub fn set_cell(&mut self, row: usize, col: usize, value: Rq) {
52        self.elements[row].set(col, value);
53    }
54
55    /// Create a random matrix of polynomials
56    pub fn random<R: Rng + CryptoRng>(rng: &mut R, row_len: usize, col_len: usize) -> Self {
57        Self {
58            elements: (0..row_len)
59                .map(|_| RqVector::random(rng, col_len))
60                .collect(),
61            is_symmetric: false,
62        }
63    }
64
65    /// Create a random symmetric matrix of polynomials
66    pub fn symmetric_random<R: Rng + CryptoRng>(rng: &mut R, row_len: usize) -> Self {
67        Self {
68            elements: (0..row_len)
69                .map(|row| RqVector::random(rng, row + 1))
70                .collect(),
71            is_symmetric: true,
72        }
73    }
74
75    pub fn get_elements(&self) -> &[RqVector] {
76        &self.elements
77    }
78
79    pub fn decompose_each_cell(&self, base: Zq, num_parts: usize) -> RqVector {
80        let mut decomposed_vec = Vec::new();
81        for ring_vector in self.get_elements() {
82            for ring in ring_vector.get_elements() {
83                decomposed_vec.extend(ring.decompose(base, num_parts));
84            }
85        }
86        RqVector::new(decomposed_vec)
87    }
88}
89
90impl FromIterator<RqVector> for RqMatrix {
91    fn from_iter<T: IntoIterator<Item = RqVector>>(iter: T) -> Self {
92        let mut elements = Vec::new();
93        for item in iter {
94            elements.push(item);
95        }
96        RqMatrix::new(elements, false)
97    }
98}
99
100// Implement matrix-vector multiplication for reference to matrix
101impl Mul<&RqVector> for &RqMatrix {
102    type Output = RqVector;
103
104    fn mul(self, rhs: &RqVector) -> Self::Output {
105        let mut result = RqVector::zero(self.elements.len());
106
107        for (i, row) in self.elements.iter().enumerate() {
108            result.set(
109                i,
110                inner_product::compute_linear_combination(row.get_elements(), rhs.get_elements()),
111            );
112        }
113        result
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use rand::rng;
120
121    use super::*;
122    use crate::ring::rq::tests::generate_rq_from_zq_vector;
123    use crate::ring::rq::Rq;
124    use crate::ring::zq::Zq;
125
126    #[test]
127    #[cfg(not(feature = "skip-slow-tests"))]
128    fn rqmatrix_fits_stack() {
129        let mut rng = rand::rng();
130        let _: RqMatrix = RqMatrix::random(&mut rng, 256, 1 << 10);
131    }
132
133    #[test]
134    fn test_set_sell() {
135        let mut matrix = RqMatrix::zero(10, 18);
136        matrix.set_cell(4, 9, Rq::new([Zq::new(10); Rq::DEGREE]));
137        matrix.set_cell(8, 1, Rq::new([Zq::new(3); Rq::DEGREE]));
138
139        for (i, vector) in matrix.get_elements().iter().enumerate() {
140            for (j, poly) in vector.get_elements().iter().enumerate() {
141                if (i == 4) && (j == 9) {
142                    assert_eq!(poly, &Rq::new([Zq::new(10); Rq::DEGREE]))
143                } else if (i == 8) && (j == 1) {
144                    assert_eq!(poly, &Rq::new([Zq::new(3); Rq::DEGREE]))
145                } else {
146                    assert_eq!(poly, &Rq::zero())
147                }
148            }
149        }
150    }
151
152    #[test]
153    fn test_symmetric_matrix() {
154        let symmetric_matrix = RqMatrix::symmetric_random(&mut rng(), 12);
155        assert_eq!(symmetric_matrix.get_row_len(), 12);
156        assert_eq!(symmetric_matrix.get_col_len(), 12);
157        for i in 0..symmetric_matrix.get_row_len() {
158            assert_eq!(symmetric_matrix.get_elements()[i].get_length(), i + 1);
159        }
160        for i in 0..symmetric_matrix.get_row_len() {
161            for j in 0..symmetric_matrix.get_col_len() {
162                assert_eq!(
163                    symmetric_matrix.get_cell(i, j),
164                    symmetric_matrix.get_cell(j, i)
165                )
166            }
167        }
168    }
169
170    #[test]
171    fn test_rq_matrix_from_iterator() {
172        let expected = vec![
173            RqVector::random(&mut rng(), 5),
174            RqVector::random(&mut rng(), 5),
175            RqVector::random(&mut rng(), 5),
176            RqVector::random(&mut rng(), 5),
177        ];
178        let polynomial_matrix = expected.clone().into_iter();
179        let result: RqMatrix = polynomial_matrix.collect();
180
181        assert_eq!(result.get_elements(), &expected);
182    }
183
184    #[test]
185    fn test_zero_matrix() {
186        /// Check if Polynomial == 0
187        pub fn is_polynomial_zero(poly: &Rq) -> bool {
188            poly.get_coefficients()
189                .iter()
190                .all(|&coeff| coeff == Zq::ZERO)
191        }
192
193        let matrix = RqMatrix::zero(10, 20);
194        assert_eq!(matrix.get_row_len(), 10);
195        assert_eq!(matrix.get_col_len(), 20);
196        for row in matrix.get_elements() {
197            for cell in row.get_elements() {
198                assert!(is_polynomial_zero(cell));
199            }
200        }
201    }
202
203    #[test]
204    fn test_rqmartrix_mul() {
205        let poly1: Rq = generate_rq_from_zq_vector(vec![Zq::new(8), Zq::new(6)]);
206        let poly2: Rq = generate_rq_from_zq_vector(vec![-Zq::new(5), -Zq::new(5)]);
207        let poly3: Rq = generate_rq_from_zq_vector(vec![Zq::ONE, Zq::ZERO]);
208        let poly4: Rq = generate_rq_from_zq_vector(vec![Zq::ZERO, Zq::new(4)]);
209        let matrix_1: RqMatrix = RqMatrix::new(vec![RqVector::from(vec![poly1, poly2])], false);
210        let vec_1: RqVector = RqVector::from(vec![poly3, poly4]);
211
212        let result_1 = matrix_1.mul(&vec_1);
213        let expected_poly_1 =
214            generate_rq_from_zq_vector(vec![Zq::new(8), -Zq::new(14), -Zq::new(20)]);
215        let expected_1 = RqVector::from(vec![expected_poly_1]);
216        assert_eq!(result_1, expected_1);
217
218        let poly5: Rq = generate_rq_from_zq_vector(vec![-Zq::new(7), Zq::new(7)]);
219        let poly6: Rq = generate_rq_from_zq_vector(vec![-Zq::new(3), Zq::ZERO]);
220        let poly7: Rq = generate_rq_from_zq_vector(vec![Zq::new(8), -Zq::new(2)]);
221        let poly8: Rq = generate_rq_from_zq_vector(vec![-Zq::new(4), Zq::new(4)]);
222        let poly9: Rq = generate_rq_from_zq_vector(vec![Zq::NEG_ONE, -Zq::new(2)]);
223        let poly10: Rq = generate_rq_from_zq_vector(vec![-Zq::new(3), -Zq::new(3)]);
224        let matrix_2: RqMatrix = RqMatrix::new(
225            vec![
226                RqVector::from(vec![poly5, poly6]),
227                RqVector::from(vec![poly7, poly8]),
228            ],
229            false,
230        );
231        let vec_2: RqVector = RqVector::from(vec![poly9, poly10]);
232
233        let result_2 = matrix_2.mul(&vec_2);
234        let expected_poly_2_1 =
235            generate_rq_from_zq_vector(vec![Zq::new(16), Zq::new(16), -Zq::new(14)]);
236        let expected_poly_2_2 =
237            generate_rq_from_zq_vector(vec![Zq::new(4), -Zq::new(14), -Zq::new(8)]);
238        let expected_2 = RqVector::from(vec![expected_poly_2_1, expected_poly_2_2]);
239        assert_eq!(result_2, expected_2);
240    }
241}