1use crate::{core::inner_product, ring::rq_vector::RqVector};
6use rand::{CryptoRng, Rng};
7use std::ops::Mul;
8
9use super::{rq::Rq, zq::Zq};
10
11#[derive(Debug, Clone)]
13pub struct RqMatrix {
14 elements: Vec<RqVector>,
15 is_symmetric: bool,
17}
18
19impl RqMatrix {
20 pub fn new(elements: Vec<RqVector>, is_symmetric: bool) -> Self {
22 if is_symmetric {
23 debug_assert!(
24 elements.iter().enumerate().all(|(i, r)| r.len() == i + 1),
25 "symmetric storage must be lower-triangular"
26 );
27 } else {
28 let width = elements.first().map(|r| r.len()).unwrap_or(0);
29 debug_assert!(
30 elements.iter().all(|r| r.len() == width),
31 "row lengths differ"
32 );
33 }
34 Self {
35 elements,
36 is_symmetric,
37 }
38 }
39
40 pub fn zero(row_len: usize, col_len: usize) -> Self {
41 RqMatrix::new(vec![RqVector::zero(col_len); row_len], false)
42 }
43
44 pub fn symmetric_zero(size: usize) -> Self {
45 Self {
46 elements: (0..size).map(|row| RqVector::zero(row + 1)).collect(),
47 is_symmetric: true,
48 }
49 }
50
51 pub fn row_len(&self) -> usize {
52 self.elements.len()
53 }
54
55 pub fn col_len(&self) -> usize {
56 let last_row = self.row_len() - 1;
57 self.elements[last_row].len()
58 }
59
60 pub fn get_cell(&self, row: usize, col: usize) -> &Rq {
61 if !self.is_symmetric || row >= col {
62 &self.elements[row].elements()[col]
63 } else {
64 &self.elements[col].elements()[row]
65 }
66 }
67
68 pub fn set_cell(&mut self, row: usize, col: usize, value: Rq) {
69 self.elements[row].set(col, value);
70 }
71
72 pub fn random<R: Rng + CryptoRng>(rng: &mut R, row_len: usize, col_len: usize) -> Self {
74 Self {
75 elements: (0..row_len)
76 .map(|_| RqVector::random(rng, col_len))
77 .collect(),
78 is_symmetric: false,
79 }
80 }
81
82 pub fn symmetric_random<R: Rng + CryptoRng>(rng: &mut R, row_len: usize) -> Self {
84 Self {
85 elements: (0..row_len)
86 .map(|row| RqVector::random(rng, row + 1))
87 .collect(),
88 is_symmetric: true,
89 }
90 }
91
92 pub fn elements(&self) -> &[RqVector] {
93 &self.elements
94 }
95
96 pub fn decompose_each_cell(&self, base: Zq, num_parts: usize) -> RqVector {
99 let mut decomposed_vec = Vec::new();
100 for ring_vector in self.elements() {
101 for ring in ring_vector.elements() {
102 decomposed_vec.extend(ring.decompose(base, num_parts));
103 }
104 }
105 RqVector::new(decomposed_vec)
106 }
107}
108
109impl FromIterator<RqVector> for RqMatrix {
110 fn from_iter<T: IntoIterator<Item = RqVector>>(iter: T) -> Self {
111 let mut elements = Vec::new();
112 for item in iter {
113 elements.push(item);
114 }
115 RqMatrix::new(elements, false)
116 }
117}
118
119impl Mul<&RqVector> for &RqMatrix {
121 type Output = RqVector;
122
123 fn mul(self, rhs: &RqVector) -> Self::Output {
124 debug_assert_eq!(self.col_len(), rhs.len(), "dimension mismatch");
125 let mut result = RqVector::zero(self.elements.len());
126
127 for (i, row) in self.elements.iter().enumerate() {
128 result.set(
129 i,
130 inner_product::compute_linear_combination(row.elements(), rhs.elements()),
131 );
132 }
133 result
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use rand::rng;
140
141 use super::*;
142 use crate::ring::rq::tests::generate_rq_from_zq_vector;
143 use crate::ring::rq::Rq;
144 use crate::ring::zq::Zq;
145
146 #[test]
147 #[cfg(not(feature = "skip-slow-tests"))]
148 fn rqmatrix_fits_stack() {
149 let mut rng = rand::rng();
150 let _: RqMatrix = RqMatrix::random(&mut rng, 256, 1 << 10);
151 }
152
153 #[test]
154 fn test_set_sell() {
155 let mut matrix = RqMatrix::zero(10, 18);
156 matrix.set_cell(4, 9, Rq::new([Zq::new(10); Rq::DEGREE]));
157 matrix.set_cell(8, 1, Rq::new([Zq::new(3); Rq::DEGREE]));
158
159 for (i, vector) in matrix.elements().iter().enumerate() {
160 for (j, poly) in vector.elements().iter().enumerate() {
161 if (i == 4) && (j == 9) {
162 assert_eq!(poly, &Rq::new([Zq::new(10); Rq::DEGREE]))
163 } else if (i == 8) && (j == 1) {
164 assert_eq!(poly, &Rq::new([Zq::new(3); Rq::DEGREE]))
165 } else {
166 assert_eq!(poly, &Rq::zero())
167 }
168 }
169 }
170 }
171
172 #[test]
173 fn test_symmetric_matrix() {
174 let symmetric_matrix = RqMatrix::symmetric_random(&mut rng(), 12);
175 assert_eq!(symmetric_matrix.row_len(), 12);
176 assert_eq!(symmetric_matrix.col_len(), 12);
177 for i in 0..symmetric_matrix.row_len() {
178 assert_eq!(symmetric_matrix.elements()[i].len(), i + 1);
179 }
180 for i in 0..symmetric_matrix.row_len() {
181 for j in 0..symmetric_matrix.col_len() {
182 assert_eq!(
183 symmetric_matrix.get_cell(i, j),
184 symmetric_matrix.get_cell(j, i)
185 )
186 }
187 }
188 }
189
190 #[test]
191 fn test_rq_matrix_from_iterator() {
192 let expected = vec![
193 RqVector::random(&mut rng(), 5),
194 RqVector::random(&mut rng(), 5),
195 RqVector::random(&mut rng(), 5),
196 RqVector::random(&mut rng(), 5),
197 ];
198 let polynomial_matrix = expected.clone().into_iter();
199 let result: RqMatrix = polynomial_matrix.collect();
200
201 assert_eq!(result.elements(), &expected);
202 }
203
204 #[test]
205 fn test_zero_matrix() {
206 pub fn is_polynomial_zero(poly: &Rq) -> bool {
208 poly.coeffs().iter().all(|&coeff| coeff == Zq::ZERO)
209 }
210
211 let matrix = RqMatrix::zero(10, 20);
212 assert_eq!(matrix.row_len(), 10);
213 assert_eq!(matrix.col_len(), 20);
214 for row in matrix.elements() {
215 for cell in row.elements() {
216 assert!(is_polynomial_zero(cell));
217 }
218 }
219 }
220
221 #[test]
222 fn test_rqmartrix_mul() {
223 let poly1: Rq = generate_rq_from_zq_vector(vec![Zq::new(8), Zq::new(6)]);
224 let poly2: Rq = generate_rq_from_zq_vector(vec![-Zq::new(5), -Zq::new(5)]);
225 let poly3: Rq = generate_rq_from_zq_vector(vec![Zq::ONE, Zq::ZERO]);
226 let poly4: Rq = generate_rq_from_zq_vector(vec![Zq::ZERO, Zq::new(4)]);
227 let matrix_1: RqMatrix = RqMatrix::new(vec![RqVector::from(vec![poly1, poly2])], false);
228 let vec_1: RqVector = RqVector::from(vec![poly3, poly4]);
229
230 let result_1 = matrix_1.mul(&vec_1);
231 let expected_poly_1 =
232 generate_rq_from_zq_vector(vec![Zq::new(8), -Zq::new(14), -Zq::new(20)]);
233 let expected_1 = RqVector::from(vec![expected_poly_1]);
234 assert_eq!(result_1, expected_1);
235
236 let poly5: Rq = generate_rq_from_zq_vector(vec![-Zq::new(7), Zq::new(7)]);
237 let poly6: Rq = generate_rq_from_zq_vector(vec![-Zq::new(3), Zq::ZERO]);
238 let poly7: Rq = generate_rq_from_zq_vector(vec![Zq::new(8), -Zq::new(2)]);
239 let poly8: Rq = generate_rq_from_zq_vector(vec![-Zq::new(4), Zq::new(4)]);
240 let poly9: Rq = generate_rq_from_zq_vector(vec![Zq::NEG_ONE, -Zq::new(2)]);
241 let poly10: Rq = generate_rq_from_zq_vector(vec![-Zq::new(3), -Zq::new(3)]);
242 let matrix_2: RqMatrix = RqMatrix::new(
243 vec![
244 RqVector::from(vec![poly5, poly6]),
245 RqVector::from(vec![poly7, poly8]),
246 ],
247 false,
248 );
249 let vec_2: RqVector = RqVector::from(vec![poly9, poly10]);
250
251 let result_2 = matrix_2.mul(&vec_2);
252 let expected_poly_2_1 =
253 generate_rq_from_zq_vector(vec![Zq::new(16), Zq::new(16), -Zq::new(14)]);
254 let expected_poly_2_2 =
255 generate_rq_from_zq_vector(vec![Zq::new(4), -Zq::new(14), -Zq::new(8)]);
256 let expected_2 = RqVector::from(vec![expected_poly_2_1, expected_poly_2_2]);
257 assert_eq!(result_2, expected_2);
258 }
259}