1use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2use rand::distr::uniform::{Error, SampleBorrow, SampleUniform, UniformInt, UniformSampler};
3use rand::prelude::*;
4use std::fmt;
5#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Default)]
8pub struct Zq {
9 value: u32,
11}
12
13impl Zq {
14 pub const Q: u32 = u32::MAX.wrapping_add(1);
16 pub const ZERO: Self = Self::new(0);
18 pub const ONE: Self = Self::new(1);
20 pub const TWO: Self = Self::new(2);
22 pub const MAX: Self = Self::new(u32::MAX);
24
25 pub const fn new(value: u32) -> Self {
28 Self { value }
29 }
30
31 pub fn to_u128(&self) -> u128 {
32 u128::from(self.value)
33 }
34
35 pub const fn is_zero(&self) -> bool {
36 self.value == 0
37 }
38
39 pub(crate) fn centered_mod(&self, bound: Self) -> Self {
46 assert!(
47 bound != Zq::ZERO,
48 "cannot get centered representative modulo for zero bound"
49 );
50 let bounded_coeff = Self::new(self.value % bound.value);
51 let half_bound = bound.scale_by(Self::TWO);
52
53 if bounded_coeff > half_bound {
54 bounded_coeff - bound
55 } else {
56 bounded_coeff
57 }
58 }
59
60 pub(crate) fn scale_by(&self, rhs: Self) -> Self {
70 assert!(rhs != Zq::ZERO, "cannot scale by zero");
71 Self::new(self.value / rhs.value)
72 }
73}
74
75macro_rules! impl_arithmetic {
77 ($trait:ident, $assign_trait:ident, $method:ident, $assign_method:ident, $op:ident) => {
78 impl $trait for Zq {
79 type Output = Self;
80
81 fn $method(self, rhs: Self) -> Self::Output {
82 Self::new(self.value.$op(rhs.value))
83 }
84 }
85
86 impl $assign_trait for Zq {
87 fn $assign_method(&mut self, rhs: Self) {
88 self.value = self.value.$op(rhs.value);
89 }
90 }
91
92 impl $trait<Zq> for &Zq {
93 type Output = Zq;
94
95 fn $method(self, rhs: Zq) -> Self::Output {
96 Zq::new(self.value.$op(rhs.value))
97 }
98 }
99 };
100}
101
102impl_arithmetic!(Add, AddAssign, add, add_assign, wrapping_add);
103impl_arithmetic!(Sub, SubAssign, sub, sub_assign, wrapping_sub);
104impl_arithmetic!(Mul, MulAssign, mul, mul_assign, wrapping_mul);
105
106impl From<u32> for Zq {
107 fn from(value: u32) -> Self {
108 Self::new(value)
109 }
110}
111
112impl fmt::Display for Zq {
113 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114 write!(f, "{} (mod 2^32)", self.value)
116 }
117}
118
119#[derive(Clone, Copy, Debug)]
120pub struct UniformZq(UniformInt<u32>);
121
122impl UniformSampler for UniformZq {
123 type X = Zq;
124
125 fn new<B1, B2>(low: B1, high: B2) -> Result<Self, Error>
126 where
127 B1: SampleBorrow<Self::X> + Sized,
128 B2: SampleBorrow<Self::X> + Sized,
129 {
130 UniformInt::<u32>::new(low.borrow().value, high.borrow().value).map(UniformZq)
131 }
132 fn new_inclusive<B1, B2>(low: B1, high: B2) -> Result<Self, Error>
133 where
134 B1: SampleBorrow<Self::X> + Sized,
135 B2: SampleBorrow<Self::X> + Sized,
136 {
137 UniformInt::<u32>::new_inclusive(low.borrow().value, high.borrow().value).map(UniformZq)
138 }
139 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
140 self.0.sample(rng).into()
141 }
142}
143
144impl SampleUniform for Zq {
145 type Sampler = UniformZq;
146}
147
148impl Neg for Zq {
150 type Output = Zq;
151
152 fn neg(self) -> Zq {
156 if self.value == 0 {
158 self
159 } else {
160 Zq::MAX + Zq::ONE - self
161 }
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 #[test]
170 fn test_basic_arithmetic() {
171 let a = Zq::new(5);
172 let b = Zq::new(3);
173
174 assert_eq!((a + b).value, 8, "5 + 3 should be 8");
176 assert_eq!((a - b).value, 2, "5 - 3 should be 2");
178 assert_eq!((a * b).value, 15, "5 * 3 should be 15");
180 }
181
182 #[test]
183 fn test_wrapping_arithmetic() {
184 let a = Zq::MAX;
185 let b = Zq::ONE;
186
187 assert_eq!((a + b).value, 0, "u32::MAX + 1 should wrap to 0");
188 assert_eq!((b - a).value, 2, "1 - u32::MAX should wrap to 2 (mod 2^32)");
189 }
190
191 #[test]
192 fn test_subtraction_edge_cases() {
193 let max = Zq::MAX;
194 let one = Zq::ONE;
195 let two = Zq::TWO;
196
197 assert_eq!((one - max).value, 2);
198 assert_eq!((two - max).value, 3);
199 assert_eq!((max - max).value, 0);
200 }
201
202 #[test]
203 fn test_multiplication_wrapping() {
204 let a = Zq::new(1 << 31);
205 let two = Zq::TWO;
206
207 assert_eq!((a * two).value, 0, "2^31 * 2 should wrap to 0");
209 }
210
211 #[test]
212 fn test_assignment_operators() {
213 let mut a = Zq::new(5);
214 let b = Zq::new(3);
215
216 a += b;
217 assert_eq!(a.value, 8, "5 += 3 should be 8");
218
219 a -= b;
220 assert_eq!(a.value, 5, "8 -= 3 should be 5");
221
222 a *= b;
223 assert_eq!(a.value, 15, "5 *= 3 should be 15");
224 }
225
226 #[test]
227 fn test_conversion_from_u32() {
228 let a: Zq = 5_u32.into();
229 assert_eq!(a.value, 5, "Conversion from u32 should preserve value");
230 }
231
232 #[test]
233 fn test_negative_arithmetic() {
234 let small = Zq::new(3);
235 let large = Zq::new(5);
236
237 let result = small - large;
239 assert_eq!(result.value, u32::MAX - 1, "3 - 5 should wrap to 2^32 - 2");
240
241 let mut x = Zq::new(10);
243 x -= Zq::new(15);
244 assert_eq!(x.value, u32::MAX - 4, "10 -= 15 should wrap to 2^32 - 5");
245
246 let a = Zq::MAX; let b = Zq::TWO;
249 assert_eq!(
250 (a * b).value,
251 u32::MAX - 1,
252 "(-1) * 2 should be -2 ≡ 2^32 - 2"
253 );
254 }
255
256 #[test]
257 fn test_display_implementation() {
258 let a = Zq::new(5);
259 let max = Zq::MAX;
260
261 assert_eq!(format!("{}", a), "5 (mod 2^32)");
262 assert_eq!(format!("{}", max), "4294967295 (mod 2^32)");
263 }
264
265 #[test]
266 fn test_maximum_element() {
267 assert_eq!(Zq::MAX, Zq::ZERO - Zq::ONE);
268 }
269
270 #[test]
271 fn test_ord() {
272 let a = Zq::new(100);
273 let b = Zq::new(200);
274 let c = Zq::new(100);
275 let d = Zq::new(400);
276
277 let res_1 = a.cmp(&b);
278 let res_2 = a.cmp(&c);
279 let res_3 = d.cmp(&b);
280 assert!(res_1.is_lt());
281 assert!(res_2.is_eq());
282 assert!(res_3.is_gt());
283 assert_eq!(a, c);
284 assert!(a < b);
285 assert!(d > b);
286 }
287
288 #[test]
289 fn test_neg() {
290 let a = Zq::new(100);
291 let b = Zq::ZERO;
292 let neg_a: Zq = -a;
293 let neg_b: Zq = -b;
294
295 assert_eq!(neg_a + a, Zq::ZERO);
296 assert_eq!(neg_b, Zq::ZERO);
297 }
298}