1use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2use rand::distr::uniform::{Error, SampleBorrow, SampleUniform, UniformInt, UniformSampler};
3use rand::prelude::*;
4use std::fmt;
5use std::iter::Sum;
6#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Default)]
9pub struct Zq {
10 value: u32,
12}
13
14impl Zq {
15 #[allow(clippy::as_conversions)]
17 pub const Q: u64 = u32::MAX as u64;
18 pub const ZERO: Self = Self::new(0);
20 pub const ONE: Self = Self::new(1);
22 pub const TWO: Self = Self::new(2);
24 pub const NEG_ONE: Self = Self::new(u32::MAX - 1);
26
27 pub const fn new(value: u32) -> Self {
30 Self { value }
31 }
32
33 pub fn to_u128(&self) -> u128 {
34 u128::from(self.value)
35 }
36
37 pub const fn is_zero(&self) -> bool {
38 self.value == 0
39 }
40
41 pub fn get_value(&self) -> u32 {
42 self.value
43 }
44
45 pub(crate) fn centered_mod(&self, bound: Self) -> Self {
52 assert!(
53 bound != Zq::ZERO,
54 "cannot get centered representative modulo for zero bound"
55 );
56 let bounded_coeff = Self::new(self.value % bound.value);
57 let half_bound = bound.scale_by(Self::TWO);
58
59 if bounded_coeff > half_bound {
60 bounded_coeff - bound
61 } else {
62 bounded_coeff
63 }
64 }
65
66 pub(crate) fn scale_by(&self, rhs: Self) -> Self {
76 assert!(rhs != Zq::ZERO, "cannot scale by zero");
77 Self::new(self.value / rhs.value)
78 }
79
80 #[allow(clippy::as_conversions)]
81 fn add_op(self, rhs: Zq) -> Zq {
82 let sum = (self.value as u64 + rhs.value as u64) % Zq::Q;
83 Zq::new(sum as u32)
84 }
85
86 #[allow(clippy::as_conversions)]
87 fn sub_op(self, rhs: Zq) -> Zq {
88 let sub = (self.value as u64 + Zq::Q - rhs.value as u64) % Zq::Q;
89 Zq::new(sub as u32)
90 }
91
92 #[allow(clippy::as_conversions)]
93 fn mul_op(self, b: Zq) -> Zq {
94 let prod = (self.value as u64 * b.value as u64) % Zq::Q;
95 Zq::new(prod as u32)
96 }
97}
98
99macro_rules! impl_arithmetic {
101 ($trait:ident, $assign_trait:ident, $method:ident, $assign_method:ident, $op:ident) => {
102 impl $trait for Zq {
103 type Output = Self;
104
105 fn $method(self, rhs: Self) -> Self::Output {
106 self.$op(rhs)
107 }
108 }
109
110 impl $assign_trait for Zq {
111 fn $assign_method(&mut self, rhs: Self) {
112 *self = self.$op(rhs);
113 }
114 }
115
116 impl $trait<Zq> for &Zq {
117 type Output = Zq;
118
119 fn $method(self, rhs: Zq) -> Self::Output {
120 self.$op(rhs)
121 }
122 }
123
124 impl $trait<&Zq> for &Zq {
125 type Output = Zq;
126
127 fn $method(self, rhs: &Zq) -> Self::Output {
128 self.$op(*rhs)
129 }
130 }
131 };
132}
133
134impl_arithmetic!(Add, AddAssign, add, add_assign, add_op);
135impl_arithmetic!(Sub, SubAssign, sub, sub_assign, sub_op);
136impl_arithmetic!(Mul, MulAssign, mul, mul_assign, mul_op);
137
138impl Neg for Zq {
140 type Output = Zq;
141
142 fn neg(self) -> Zq {
146 if self.value == 0 {
148 self
149 } else {
150 #[allow(clippy::as_conversions)]
151 Zq::new(Zq::Q as u32 - self.get_value())
152 }
153 }
154}
155
156impl fmt::Display for Zq {
157 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158 write!(f, "{} (mod {})", self.value, Zq::Q)
160 }
161}
162
163#[derive(Clone, Copy, Debug)]
164pub struct UniformZq(UniformInt<u32>);
165
166impl UniformSampler for UniformZq {
167 type X = Zq;
168
169 fn new<B1, B2>(low: B1, high: B2) -> Result<Self, Error>
170 where
171 B1: SampleBorrow<Self::X> + Sized,
172 B2: SampleBorrow<Self::X> + Sized,
173 {
174 UniformInt::<u32>::new(low.borrow().value, high.borrow().value).map(UniformZq)
175 }
176 fn new_inclusive<B1, B2>(low: B1, high: B2) -> Result<Self, Error>
177 where
178 B1: SampleBorrow<Self::X> + Sized,
179 B2: SampleBorrow<Self::X> + Sized,
180 {
181 UniformInt::<u32>::new_inclusive(low.borrow().value, high.borrow().value).map(UniformZq)
182 }
183 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
184 Self::X::new(self.0.sample(rng))
185 }
186}
187
188impl SampleUniform for Zq {
189 type Sampler = UniformZq;
190}
191
192impl Sum for Zq {
193 fn sum<I>(iter: I) -> Self
195 where
196 I: Iterator<Item = Zq>,
197 {
198 iter.fold(Zq::ZERO, |acc, x| acc + x)
199 }
200}
201
202pub fn add_assign_two_zq_vectors(first: &mut [Zq], second: Vec<Zq>) {
203 first
204 .iter_mut()
205 .zip(second)
206 .for_each(|(first_coeff, second_coeff)| *first_coeff += second_coeff);
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
214 fn test_to_u128() {
215 let a = Zq::new(10);
216 let b = a.to_u128();
217 assert_eq!(b, 10u128);
218 }
219
220 #[test]
221 fn test_is_zero() {
222 let a = Zq::new(0);
223 let b = Zq::new(10);
224 assert!(a.is_zero());
225 assert!(!b.is_zero());
226 }
227
228 #[test]
229 fn test_get_value() {
230 let a = Zq::new(1000);
231 assert_eq!(a.get_value(), 1000u32);
232 }
233
234 #[test]
235 fn test_basic_arithmetic() {
236 let a = Zq::new(5);
237 let b = Zq::new(3);
238
239 assert_eq!((a + b).value, 8, "5 + 3 should be 8");
241 assert_eq!((a - b).value, 2, "5 - 3 should be 2");
243 assert_eq!((a * b).value, 15, "5 * 3 should be 15");
245 }
246
247 #[test]
248 fn test_wrapping_arithmetic() {
249 let a = Zq::NEG_ONE;
250 let b = Zq::ONE;
251
252 assert_eq!((a + b).value, 0, "u32::MAX + 1 should wrap to 0");
253 assert_eq!((b - a).value, 2, "1 - u32::MAX should wrap to 2 (mod 2^32)");
254 }
255
256 #[test]
257 fn test_subtraction_edge_cases() {
258 let max = Zq::NEG_ONE;
259 let one = Zq::ONE;
260 let two = Zq::TWO;
261
262 assert_eq!((one - max).value, 2);
263 assert_eq!((two - max).value, 3);
264 assert_eq!((max - max).value, 0);
265 }
266
267 #[test]
268 fn test_multiplication_wrapping() {
269 let a = Zq::new(1 << 31);
270 let two = Zq::TWO;
271
272 assert_eq!((a * two).value, 1, "2^31 * 2 should wrap to 1");
274 }
275
276 #[test]
277 fn test_assignment_operators() {
278 let mut a = Zq::new(5);
279 let b = Zq::new(3);
280
281 a += b;
282 assert_eq!(a.value, 8, "5 += 3 should be 8");
283
284 a -= b;
285 assert_eq!(a.value, 5, "8 -= 3 should be 5");
286
287 a *= b;
288 assert_eq!(a.value, 15, "5 *= 3 should be 15");
289 }
290
291 #[test]
292 fn test_conversion_from_u32() {
293 let a: Zq = Zq::new(5);
294 assert_eq!(a.value, 5, "Conversion from u32 should preserve value");
295 }
296
297 #[test]
298 fn test_negative_arithmetic() {
299 let small = Zq::new(3);
300 let large = Zq::new(5);
301
302 let result = small - large;
304 assert_eq!(result.value, u32::MAX - 2, "3 - 5 should wrap to 2^32 - 2");
305
306 let mut x = Zq::new(10);
308 x -= Zq::new(15);
309 assert_eq!(x.value, u32::MAX - 5, "10 -= 15 should wrap to 2^32 - 5");
310
311 let a = Zq::NEG_ONE; let b = Zq::TWO;
314 assert_eq!(
315 (a * b).value,
316 u32::MAX - 2,
317 "(-1) * 2 should be -2 ≡ 2^32 - 2"
318 );
319 }
320
321 #[test]
322 fn test_display_implementation() {
323 let a = Zq::new(5);
324 let max = Zq::NEG_ONE;
325 assert_eq!(format!("{}", a), format!("5 (mod {})", Zq::Q));
326 assert_eq!(format!("{}", max), format!("4294967294 (mod {})", Zq::Q));
327 }
328
329 #[test]
330 fn test_maximum_element() {
331 dbg!(Zq::NEG_ONE);
332 dbg!(Zq::ZERO);
333 dbg!(Zq::ONE);
334 dbg!(Zq::ZERO - Zq::ONE);
335 assert_eq!(Zq::NEG_ONE, Zq::ZERO - Zq::ONE);
336 }
337
338 #[test]
339 fn test_ord() {
340 let a = Zq::new(100);
341 let b = Zq::new(200);
342 let c = Zq::new(100);
343 let d = Zq::new(400);
344
345 let res_1 = a.cmp(&b);
346 let res_2 = a.cmp(&c);
347 let res_3 = d.cmp(&b);
348 assert!(res_1.is_lt());
349 assert!(res_2.is_eq());
350 assert!(res_3.is_gt());
351 assert_eq!(a, c);
352 assert!(a < b);
353 assert!(d > b);
354 }
355
356 #[test]
357 fn test_neg() {
358 let a = Zq::new(100);
359 let b = Zq::ZERO;
360 let neg_a: Zq = -a;
361 let neg_b: Zq = -b;
362
363 assert_eq!(neg_a + a, Zq::ZERO);
364 assert_eq!(neg_b, Zq::ZERO);
365 }
366}