labrador/ring/
rq_matrix.rs1use crate::ring::rq_vector::RqVector;
2use rand::{CryptoRng, Rng};
3use std::ops::Mul;
4
5use super::{rq::Rq, zq::Zq};
6
7#[derive(Debug, Clone)]
9pub struct RqMatrix {
10 elements: Vec<RqVector>,
11}
12
13impl RqMatrix {
14 pub const fn new(elements: Vec<RqVector>) -> Self {
16 RqMatrix { elements }
17 }
18
19 pub fn zero(row_len: usize, col_len: usize) -> Self {
20 RqMatrix::new(vec![RqVector::zero(col_len); row_len])
21 }
22
23 pub fn get_row_len(&self) -> usize {
24 self.elements.len()
25 }
26
27 pub fn get_col_len(&self) -> usize {
28 self.elements[0].get_length()
29 }
30
31 pub fn random<R: Rng + CryptoRng>(rng: &mut R, row_len: usize, col_len: usize) -> Self {
33 Self {
34 elements: (0..row_len)
35 .map(|_| RqVector::random(rng, col_len))
36 .collect(),
37 }
38 }
39
40 pub fn get_cell_symmetric(&self, row: usize, col: usize) -> Rq {
41 if row >= col {
42 self.elements[row].get_elements()[col].clone()
43 } else {
44 self.elements[col].get_elements()[row].clone()
45 }
46 }
47
48 pub fn random_ternary<R: Rng + CryptoRng>(rng: &mut R, row_len: usize, col_len: usize) -> Self {
50 Self {
51 elements: (0..row_len)
52 .map(|_| RqVector::random_ternary(rng, col_len))
53 .collect(),
54 }
55 }
56
57 pub fn get_elements(&self) -> &Vec<RqVector> {
58 &self.elements
59 }
60
61 pub fn decompose_each_cell(&self, base: Zq, num_parts: usize) -> RqVector {
62 let mut decomposed_vec = Vec::new();
63 for ring_vector in self.get_elements() {
64 for ring in ring_vector.get_elements() {
65 decomposed_vec.append(&mut ring.decompose(base, num_parts).get_elements().clone());
66 }
67 }
68 RqVector::new(decomposed_vec)
69 }
70}
71
72impl FromIterator<RqVector> for RqMatrix {
73 fn from_iter<T: IntoIterator<Item = RqVector>>(iter: T) -> Self {
74 let mut elements = Vec::new();
75 for item in iter {
76 elements.push(item);
77 }
78 RqMatrix::new(elements)
79 }
80}
81
82impl Mul<&RqVector> for &RqMatrix {
84 type Output = RqVector;
85
86 fn mul(self, rhs: &RqVector) -> Self::Output {
87 let mut result = RqVector::zero(self.elements.len());
88
89 for (i, row) in self.elements.iter().enumerate() {
90 result[i] = row * rhs;
91 }
92
93 result
94 }
95}
96
97impl Mul<&RqVector> for RqMatrix {
99 type Output = RqVector;
100
101 fn mul(self, rhs: &RqVector) -> Self::Output {
102 &self * rhs
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use crate::ring::rq::Rq;
110 use crate::ring::zq::Zq;
111
112 #[test]
113 #[cfg(not(feature = "skip-slow-tests"))]
114 fn rqmatrix_fits_stack() {
115 let mut rng = rand::rng();
116 let _: RqMatrix = RqMatrix::random(&mut rng, 256, 1 << 10);
117 }
118
119 #[test]
120 fn test_zero_matrix() {
121 let matrix = RqMatrix::zero(10, 20);
122 assert_eq!(matrix.get_row_len(), 10);
123 assert_eq!(matrix.get_col_len(), 20);
124 for row in matrix.get_elements() {
125 for cell in row.get_elements() {
126 assert!(cell.is_zero());
127 }
128 }
129 }
130
131 #[test]
132 fn test_rqmartrix_mul() {
133 let poly1: Rq = vec![Zq::new(8), Zq::new(6)].into();
134 let poly2: Rq = vec![Zq::new(u32::MAX - 4), Zq::new(u32::MAX - 4)].into();
135 let poly3: Rq = vec![Zq::ONE, Zq::ZERO].into();
136 let poly4: Rq = vec![Zq::ZERO, Zq::new(4)].into();
137 let matrix_1: RqMatrix = RqMatrix::new(vec![RqVector::from(vec![poly1, poly2])]);
138 let vec_1: RqVector = RqVector::from(vec![poly3, poly4]);
139
140 let result_1 = matrix_1.mul(&vec_1);
141 let expected_poly_1 =
142 vec![Zq::new(8), Zq::new(u32::MAX - 13), Zq::new(u32::MAX - 19)].into();
143 let expected_1 = RqVector::from(vec![expected_poly_1]);
144 assert_eq!(result_1, expected_1);
145
146 let poly5: Rq = vec![Zq::new(u32::MAX - 6), Zq::new(7)].into();
147 let poly6: Rq = vec![Zq::new(u32::MAX - 2), Zq::ZERO].into();
148 let poly7: Rq = vec![Zq::new(8), Zq::new(u32::MAX - 1)].into();
149 let poly8: Rq = vec![Zq::new(u32::MAX - 3), Zq::new(4)].into();
150 let poly9: Rq = vec![Zq::MAX, Zq::new(u32::MAX - 1)].into();
151 let poly10: Rq = vec![Zq::new(u32::MAX - 2), Zq::new(u32::MAX - 2)].into();
152 let matrix_2: RqMatrix = RqMatrix::new(vec![
153 RqVector::from(vec![poly5, poly6]),
154 RqVector::from(vec![poly7, poly8]),
155 ]);
156 let vec_2: RqVector = RqVector::from(vec![poly9, poly10]);
157
158 let result_2 = matrix_2.mul(&vec_2);
159 let expected_poly_2_1 = vec![Zq::new(16), Zq::new(16), Zq::new(u32::MAX - 13)].into();
160 let expected_poly_2_2 =
161 vec![Zq::new(4), Zq::new(u32::MAX - 13), Zq::new(u32::MAX - 7)].into();
162 let expected_2 = RqVector::from(vec![expected_poly_2_1, expected_poly_2_2]);
163 assert_eq!(result_2, expected_2);
164 }
165}