1use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2use rand::distr::uniform::{Error, SampleBorrow, SampleUniform, UniformInt, UniformSampler};
3use rand::prelude::*;
4use std::fmt;
5#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Default)]
8pub struct Zq {
9 value: u32,
11}
12
13impl Zq {
14 pub const Q: u32 = u32::MAX.wrapping_add(1);
16 pub const ZERO: Self = Self::new(0);
18 pub const ONE: Self = Self::new(1);
20 pub const TWO: Self = Self::new(2);
22 pub const MAX: Self = Self::new(u32::MAX);
24
25 pub const fn new(value: u32) -> Self {
28 Self { value }
29 }
30
31 pub fn to_u128(&self) -> u128 {
32 u128::from(self.value)
33 }
34
35 pub const fn is_zero(&self) -> bool {
36 self.value == 0
37 }
38
39 #[allow(clippy::as_conversions)]
40 pub fn get_value(&self) -> usize {
41 self.value as usize
42 }
43
44 pub(crate) fn centered_mod(&self, bound: Self) -> Self {
51 assert!(
52 bound != Zq::ZERO,
53 "cannot get centered representative modulo for zero bound"
54 );
55 let bounded_coeff = Self::new(self.value % bound.value);
56 let half_bound = bound.scale_by(Self::TWO);
57
58 if bounded_coeff > half_bound {
59 bounded_coeff - bound
60 } else {
61 bounded_coeff
62 }
63 }
64
65 pub(crate) fn scale_by(&self, rhs: Self) -> Self {
75 assert!(rhs != Zq::ZERO, "cannot scale by zero");
76 Self::new(self.value / rhs.value)
77 }
78}
79
80macro_rules! impl_arithmetic {
82 ($trait:ident, $assign_trait:ident, $method:ident, $assign_method:ident, $op:ident) => {
83 impl $trait for Zq {
84 type Output = Self;
85
86 fn $method(self, rhs: Self) -> Self::Output {
87 Self::new(self.value.$op(rhs.value))
88 }
89 }
90
91 impl $assign_trait for Zq {
92 fn $assign_method(&mut self, rhs: Self) {
93 self.value = self.value.$op(rhs.value);
94 }
95 }
96
97 impl $trait<Zq> for &Zq {
98 type Output = Zq;
99
100 fn $method(self, rhs: Zq) -> Self::Output {
101 Zq::new(self.value.$op(rhs.value))
102 }
103 }
104 };
105}
106
107impl_arithmetic!(Add, AddAssign, add, add_assign, wrapping_add);
108impl_arithmetic!(Sub, SubAssign, sub, sub_assign, wrapping_sub);
109impl_arithmetic!(Mul, MulAssign, mul, mul_assign, wrapping_mul);
110
111impl From<u32> for Zq {
112 fn from(value: u32) -> Self {
113 Self::new(value)
114 }
115}
116
117impl fmt::Display for Zq {
118 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
119 write!(f, "{} (mod 2^32)", self.value)
121 }
122}
123
124#[derive(Clone, Copy, Debug)]
125pub struct UniformZq(UniformInt<u32>);
126
127impl UniformSampler for UniformZq {
128 type X = Zq;
129
130 fn new<B1, B2>(low: B1, high: B2) -> Result<Self, Error>
131 where
132 B1: SampleBorrow<Self::X> + Sized,
133 B2: SampleBorrow<Self::X> + Sized,
134 {
135 UniformInt::<u32>::new(low.borrow().value, high.borrow().value).map(UniformZq)
136 }
137 fn new_inclusive<B1, B2>(low: B1, high: B2) -> Result<Self, Error>
138 where
139 B1: SampleBorrow<Self::X> + Sized,
140 B2: SampleBorrow<Self::X> + Sized,
141 {
142 UniformInt::<u32>::new_inclusive(low.borrow().value, high.borrow().value).map(UniformZq)
143 }
144 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
145 self.0.sample(rng).into()
146 }
147}
148
149impl SampleUniform for Zq {
150 type Sampler = UniformZq;
151}
152
153impl Neg for Zq {
155 type Output = Zq;
156
157 fn neg(self) -> Zq {
161 if self.value == 0 {
163 self
164 } else {
165 Zq::MAX + Zq::ONE - self
166 }
167 }
168}
169
170pub trait ZqVector {
171 fn random<R: Rng + CryptoRng>(rng: &mut R, n: usize) -> Self;
172 fn conjugate_automorphism(&self) -> Self;
173 fn multiply(&self, other: &Self) -> Self;
174 fn add(&self, other: &Self) -> Self;
175 fn inner_product(&self, other: &Self) -> Zq;
176}
177
178impl ZqVector for Vec<Zq> {
179 fn random<R: Rng + CryptoRng>(rng: &mut R, n: usize) -> Self {
180 let uniform = UniformZq::new_inclusive(Zq::ZERO, Zq::MAX).unwrap();
182 (0..n).map(|_| uniform.sample(rng)).collect()
183 }
184
185 fn add(&self, other: &Self) -> Self {
187 let max_degree = self.len().max(other.len());
188 let mut coeffs = vec![Zq::ZERO; max_degree];
189 for (i, coeff) in coeffs.iter_mut().enumerate().take(max_degree) {
190 if i < self.len() {
191 *coeff += self[i];
192 }
193 if i < other.len() {
194 *coeff += other[i];
195 }
196 }
197 coeffs
198 }
199
200 fn multiply(&self, other: &Vec<Zq>) -> Vec<Zq> {
205 let mut result_coefficients = vec![Zq::new(0); self.len() + other.len() - 1];
206 for (i, &coeff1) in self.iter().enumerate() {
207 for (j, &coeff2) in other.iter().enumerate() {
208 result_coefficients[i + j] += coeff1 * coeff2;
209 }
210 }
211
212 if result_coefficients.len() > self.len() {
213 let q_minus_1 = Zq::MAX;
214 let (left, right) = result_coefficients.split_at_mut(self.len());
215 for (i, &overflow) in right.iter().enumerate() {
216 left[i] += overflow * q_minus_1;
217 }
218 result_coefficients.truncate(self.len());
219 }
220 result_coefficients
221 }
222
223 fn inner_product(&self, other: &Self) -> Zq {
225 self.iter()
226 .zip(other.iter())
227 .map(|(&a, &b)| a * b)
228 .fold(Zq::ZERO, |acc, x| acc + x)
229 }
230
231 fn conjugate_automorphism(&self) -> Vec<Zq> {
233 let q_minus_1 = Zq::MAX;
234 let mut new_coeffs = vec![Zq::ZERO; self.len()];
235 for (i, new_coeff) in new_coeffs.iter_mut().enumerate().take(self.len()) {
236 if i < self.len() {
237 if i == 0 {
238 *new_coeff = self[i];
239 } else {
240 *new_coeff = self[i] * q_minus_1;
241 }
242 } else {
243 *new_coeff = Zq::ZERO;
244 }
245 }
246 let reversed_coefficients = new_coeffs
247 .iter()
248 .take(1)
249 .cloned()
250 .chain(new_coeffs.iter().skip(1).rev().cloned())
251 .collect::<Vec<Zq>>();
252
253 reversed_coefficients
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260
261 #[test]
262 fn test_basic_arithmetic() {
263 let a = Zq::new(5);
264 let b = Zq::new(3);
265
266 assert_eq!((a + b).value, 8, "5 + 3 should be 8");
268 assert_eq!((a - b).value, 2, "5 - 3 should be 2");
270 assert_eq!((a * b).value, 15, "5 * 3 should be 15");
272 }
273
274 #[test]
275 fn test_wrapping_arithmetic() {
276 let a = Zq::MAX;
277 let b = Zq::ONE;
278
279 assert_eq!((a + b).value, 0, "u32::MAX + 1 should wrap to 0");
280 assert_eq!((b - a).value, 2, "1 - u32::MAX should wrap to 2 (mod 2^32)");
281 }
282
283 #[test]
284 fn test_subtraction_edge_cases() {
285 let max = Zq::MAX;
286 let one = Zq::ONE;
287 let two = Zq::TWO;
288
289 assert_eq!((one - max).value, 2);
290 assert_eq!((two - max).value, 3);
291 assert_eq!((max - max).value, 0);
292 }
293
294 #[test]
295 fn test_multiplication_wrapping() {
296 let a = Zq::new(1 << 31);
297 let two = Zq::TWO;
298
299 assert_eq!((a * two).value, 0, "2^31 * 2 should wrap to 0");
301 }
302
303 #[test]
304 fn test_assignment_operators() {
305 let mut a = Zq::new(5);
306 let b = Zq::new(3);
307
308 a += b;
309 assert_eq!(a.value, 8, "5 += 3 should be 8");
310
311 a -= b;
312 assert_eq!(a.value, 5, "8 -= 3 should be 5");
313
314 a *= b;
315 assert_eq!(a.value, 15, "5 *= 3 should be 15");
316 }
317
318 #[test]
319 fn test_conversion_from_u32() {
320 let a: Zq = 5_u32.into();
321 assert_eq!(a.value, 5, "Conversion from u32 should preserve value");
322 }
323
324 #[test]
325 fn test_negative_arithmetic() {
326 let small = Zq::new(3);
327 let large = Zq::new(5);
328
329 let result = small - large;
331 assert_eq!(result.value, u32::MAX - 1, "3 - 5 should wrap to 2^32 - 2");
332
333 let mut x = Zq::new(10);
335 x -= Zq::new(15);
336 assert_eq!(x.value, u32::MAX - 4, "10 -= 15 should wrap to 2^32 - 5");
337
338 let a = Zq::MAX; let b = Zq::TWO;
341 assert_eq!(
342 (a * b).value,
343 u32::MAX - 1,
344 "(-1) * 2 should be -2 ≡ 2^32 - 2"
345 );
346 }
347
348 #[test]
349 fn test_display_implementation() {
350 let a = Zq::new(5);
351 let max = Zq::MAX;
352
353 assert_eq!(format!("{}", a), "5 (mod 2^32)");
354 assert_eq!(format!("{}", max), "4294967295 (mod 2^32)");
355 }
356
357 #[test]
358 fn test_maximum_element() {
359 assert_eq!(Zq::MAX, Zq::ZERO - Zq::ONE);
360 }
361
362 #[test]
363 fn test_ord() {
364 let a = Zq::new(100);
365 let b = Zq::new(200);
366 let c = Zq::new(100);
367 let d = Zq::new(400);
368
369 let res_1 = a.cmp(&b);
370 let res_2 = a.cmp(&c);
371 let res_3 = d.cmp(&b);
372 assert!(res_1.is_lt());
373 assert!(res_2.is_eq());
374 assert!(res_3.is_gt());
375 assert_eq!(a, c);
376 assert!(a < b);
377 assert!(d > b);
378 }
379
380 #[test]
381 fn test_neg() {
382 let a = Zq::new(100);
383 let b = Zq::ZERO;
384 let neg_a: Zq = -a;
385 let neg_b: Zq = -b;
386
387 assert_eq!(neg_a + a, Zq::ZERO);
388 assert_eq!(neg_b, Zq::ZERO);
389 }
390}