labrador/ring/
zq.rs

1use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2use rand::distr::uniform::{Error, SampleBorrow, SampleUniform, UniformInt, UniformSampler};
3use rand::prelude::*;
4use std::fmt;
5/// Represents an element in the ring Z/qZ where q = 2^32.
6/// Uses native u32 operations with automatic modulo reduction through wrapping arithmetic.
7#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Default)]
8pub struct Zq {
9    /// Stored value is always in [0, q-1] due to u32's wrapping behavior
10    value: u32,
11}
12
13impl Zq {
14    /// Modulus q = 2^32 (stored as 0 in u32 due to wrapping behavior)
15    pub const Q: u32 = u32::MAX.wrapping_add(1);
16    /// Zero element (additive identity)
17    pub const ZERO: Self = Self::new(0);
18    /// Multiplicative identity
19    pub const ONE: Self = Self::new(1);
20    /// Two
21    pub const TWO: Self = Self::new(2);
22    /// Maximum element
23    pub const MAX: Self = Self::new(u32::MAX);
24
25    /// Creates a new Zq element from a raw u32 value.
26    /// No explicit modulo needed as u32 automatically wraps
27    pub const fn new(value: u32) -> Self {
28        Self { value }
29    }
30
31    pub fn to_u128(&self) -> u128 {
32        u128::from(self.value)
33    }
34
35    pub const fn is_zero(&self) -> bool {
36        self.value == 0
37    }
38
39    /// Returns the centered representative modulo the given bound
40    /// Result is guaranteed to be in (-bound/2, bound/2]
41    ///
42    /// # Panics
43    ///
44    /// Panics if `bound` is zero.
45    pub(crate) fn centered_mod(&self, bound: Self) -> Self {
46        assert!(
47            bound != Zq::ZERO,
48            "cannot get centered representative modulo for zero bound"
49        );
50        let bounded_coeff = Self::new(self.value % bound.value);
51        let half_bound = bound.scale_by(Self::TWO);
52
53        if bounded_coeff > half_bound {
54            bounded_coeff - bound
55        } else {
56            bounded_coeff
57        }
58    }
59
60    /// Scales by other Zq.
61    ///
62    /// Effectively it is a floor division of internal values.
63    /// But for the ring of integers there is no defined division
64    /// operation.
65    ///
66    /// # Panics
67    ///
68    /// Panics if `bound` is zero.
69    pub(crate) fn scale_by(&self, rhs: Self) -> Self {
70        assert!(rhs != Zq::ZERO, "cannot scale by zero");
71        Self::new(self.value / rhs.value)
72    }
73}
74
75// Macro to generate arithmetic trait implementations
76macro_rules! impl_arithmetic {
77    ($trait:ident, $assign_trait:ident, $method:ident, $assign_method:ident, $op:ident) => {
78        impl $trait for Zq {
79            type Output = Self;
80
81            fn $method(self, rhs: Self) -> Self::Output {
82                Self::new(self.value.$op(rhs.value))
83            }
84        }
85
86        impl $assign_trait for Zq {
87            fn $assign_method(&mut self, rhs: Self) {
88                self.value = self.value.$op(rhs.value);
89            }
90        }
91
92        impl $trait<Zq> for &Zq {
93            type Output = Zq;
94
95            fn $method(self, rhs: Zq) -> Self::Output {
96                Zq::new(self.value.$op(rhs.value))
97            }
98        }
99    };
100}
101
102impl_arithmetic!(Add, AddAssign, add, add_assign, wrapping_add);
103impl_arithmetic!(Sub, SubAssign, sub, sub_assign, wrapping_sub);
104impl_arithmetic!(Mul, MulAssign, mul, mul_assign, wrapping_mul);
105
106impl From<u32> for Zq {
107    fn from(value: u32) -> Self {
108        Self::new(value)
109    }
110}
111
112impl fmt::Display for Zq {
113    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114        // Shows value with modulus for clarity
115        write!(f, "{} (mod 2^32)", self.value)
116    }
117}
118
119#[derive(Clone, Copy, Debug)]
120pub struct UniformZq(UniformInt<u32>);
121
122impl UniformSampler for UniformZq {
123    type X = Zq;
124
125    fn new<B1, B2>(low: B1, high: B2) -> Result<Self, Error>
126    where
127        B1: SampleBorrow<Self::X> + Sized,
128        B2: SampleBorrow<Self::X> + Sized,
129    {
130        UniformInt::<u32>::new(low.borrow().value, high.borrow().value).map(UniformZq)
131    }
132    fn new_inclusive<B1, B2>(low: B1, high: B2) -> Result<Self, Error>
133    where
134        B1: SampleBorrow<Self::X> + Sized,
135        B2: SampleBorrow<Self::X> + Sized,
136    {
137        UniformInt::<u32>::new_inclusive(low.borrow().value, high.borrow().value).map(UniformZq)
138    }
139    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
140        self.0.sample(rng).into()
141    }
142}
143
144impl SampleUniform for Zq {
145    type Sampler = UniformZq;
146}
147
148// Implement the Neg trait for Zq.
149impl Neg for Zq {
150    type Output = Zq;
151
152    /// Returns the additive inverse of the field element.
153    ///
154    /// Wrap around (q - a) mod q.
155    fn neg(self) -> Zq {
156        // If the value is zero, its inverse is itself.
157        if self.value == 0 {
158            self
159        } else {
160            Zq::MAX + Zq::ONE - self
161        }
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    #[test]
170    fn test_basic_arithmetic() {
171        let a = Zq::new(5);
172        let b = Zq::new(3);
173
174        // Addition
175        assert_eq!((a + b).value, 8, "5 + 3 should be 8");
176        // Subtraction
177        assert_eq!((a - b).value, 2, "5 - 3 should be 2");
178        // Multiplication
179        assert_eq!((a * b).value, 15, "5 * 3 should be 15");
180    }
181
182    #[test]
183    fn test_wrapping_arithmetic() {
184        let a = Zq::MAX;
185        let b = Zq::ONE;
186
187        assert_eq!((a + b).value, 0, "u32::MAX + 1 should wrap to 0");
188        assert_eq!((b - a).value, 2, "1 - u32::MAX should wrap to 2 (mod 2^32)");
189    }
190
191    #[test]
192    fn test_subtraction_edge_cases() {
193        let max = Zq::MAX;
194        let one = Zq::ONE;
195        let two = Zq::TWO;
196
197        assert_eq!((one - max).value, 2);
198        assert_eq!((two - max).value, 3);
199        assert_eq!((max - max).value, 0);
200    }
201
202    #[test]
203    fn test_multiplication_wrapping() {
204        let a = Zq::new(1 << 31);
205        let two = Zq::TWO;
206
207        // Multiplication wraps when exceeding u32 range
208        assert_eq!((a * two).value, 0, "2^31 * 2 should wrap to 0");
209    }
210
211    #[test]
212    fn test_assignment_operators() {
213        let mut a = Zq::new(5);
214        let b = Zq::new(3);
215
216        a += b;
217        assert_eq!(a.value, 8, "5 += 3 should be 8");
218
219        a -= b;
220        assert_eq!(a.value, 5, "8 -= 3 should be 5");
221
222        a *= b;
223        assert_eq!(a.value, 15, "5 *= 3 should be 15");
224    }
225
226    #[test]
227    fn test_conversion_from_u32() {
228        let a: Zq = 5_u32.into();
229        assert_eq!(a.value, 5, "Conversion from u32 should preserve value");
230    }
231
232    #[test]
233    fn test_negative_arithmetic() {
234        let small = Zq::new(3);
235        let large = Zq::new(5);
236
237        // Test underflow handling (3 - 5 in u32 terms)
238        let result = small - large;
239        assert_eq!(result.value, u32::MAX - 1, "3 - 5 should wrap to 2^32 - 2");
240
241        // Test compound negative operations
242        let mut x = Zq::new(10);
243        x -= Zq::new(15);
244        assert_eq!(x.value, u32::MAX - 4, "10 -= 15 should wrap to 2^32 - 5");
245
246        // Test negative equivalent value in multiplication
247        let a = Zq::MAX; // Represents -1 in mod 2^32 arithmetic
248        let b = Zq::TWO;
249        assert_eq!(
250            (a * b).value,
251            u32::MAX - 1,
252            "(-1) * 2 should be -2 ≡ 2^32 - 2"
253        );
254    }
255
256    #[test]
257    fn test_display_implementation() {
258        let a = Zq::new(5);
259        let max = Zq::MAX;
260
261        assert_eq!(format!("{}", a), "5 (mod 2^32)");
262        assert_eq!(format!("{}", max), "4294967295 (mod 2^32)");
263    }
264
265    #[test]
266    fn test_maximum_element() {
267        assert_eq!(Zq::MAX, Zq::ZERO - Zq::ONE);
268    }
269
270    #[test]
271    fn test_ord() {
272        let a = Zq::new(100);
273        let b = Zq::new(200);
274        let c = Zq::new(100);
275        let d = Zq::new(400);
276
277        let res_1 = a.cmp(&b);
278        let res_2 = a.cmp(&c);
279        let res_3 = d.cmp(&b);
280        assert!(res_1.is_lt());
281        assert!(res_2.is_eq());
282        assert!(res_3.is_gt());
283        assert_eq!(a, c);
284        assert!(a < b);
285        assert!(d > b);
286    }
287
288    #[test]
289    fn test_neg() {
290        let a = Zq::new(100);
291        let b = Zq::ZERO;
292        let neg_a: Zq = -a;
293        let neg_b: Zq = -b;
294
295        assert_eq!(neg_a + a, Zq::ZERO);
296        assert_eq!(neg_b, Zq::ZERO);
297    }
298}