labrador/ring/
rq_matrix.rs

1//! Dense matrix of `Rq` polynomials with an optional *lower‑triangular (symmetric)*
2//! storage optimisation.  A symmetric matrix stores **only** the lower triangle
3//! (including the diagonal).
4
5use crate::{core::inner_product, ring::rq_vector::RqVector};
6use rand::{CryptoRng, Rng};
7use std::ops::Mul;
8
9use super::{rq::Rq, zq::Zq};
10
11/// Matrix of polynomials in Rq
12#[derive(Debug, Clone)]
13pub struct RqMatrix {
14    elements: Vec<RqVector>,
15    /// If `true`, `rows[i]` has length `i+1` and represents the lower triangle.
16    is_symmetric: bool,
17}
18
19impl RqMatrix {
20    /// Constructor for the Matrix of polynomials in Rq
21    pub fn new(elements: Vec<RqVector>, is_symmetric: bool) -> Self {
22        if is_symmetric {
23            debug_assert!(
24                elements.iter().enumerate().all(|(i, r)| r.len() == i + 1),
25                "symmetric storage must be lower-triangular"
26            );
27        } else {
28            let width = elements.first().map(|r| r.len()).unwrap_or(0);
29            debug_assert!(
30                elements.iter().all(|r| r.len() == width),
31                "row lengths differ"
32            );
33        }
34        Self {
35            elements,
36            is_symmetric,
37        }
38    }
39
40    pub fn zero(row_len: usize, col_len: usize) -> Self {
41        RqMatrix::new(vec![RqVector::zero(col_len); row_len], false)
42    }
43
44    pub fn symmetric_zero(size: usize) -> Self {
45        Self {
46            elements: (0..size).map(|row| RqVector::zero(row + 1)).collect(),
47            is_symmetric: true,
48        }
49    }
50
51    pub fn row_len(&self) -> usize {
52        self.elements.len()
53    }
54
55    pub fn col_len(&self) -> usize {
56        let last_row = self.row_len() - 1;
57        self.elements[last_row].len()
58    }
59
60    pub fn get_cell(&self, row: usize, col: usize) -> &Rq {
61        if !self.is_symmetric || row >= col {
62            &self.elements[row].elements()[col]
63        } else {
64            &self.elements[col].elements()[row]
65        }
66    }
67
68    pub fn set_cell(&mut self, row: usize, col: usize, value: Rq) {
69        self.elements[row].set(col, value);
70    }
71
72    /// Create a random matrix of polynomials
73    pub fn random<R: Rng + CryptoRng>(rng: &mut R, row_len: usize, col_len: usize) -> Self {
74        Self {
75            elements: (0..row_len)
76                .map(|_| RqVector::random(rng, col_len))
77                .collect(),
78            is_symmetric: false,
79        }
80    }
81
82    /// Create a random symmetric matrix of polynomials
83    pub fn symmetric_random<R: Rng + CryptoRng>(rng: &mut R, row_len: usize) -> Self {
84        Self {
85            elements: (0..row_len)
86                .map(|row| RqVector::random(rng, row + 1))
87                .collect(),
88            is_symmetric: true,
89        }
90    }
91
92    pub fn elements(&self) -> &[RqVector] {
93        &self.elements
94    }
95
96    /// Decompose every cell into `parts` low‑norm digits in base `base` and
97    /// concat all results into **one** long vector.
98    pub fn decompose_each_cell(&self, base: Zq, num_parts: usize) -> RqVector {
99        let mut decomposed_vec = Vec::new();
100        for ring_vector in self.elements() {
101            for ring in ring_vector.elements() {
102                decomposed_vec.extend(ring.decompose(base, num_parts));
103            }
104        }
105        RqVector::new(decomposed_vec)
106    }
107}
108
109impl FromIterator<RqVector> for RqMatrix {
110    fn from_iter<T: IntoIterator<Item = RqVector>>(iter: T) -> Self {
111        let mut elements = Vec::new();
112        for item in iter {
113            elements.push(item);
114        }
115        RqMatrix::new(elements, false)
116    }
117}
118
119// Implement matrix-vector multiplication for reference to matrix
120impl Mul<&RqVector> for &RqMatrix {
121    type Output = RqVector;
122
123    fn mul(self, rhs: &RqVector) -> Self::Output {
124        debug_assert_eq!(self.col_len(), rhs.len(), "dimension mismatch");
125        let mut result = RqVector::zero(self.elements.len());
126
127        for (i, row) in self.elements.iter().enumerate() {
128            result.set(
129                i,
130                inner_product::compute_linear_combination(row.elements(), rhs.elements()),
131            );
132        }
133        result
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use rand::rng;
140
141    use super::*;
142    use crate::ring::rq::tests::generate_rq_from_zq_vector;
143    use crate::ring::rq::Rq;
144    use crate::ring::zq::Zq;
145
146    #[test]
147    #[cfg(not(feature = "skip-slow-tests"))]
148    fn rqmatrix_fits_stack() {
149        let mut rng = rand::rng();
150        let _: RqMatrix = RqMatrix::random(&mut rng, 256, 1 << 10);
151    }
152
153    #[test]
154    fn test_set_sell() {
155        let mut matrix = RqMatrix::zero(10, 18);
156        matrix.set_cell(4, 9, Rq::new([Zq::new(10); Rq::DEGREE]));
157        matrix.set_cell(8, 1, Rq::new([Zq::new(3); Rq::DEGREE]));
158
159        for (i, vector) in matrix.elements().iter().enumerate() {
160            for (j, poly) in vector.elements().iter().enumerate() {
161                if (i == 4) && (j == 9) {
162                    assert_eq!(poly, &Rq::new([Zq::new(10); Rq::DEGREE]))
163                } else if (i == 8) && (j == 1) {
164                    assert_eq!(poly, &Rq::new([Zq::new(3); Rq::DEGREE]))
165                } else {
166                    assert_eq!(poly, &Rq::zero())
167                }
168            }
169        }
170    }
171
172    #[test]
173    fn test_symmetric_matrix() {
174        let symmetric_matrix = RqMatrix::symmetric_random(&mut rng(), 12);
175        assert_eq!(symmetric_matrix.row_len(), 12);
176        assert_eq!(symmetric_matrix.col_len(), 12);
177        for i in 0..symmetric_matrix.row_len() {
178            assert_eq!(symmetric_matrix.elements()[i].len(), i + 1);
179        }
180        for i in 0..symmetric_matrix.row_len() {
181            for j in 0..symmetric_matrix.col_len() {
182                assert_eq!(
183                    symmetric_matrix.get_cell(i, j),
184                    symmetric_matrix.get_cell(j, i)
185                )
186            }
187        }
188    }
189
190    #[test]
191    fn test_rq_matrix_from_iterator() {
192        let expected = vec![
193            RqVector::random(&mut rng(), 5),
194            RqVector::random(&mut rng(), 5),
195            RqVector::random(&mut rng(), 5),
196            RqVector::random(&mut rng(), 5),
197        ];
198        let polynomial_matrix = expected.clone().into_iter();
199        let result: RqMatrix = polynomial_matrix.collect();
200
201        assert_eq!(result.elements(), &expected);
202    }
203
204    #[test]
205    fn test_zero_matrix() {
206        /// Check if Polynomial == 0
207        pub fn is_polynomial_zero(poly: &Rq) -> bool {
208            poly.coeffs().iter().all(|&coeff| coeff == Zq::ZERO)
209        }
210
211        let matrix = RqMatrix::zero(10, 20);
212        assert_eq!(matrix.row_len(), 10);
213        assert_eq!(matrix.col_len(), 20);
214        for row in matrix.elements() {
215            for cell in row.elements() {
216                assert!(is_polynomial_zero(cell));
217            }
218        }
219    }
220
221    #[test]
222    fn test_rqmartrix_mul() {
223        let poly1: Rq = generate_rq_from_zq_vector(vec![Zq::new(8), Zq::new(6)]);
224        let poly2: Rq = generate_rq_from_zq_vector(vec![-Zq::new(5), -Zq::new(5)]);
225        let poly3: Rq = generate_rq_from_zq_vector(vec![Zq::ONE, Zq::ZERO]);
226        let poly4: Rq = generate_rq_from_zq_vector(vec![Zq::ZERO, Zq::new(4)]);
227        let matrix_1: RqMatrix = RqMatrix::new(vec![RqVector::from(vec![poly1, poly2])], false);
228        let vec_1: RqVector = RqVector::from(vec![poly3, poly4]);
229
230        let result_1 = matrix_1.mul(&vec_1);
231        let expected_poly_1 =
232            generate_rq_from_zq_vector(vec![Zq::new(8), -Zq::new(14), -Zq::new(20)]);
233        let expected_1 = RqVector::from(vec![expected_poly_1]);
234        assert_eq!(result_1, expected_1);
235
236        let poly5: Rq = generate_rq_from_zq_vector(vec![-Zq::new(7), Zq::new(7)]);
237        let poly6: Rq = generate_rq_from_zq_vector(vec![-Zq::new(3), Zq::ZERO]);
238        let poly7: Rq = generate_rq_from_zq_vector(vec![Zq::new(8), -Zq::new(2)]);
239        let poly8: Rq = generate_rq_from_zq_vector(vec![-Zq::new(4), Zq::new(4)]);
240        let poly9: Rq = generate_rq_from_zq_vector(vec![Zq::NEG_ONE, -Zq::new(2)]);
241        let poly10: Rq = generate_rq_from_zq_vector(vec![-Zq::new(3), -Zq::new(3)]);
242        let matrix_2: RqMatrix = RqMatrix::new(
243            vec![
244                RqVector::from(vec![poly5, poly6]),
245                RqVector::from(vec![poly7, poly8]),
246            ],
247            false,
248        );
249        let vec_2: RqVector = RqVector::from(vec![poly9, poly10]);
250
251        let result_2 = matrix_2.mul(&vec_2);
252        let expected_poly_2_1 =
253            generate_rq_from_zq_vector(vec![Zq::new(16), Zq::new(16), -Zq::new(14)]);
254        let expected_poly_2_2 =
255            generate_rq_from_zq_vector(vec![Zq::new(4), -Zq::new(14), -Zq::new(8)]);
256        let expected_2 = RqVector::from(vec![expected_poly_2_1, expected_poly_2_2]);
257        assert_eq!(result_2, expected_2);
258    }
259}