1use crate::ring::Norms;
2use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
3use rand::distr::uniform::{Error, SampleBorrow, SampleUniform, UniformInt, UniformSampler};
4use rand::prelude::*;
5use std::fmt;
6use std::iter::Sum;
7
8#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Default)]
11pub struct Zq {
12 value: u32,
14}
15
16impl Zq {
17 #[allow(clippy::as_conversions)]
19 pub const Q: u32 = u32::MAX;
20
21 pub const ZERO: Self = Self::new(0);
23 pub const ONE: Self = Self::new(1);
24 pub const TWO: Self = Self::new(2);
25 pub const NEG_ONE: Self = Self::new(u32::MAX - 1);
27
28 pub const fn new(value: u32) -> Self {
31 Self { value }
32 }
33
34 pub fn to_u128(&self) -> u128 {
35 u128::from(self.value)
36 }
37
38 pub fn get_value(&self) -> u32 {
39 self.value
40 }
41
42 pub const fn is_zero(&self) -> bool {
43 self.value == 0
44 }
45
46 #[allow(clippy::as_conversions)]
48 pub fn is_larger_than_half(&self) -> bool {
49 self.value > (Self::Q - 1) / 2
50 }
51
52 #[allow(clippy::as_conversions)]
54 pub(crate) fn centered_mod(&self) -> i128 {
55 let bound = Self::Q as i128;
56 let value = self.value as i128;
57
58 if value > (bound - 1) / 2 {
59 value - bound
60 } else {
61 value
62 }
63 }
64
65 pub(crate) fn div_floor_by(&self, rhs: u32) -> Self {
67 assert_ne!(rhs, 0, "division by zero");
68 Self::new(self.value / rhs)
69 }
70
71 pub(crate) fn decompose(&self, bound: Self, num_parts: usize) -> Vec<Zq> {
74 assert!(bound >= Self::TWO, "base must be ≥ 2");
75 assert_ne!(num_parts, 0, "num_parts cannot be zero");
76
77 let mut parts = vec![Self::ZERO; num_parts];
78 let half_bound = bound.div_floor_by(2);
79 let mut abs_self = match self.is_larger_than_half() {
80 true => -(*self),
81 false => *self,
82 };
83
84 for part in &mut parts {
85 let mut remainder = Self::new(abs_self.value % bound.value);
86 if remainder > half_bound {
87 remainder -= bound;
88 }
89 *part = match self.is_larger_than_half() {
90 true => -remainder,
91 false => remainder,
92 };
93 abs_self = Self::new((abs_self - remainder).value / bound.value);
94 if abs_self == Self::ZERO {
95 break;
96 }
97 }
98 parts
99 }
100
101 #[allow(clippy::as_conversions)]
102 fn add_op(self, rhs: Zq) -> Zq {
103 let sum = (self.value as u64 + rhs.value as u64) % Zq::Q as u64;
104 Zq::new(sum as u32)
105 }
106
107 #[allow(clippy::as_conversions)]
108 fn sub_op(self, rhs: Zq) -> Zq {
109 let sub = (self.value as u64 + Zq::Q as u64 - rhs.value as u64) % Zq::Q as u64;
110 Zq::new(sub as u32)
111 }
112
113 #[allow(clippy::as_conversions)]
114 fn mul_op(self, b: Zq) -> Zq {
115 let prod = (self.value as u64 * b.value as u64) % Zq::Q as u64;
116 Zq::new(prod as u32)
117 }
118}
119
120macro_rules! impl_arithmetic {
122 ($trait:ident, $assign_trait:ident, $method:ident, $assign_method:ident, $op:ident) => {
123 impl $trait for Zq {
124 type Output = Self;
125
126 fn $method(self, rhs: Self) -> Self::Output {
127 self.$op(rhs)
128 }
129 }
130
131 impl $assign_trait for Zq {
132 fn $assign_method(&mut self, rhs: Self) {
133 *self = self.$op(rhs);
134 }
135 }
136
137 impl $trait<Zq> for &Zq {
138 type Output = Zq;
139
140 fn $method(self, rhs: Zq) -> Self::Output {
141 self.$op(rhs)
142 }
143 }
144
145 impl $trait<&Zq> for &Zq {
146 type Output = Zq;
147
148 fn $method(self, rhs: &Zq) -> Self::Output {
149 self.$op(*rhs)
150 }
151 }
152 };
153}
154
155impl_arithmetic!(Add, AddAssign, add, add_assign, add_op);
156impl_arithmetic!(Sub, SubAssign, sub, sub_assign, sub_op);
157impl_arithmetic!(Mul, MulAssign, mul, mul_assign, mul_op);
158
159impl Neg for Zq {
161 type Output = Zq;
162
163 fn neg(self) -> Zq {
167 if self.value == 0 {
169 self
170 } else {
171 #[allow(clippy::as_conversions)]
172 Zq::new(Zq::Q - self.get_value())
173 }
174 }
175}
176
177impl fmt::Display for Zq {
178 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179 write!(f, "{} (mod {})", self.value, Zq::Q)
181 }
182}
183
184#[derive(Clone, Copy, Debug)]
185pub struct UniformZq(UniformInt<u32>);
186
187impl UniformSampler for UniformZq {
188 type X = Zq;
189
190 fn new<B1, B2>(low: B1, high: B2) -> Result<Self, Error>
191 where
192 B1: SampleBorrow<Self::X> + Sized,
193 B2: SampleBorrow<Self::X> + Sized,
194 {
195 UniformInt::<u32>::new(low.borrow().value, high.borrow().value).map(UniformZq)
196 }
197 fn new_inclusive<B1, B2>(low: B1, high: B2) -> Result<Self, Error>
198 where
199 B1: SampleBorrow<Self::X> + Sized,
200 B2: SampleBorrow<Self::X> + Sized,
201 {
202 UniformInt::<u32>::new_inclusive(low.borrow().value, high.borrow().value).map(UniformZq)
203 }
204 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
205 Self::X::new(self.0.sample(rng))
206 }
207}
208
209impl SampleUniform for Zq {
210 type Sampler = UniformZq;
211}
212
213impl Sum for Zq {
214 fn sum<I>(iter: I) -> Self
216 where
217 I: Iterator<Item = Zq>,
218 {
219 iter.fold(Zq::ZERO, |acc, x| acc + x)
220 }
221}
222
223pub fn add_assign_two_zq_vectors(lhs: &mut [Zq], rhs: Vec<Zq>) {
225 debug_assert_eq!(lhs.len(), rhs.len(), "vector length mismatch");
226 lhs.iter_mut().zip(rhs).for_each(|(l, r)| *l += r);
227}
228
229impl Norms for [Zq] {
231 type NormType = u128;
232
233 #[allow(clippy::as_conversions)]
234 fn l2_norm_squared(&self) -> Self::NormType {
235 self.iter().fold(0u128, |acc, coeff| {
236 let c = coeff.centered_mod();
237 acc + (c * c) as u128
238 })
239 }
240
241 #[allow(clippy::as_conversions)]
242 fn linf_norm(&self) -> Self::NormType {
243 self.iter()
244 .map(|coeff| coeff.centered_mod().unsigned_abs())
245 .max()
246 .unwrap_or(0)
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[test]
255 fn test_to_u128() {
256 let a = Zq::new(10);
257 let b = a.to_u128();
258 assert_eq!(b, 10u128);
259 }
260
261 #[test]
262 fn test_is_zero() {
263 let a = Zq::new(0);
264 let b = Zq::new(10);
265 assert!(a.is_zero());
266 assert!(!b.is_zero());
267 }
268
269 #[test]
270 fn test_get_value() {
271 let a = Zq::new(1000);
272 assert_eq!(a.get_value(), 1000u32);
273 }
274
275 #[test]
276 fn test_basic_arithmetic() {
277 let a = Zq::new(5);
278 let b = Zq::new(3);
279
280 assert_eq!((a + b).value, 8, "5 + 3 should be 8");
282 assert_eq!((a - b).value, 2, "5 - 3 should be 2");
284 assert_eq!((a * b).value, 15, "5 * 3 should be 15");
286 }
287
288 #[test]
289 fn test_wrapping_arithmetic() {
290 let a = Zq::NEG_ONE;
291 let b = Zq::ONE;
292
293 assert_eq!((a + b).value, 0, "u32::MAX + 1 should wrap to 0");
294 assert_eq!((b - a).value, 2, "1 - u32::MAX should wrap to 2 (mod 2^32)");
295 }
296
297 #[test]
298 fn test_subtraction_edge_cases() {
299 let max = Zq::NEG_ONE;
300 let one = Zq::ONE;
301 let two = Zq::TWO;
302
303 assert_eq!((one - max).value, 2);
304 assert_eq!((two - max).value, 3);
305 assert_eq!((max - max).value, 0);
306 }
307
308 #[test]
309 fn test_multiplication_wrapping() {
310 let a = Zq::new(1 << 31);
311 let two = Zq::TWO;
312
313 assert_eq!((a * two).value, 1, "2^31 * 2 should wrap to 1");
315 }
316
317 #[test]
318 fn test_assignment_operators() {
319 let mut a = Zq::new(5);
320 let b = Zq::new(3);
321
322 a += b;
323 assert_eq!(a.value, 8, "5 += 3 should be 8");
324
325 a -= b;
326 assert_eq!(a.value, 5, "8 -= 3 should be 5");
327
328 a *= b;
329 assert_eq!(a.value, 15, "5 *= 3 should be 15");
330 }
331
332 #[test]
333 fn test_conversion_from_u32() {
334 let a: Zq = Zq::new(5);
335 assert_eq!(a.value, 5, "Conversion from u32 should preserve value");
336 }
337
338 #[test]
339 fn test_negative_arithmetic() {
340 let small = Zq::new(3);
341 let large = Zq::new(5);
342
343 let result = small - large;
345 assert_eq!(result.value, u32::MAX - 2, "3 - 5 should wrap to 2^32 - 2");
346
347 let mut x = Zq::new(10);
349 x -= Zq::new(15);
350 assert_eq!(x.value, u32::MAX - 5, "10 -= 15 should wrap to 2^32 - 5");
351
352 let a = Zq::NEG_ONE; let b = Zq::TWO;
355 assert_eq!(
356 (a * b).value,
357 u32::MAX - 2,
358 "(-1) * 2 should be -2 ≡ 2^32 - 2"
359 );
360 }
361
362 #[test]
363 fn test_display_implementation() {
364 let a = Zq::new(5);
365 let max = Zq::NEG_ONE;
366 assert_eq!(format!("{a}"), format!("5 (mod {})", Zq::Q));
367 assert_eq!(format!("{max}"), format!("4294967294 (mod {})", Zq::Q));
368 }
369
370 #[test]
371 fn test_maximum_element() {
372 dbg!(Zq::NEG_ONE);
373 dbg!(Zq::ZERO);
374 dbg!(Zq::ONE);
375 dbg!(Zq::ZERO - Zq::ONE);
376 assert_eq!(Zq::NEG_ONE, Zq::ZERO - Zq::ONE);
377 }
378
379 #[test]
380 fn test_ord() {
381 let a = Zq::new(100);
382 let b = Zq::new(200);
383 let c = Zq::new(100);
384 let d = Zq::new(400);
385
386 let res_1 = a.cmp(&b);
387 let res_2 = a.cmp(&c);
388 let res_3 = d.cmp(&b);
389 assert!(res_1.is_lt());
390 assert!(res_2.is_eq());
391 assert!(res_3.is_gt());
392 assert_eq!(a, c);
393 assert!(a < b);
394 assert!(d > b);
395 }
396
397 #[test]
398 fn test_neg() {
399 let a = Zq::new(100);
400 let b = Zq::ZERO;
401 let neg_a: Zq = -a;
402 let neg_b: Zq = -b;
403
404 assert_eq!(neg_a + a, Zq::ZERO);
405 assert_eq!(neg_b, Zq::ZERO);
406 }
407
408 #[test]
409 fn test_centered_mod() {
410 let a = -Zq::new(1);
411 assert_eq!(-1, a.centered_mod());
412
413 let a = Zq::new(4294967103);
414 assert_eq!(a, -Zq::new(192));
415 assert_eq!(-192, a.centered_mod());
416 }
417}
418
419#[cfg(test)]
420mod norm_tests {
421 use super::*;
422
423 #[test]
424 fn test_l2_norm() {
425 let zq_vector = [
426 Zq::new(1),
427 Zq::new(2),
428 Zq::new(3),
429 Zq::new(4),
430 Zq::new(5),
431 Zq::new(6),
432 Zq::new(7),
433 ];
434 let res = zq_vector.l2_norm_squared();
435
436 assert_eq!(res, 140);
437 }
438
439 #[test]
440 fn test_l2_norm_with_negative_values() {
441 let zq_vector = [
442 Zq::new(1),
443 Zq::new(2),
444 Zq::new(3),
445 -Zq::new(4),
446 -Zq::new(5),
447 -Zq::new(6),
448 -Zq::new(7),
449 ];
450 let res = zq_vector.l2_norm_squared();
451
452 assert_eq!(res, 140);
453 }
454
455 #[test]
456 fn test_linf_norm() {
457 let zq_vector = [
458 Zq::new(1),
459 Zq::new(200),
460 Zq::new(300),
461 Zq::new(40),
462 -Zq::new(5),
463 -Zq::new(6),
464 -Zq::new(700000),
465 ];
466 let res = zq_vector.linf_norm();
467 assert_eq!(res, 700000);
468
469 let zq_vector = [
470 Zq::new(1000000),
471 Zq::new(200),
472 Zq::new(300),
473 Zq::new(40),
474 -Zq::new(5),
475 -Zq::new(6),
476 -Zq::new(999999),
477 ];
478 let res = zq_vector.linf_norm();
479 assert_eq!(res, 1000000);
480
481 let zq_vector = [
482 Zq::new(1),
483 Zq::new(2),
484 Zq::new(3),
485 -Zq::new(4),
486 Zq::new(0),
487 -Zq::new(3),
488 -Zq::new(2),
489 -Zq::new(1),
490 ];
491 let res = zq_vector.linf_norm();
492 assert_eq!(res, 4);
493 }
494}
495
496#[cfg(test)]
497mod decomposition_tests {
498 use crate::ring::{zq::Zq, Norms};
499
500 #[test]
501 fn test_zq_decomposition() {
502 let (base, parts) = (Zq::new(12), 10);
503 let pos_zq = Zq::new(29);
504 let neg_zq = -Zq::new(29);
505
506 let pos_decomposed = pos_zq.decompose(base, parts);
507 let neg_decomposed = neg_zq.decompose(base, parts);
508
509 assert_eq!(
510 pos_decomposed,
511 vec![
512 Zq::new(5),
513 Zq::new(2),
514 Zq::ZERO,
515 Zq::ZERO,
516 Zq::ZERO,
517 Zq::ZERO,
518 Zq::ZERO,
519 Zq::ZERO,
520 Zq::ZERO,
521 Zq::ZERO
522 ]
523 );
524 assert_eq!(
525 neg_decomposed,
526 vec![
527 -Zq::new(5),
528 -Zq::new(2),
529 Zq::ZERO,
530 Zq::ZERO,
531 Zq::ZERO,
532 Zq::ZERO,
533 Zq::ZERO,
534 Zq::ZERO,
535 Zq::ZERO,
536 Zq::ZERO
537 ]
538 );
539 }
540
541 #[test]
542 fn test_zq_recompositoin() {
543 let (base, parts) = (Zq::new(1802), 10);
544 let pos_zq = -Zq::new(16200);
545
546 let pos_decomposed = pos_zq.decompose(base, parts);
547 let mut exponensial_base = Zq::new(1);
548 let mut result = Zq::new(0);
549 for decomposed_part in pos_decomposed {
550 result += decomposed_part * exponensial_base;
551 exponensial_base *= base;
552 }
553 assert_eq!(result, pos_zq)
554 }
555
556 #[test]
557 fn test_zq_recompositoin_positive() {
558 let (base, parts) = (Zq::new(1802), 10);
559 let pos_zq = Zq::new(23071);
560
561 let pos_decomposed = pos_zq.decompose(base, parts);
562 let mut exponensial_base = Zq::new(1);
563 let mut result = Zq::new(0);
564 for decomposed_part in pos_decomposed {
565 result += decomposed_part * exponensial_base;
566 exponensial_base *= base;
567 }
568 assert_eq!(result, pos_zq)
569 }
570
571 #[test]
572 fn test_linf_norm() {
573 let (base, parts) = (Zq::new(1802), 10);
574 let pos_zq = Zq::new(16200);
575
576 let pos_decomposed = pos_zq.decompose(base, parts);
577 dbg!(&pos_decomposed);
578 assert!(pos_decomposed.linf_norm() <= 901);
579 }
580}