1use crate::{core::inner_product, ring::rq_vector::RqVector};
2use rand::{CryptoRng, Rng};
3use std::ops::Mul;
4
5use super::{rq::Rq, zq::Zq};
6
7#[derive(Debug, Clone)]
9pub struct RqMatrix {
10 elements: Vec<RqVector>,
11 is_symmetric: bool,
12}
13
14impl RqMatrix {
15 pub fn new(elements: Vec<RqVector>, is_symmetric: bool) -> Self {
17 RqMatrix {
18 elements,
19 is_symmetric,
20 }
21 }
22
23 pub fn zero(row_len: usize, col_len: usize) -> Self {
24 RqMatrix::new(vec![RqVector::zero(col_len); row_len], false)
25 }
26
27 pub fn symmetric_zero(size: usize) -> Self {
28 Self {
29 elements: (0..size).map(|row| RqVector::zero(row + 1)).collect(),
30 is_symmetric: true,
31 }
32 }
33
34 pub fn get_row_len(&self) -> usize {
35 self.elements.len()
36 }
37
38 pub fn get_col_len(&self) -> usize {
39 let last_row = self.get_row_len() - 1;
40 self.elements[last_row].get_length()
41 }
42
43 pub fn get_cell(&self, row: usize, col: usize) -> &Rq {
44 if !self.is_symmetric || row >= col {
45 &self.elements[row].get_elements()[col]
46 } else {
47 &self.elements[col].get_elements()[row]
48 }
49 }
50
51 pub fn set_cell(&mut self, row: usize, col: usize, value: Rq) {
52 self.elements[row].set(col, value);
53 }
54
55 pub fn random<R: Rng + CryptoRng>(rng: &mut R, row_len: usize, col_len: usize) -> Self {
57 Self {
58 elements: (0..row_len)
59 .map(|_| RqVector::random(rng, col_len))
60 .collect(),
61 is_symmetric: false,
62 }
63 }
64
65 pub fn symmetric_random<R: Rng + CryptoRng>(rng: &mut R, row_len: usize) -> Self {
67 Self {
68 elements: (0..row_len)
69 .map(|row| RqVector::random(rng, row + 1))
70 .collect(),
71 is_symmetric: true,
72 }
73 }
74
75 pub fn get_elements(&self) -> &[RqVector] {
76 &self.elements
77 }
78
79 pub fn decompose_each_cell(&self, base: Zq, num_parts: usize) -> RqVector {
80 let mut decomposed_vec = Vec::new();
81 for ring_vector in self.get_elements() {
82 for ring in ring_vector.get_elements() {
83 decomposed_vec.extend(ring.decompose(base, num_parts));
84 }
85 }
86 RqVector::new(decomposed_vec)
87 }
88}
89
90impl FromIterator<RqVector> for RqMatrix {
91 fn from_iter<T: IntoIterator<Item = RqVector>>(iter: T) -> Self {
92 let mut elements = Vec::new();
93 for item in iter {
94 elements.push(item);
95 }
96 RqMatrix::new(elements, false)
97 }
98}
99
100impl Mul<&RqVector> for &RqMatrix {
102 type Output = RqVector;
103
104 fn mul(self, rhs: &RqVector) -> Self::Output {
105 let mut result = RqVector::zero(self.elements.len());
106
107 for (i, row) in self.elements.iter().enumerate() {
108 result.set(
109 i,
110 inner_product::compute_linear_combination(row.get_elements(), rhs.get_elements()),
111 );
112 }
113 result
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use rand::rng;
120
121 use super::*;
122 use crate::ring::rq::tests::generate_rq_from_zq_vector;
123 use crate::ring::rq::Rq;
124 use crate::ring::zq::Zq;
125
126 #[test]
127 #[cfg(not(feature = "skip-slow-tests"))]
128 fn rqmatrix_fits_stack() {
129 let mut rng = rand::rng();
130 let _: RqMatrix = RqMatrix::random(&mut rng, 256, 1 << 10);
131 }
132
133 #[test]
134 fn test_set_sell() {
135 let mut matrix = RqMatrix::zero(10, 18);
136 matrix.set_cell(4, 9, Rq::new([Zq::new(10); Rq::DEGREE]));
137 matrix.set_cell(8, 1, Rq::new([Zq::new(3); Rq::DEGREE]));
138
139 for (i, vector) in matrix.get_elements().iter().enumerate() {
140 for (j, poly) in vector.get_elements().iter().enumerate() {
141 if (i == 4) && (j == 9) {
142 assert_eq!(poly, &Rq::new([Zq::new(10); Rq::DEGREE]))
143 } else if (i == 8) && (j == 1) {
144 assert_eq!(poly, &Rq::new([Zq::new(3); Rq::DEGREE]))
145 } else {
146 assert_eq!(poly, &Rq::zero())
147 }
148 }
149 }
150 }
151
152 #[test]
153 fn test_symmetric_matrix() {
154 let symmetric_matrix = RqMatrix::symmetric_random(&mut rng(), 12);
155 assert_eq!(symmetric_matrix.get_row_len(), 12);
156 assert_eq!(symmetric_matrix.get_col_len(), 12);
157 for i in 0..symmetric_matrix.get_row_len() {
158 assert_eq!(symmetric_matrix.get_elements()[i].get_length(), i + 1);
159 }
160 for i in 0..symmetric_matrix.get_row_len() {
161 for j in 0..symmetric_matrix.get_col_len() {
162 assert_eq!(
163 symmetric_matrix.get_cell(i, j),
164 symmetric_matrix.get_cell(j, i)
165 )
166 }
167 }
168 }
169
170 #[test]
171 fn test_rq_matrix_from_iterator() {
172 let expected = vec![
173 RqVector::random(&mut rng(), 5),
174 RqVector::random(&mut rng(), 5),
175 RqVector::random(&mut rng(), 5),
176 RqVector::random(&mut rng(), 5),
177 ];
178 let polynomial_matrix = expected.clone().into_iter();
179 let result: RqMatrix = polynomial_matrix.collect();
180
181 assert_eq!(result.get_elements(), &expected);
182 }
183
184 #[test]
185 fn test_zero_matrix() {
186 pub fn is_polynomial_zero(poly: &Rq) -> bool {
188 poly.get_coefficients()
189 .iter()
190 .all(|&coeff| coeff == Zq::ZERO)
191 }
192
193 let matrix = RqMatrix::zero(10, 20);
194 assert_eq!(matrix.get_row_len(), 10);
195 assert_eq!(matrix.get_col_len(), 20);
196 for row in matrix.get_elements() {
197 for cell in row.get_elements() {
198 assert!(is_polynomial_zero(cell));
199 }
200 }
201 }
202
203 #[test]
204 fn test_rqmartrix_mul() {
205 let poly1: Rq = generate_rq_from_zq_vector(vec![Zq::new(8), Zq::new(6)]);
206 let poly2: Rq = generate_rq_from_zq_vector(vec![-Zq::new(5), -Zq::new(5)]);
207 let poly3: Rq = generate_rq_from_zq_vector(vec![Zq::ONE, Zq::ZERO]);
208 let poly4: Rq = generate_rq_from_zq_vector(vec![Zq::ZERO, Zq::new(4)]);
209 let matrix_1: RqMatrix = RqMatrix::new(vec![RqVector::from(vec![poly1, poly2])], false);
210 let vec_1: RqVector = RqVector::from(vec![poly3, poly4]);
211
212 let result_1 = matrix_1.mul(&vec_1);
213 let expected_poly_1 =
214 generate_rq_from_zq_vector(vec![Zq::new(8), -Zq::new(14), -Zq::new(20)]);
215 let expected_1 = RqVector::from(vec![expected_poly_1]);
216 assert_eq!(result_1, expected_1);
217
218 let poly5: Rq = generate_rq_from_zq_vector(vec![-Zq::new(7), Zq::new(7)]);
219 let poly6: Rq = generate_rq_from_zq_vector(vec![-Zq::new(3), Zq::ZERO]);
220 let poly7: Rq = generate_rq_from_zq_vector(vec![Zq::new(8), -Zq::new(2)]);
221 let poly8: Rq = generate_rq_from_zq_vector(vec![-Zq::new(4), Zq::new(4)]);
222 let poly9: Rq = generate_rq_from_zq_vector(vec![Zq::NEG_ONE, -Zq::new(2)]);
223 let poly10: Rq = generate_rq_from_zq_vector(vec![-Zq::new(3), -Zq::new(3)]);
224 let matrix_2: RqMatrix = RqMatrix::new(
225 vec![
226 RqVector::from(vec![poly5, poly6]),
227 RqVector::from(vec![poly7, poly8]),
228 ],
229 false,
230 );
231 let vec_2: RqVector = RqVector::from(vec![poly9, poly10]);
232
233 let result_2 = matrix_2.mul(&vec_2);
234 let expected_poly_2_1 =
235 generate_rq_from_zq_vector(vec![Zq::new(16), Zq::new(16), -Zq::new(14)]);
236 let expected_poly_2_2 =
237 generate_rq_from_zq_vector(vec![Zq::new(4), -Zq::new(14), -Zq::new(8)]);
238 let expected_2 = RqVector::from(vec![expected_poly_2_1, expected_poly_2_2]);
239 assert_eq!(result_2, expected_2);
240 }
241}