labrador/ring/
zq.rs

1use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2use rand::distr::uniform::{Error, SampleBorrow, SampleUniform, UniformInt, UniformSampler};
3use rand::prelude::*;
4use std::fmt;
5use std::iter::Sum;
6/// Represents an element in the ring Z/qZ where q = 2^32.
7/// Uses native u32 operations with automatic modulo reduction through wrapping arithmetic.
8#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Default)]
9pub struct Zq {
10    /// Stored value is always in [0, q-1] due to u32's wrapping behavior
11    value: u32,
12}
13
14impl Zq {
15    /// Modulus q = 2^32 - 1
16    #[allow(clippy::as_conversions)]
17    pub const Q: u64 = u32::MAX as u64;
18    /// Zero element (additive identity)
19    pub const ZERO: Self = Self::new(0);
20    /// Multiplicative identity
21    pub const ONE: Self = Self::new(1);
22    /// Two
23    pub const TWO: Self = Self::new(2);
24    /// Maximum element
25    pub const NEG_ONE: Self = Self::new(u32::MAX - 1);
26
27    /// Creates a new Zq element from a raw u32 value.
28    /// No explicit modulo needed as u32 automatically wraps
29    pub const fn new(value: u32) -> Self {
30        Self { value }
31    }
32
33    pub fn to_u128(&self) -> u128 {
34        u128::from(self.value)
35    }
36
37    pub const fn is_zero(&self) -> bool {
38        self.value == 0
39    }
40
41    pub fn get_value(&self) -> u32 {
42        self.value
43    }
44
45    /// Returns the centered representative modulo the given bound
46    /// Result is guaranteed to be in (-bound/2, bound/2]
47    ///
48    /// # Panics
49    ///
50    /// Panics if `bound` is zero.
51    pub(crate) fn centered_mod(&self, bound: Self) -> Self {
52        assert!(
53            bound != Zq::ZERO,
54            "cannot get centered representative modulo for zero bound"
55        );
56        let bounded_coeff = Self::new(self.value % bound.value);
57        let half_bound = bound.scale_by(Self::TWO);
58
59        if bounded_coeff > half_bound {
60            bounded_coeff - bound
61        } else {
62            bounded_coeff
63        }
64    }
65
66    /// Scales by other Zq.
67    ///
68    /// Effectively it is a floor division of internal values.
69    /// But for the ring of integers there is no defined division
70    /// operation.
71    ///
72    /// # Panics
73    ///
74    /// Panics if `bound` is zero.
75    pub(crate) fn scale_by(&self, rhs: Self) -> Self {
76        assert!(rhs != Zq::ZERO, "cannot scale by zero");
77        Self::new(self.value / rhs.value)
78    }
79
80    #[allow(clippy::as_conversions)]
81    fn add_op(self, rhs: Zq) -> Zq {
82        let sum = (self.value as u64 + rhs.value as u64) % Zq::Q;
83        Zq::new(sum as u32)
84    }
85
86    #[allow(clippy::as_conversions)]
87    fn sub_op(self, rhs: Zq) -> Zq {
88        let sub = (self.value as u64 + Zq::Q - rhs.value as u64) % Zq::Q;
89        Zq::new(sub as u32)
90    }
91
92    #[allow(clippy::as_conversions)]
93    fn mul_op(self, b: Zq) -> Zq {
94        let prod = (self.value as u64 * b.value as u64) % Zq::Q;
95        Zq::new(prod as u32)
96    }
97}
98
99// Macro to generate arithmetic trait implementations
100macro_rules! impl_arithmetic {
101    ($trait:ident, $assign_trait:ident, $method:ident, $assign_method:ident, $op:ident) => {
102        impl $trait for Zq {
103            type Output = Self;
104
105            fn $method(self, rhs: Self) -> Self::Output {
106                self.$op(rhs)
107            }
108        }
109
110        impl $assign_trait for Zq {
111            fn $assign_method(&mut self, rhs: Self) {
112                *self = self.$op(rhs);
113            }
114        }
115
116        impl $trait<Zq> for &Zq {
117            type Output = Zq;
118
119            fn $method(self, rhs: Zq) -> Self::Output {
120                self.$op(rhs)
121            }
122        }
123
124        impl $trait<&Zq> for &Zq {
125            type Output = Zq;
126
127            fn $method(self, rhs: &Zq) -> Self::Output {
128                self.$op(*rhs)
129            }
130        }
131    };
132}
133
134impl_arithmetic!(Add, AddAssign, add, add_assign, add_op);
135impl_arithmetic!(Sub, SubAssign, sub, sub_assign, sub_op);
136impl_arithmetic!(Mul, MulAssign, mul, mul_assign, mul_op);
137
138// Implement the Neg trait for Zq.
139impl Neg for Zq {
140    type Output = Zq;
141
142    /// Returns the additive inverse of the field element.
143    ///
144    /// Wrap around (q - a) mod q.
145    fn neg(self) -> Zq {
146        // If the value is zero, its inverse is itself.
147        if self.value == 0 {
148            self
149        } else {
150            #[allow(clippy::as_conversions)]
151            Zq::new(Zq::Q as u32 - self.get_value())
152        }
153    }
154}
155
156impl fmt::Display for Zq {
157    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158        // Shows value with modulus for clarity
159        write!(f, "{} (mod {})", self.value, Zq::Q)
160    }
161}
162
163#[derive(Clone, Copy, Debug)]
164pub struct UniformZq(UniformInt<u32>);
165
166impl UniformSampler for UniformZq {
167    type X = Zq;
168
169    fn new<B1, B2>(low: B1, high: B2) -> Result<Self, Error>
170    where
171        B1: SampleBorrow<Self::X> + Sized,
172        B2: SampleBorrow<Self::X> + Sized,
173    {
174        UniformInt::<u32>::new(low.borrow().value, high.borrow().value).map(UniformZq)
175    }
176    fn new_inclusive<B1, B2>(low: B1, high: B2) -> Result<Self, Error>
177    where
178        B1: SampleBorrow<Self::X> + Sized,
179        B2: SampleBorrow<Self::X> + Sized,
180    {
181        UniformInt::<u32>::new_inclusive(low.borrow().value, high.borrow().value).map(UniformZq)
182    }
183    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
184        Self::X::new(self.0.sample(rng))
185    }
186}
187
188impl SampleUniform for Zq {
189    type Sampler = UniformZq;
190}
191
192impl Sum for Zq {
193    // Accumulate using the addition operator
194    fn sum<I>(iter: I) -> Self
195    where
196        I: Iterator<Item = Zq>,
197    {
198        iter.fold(Zq::ZERO, |acc, x| acc + x)
199    }
200}
201
202pub fn add_assign_two_zq_vectors(first: &mut [Zq], second: Vec<Zq>) {
203    first
204        .iter_mut()
205        .zip(second)
206        .for_each(|(first_coeff, second_coeff)| *first_coeff += second_coeff);
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    #[test]
214    fn test_to_u128() {
215        let a = Zq::new(10);
216        let b = a.to_u128();
217        assert_eq!(b, 10u128);
218    }
219
220    #[test]
221    fn test_is_zero() {
222        let a = Zq::new(0);
223        let b = Zq::new(10);
224        assert!(a.is_zero());
225        assert!(!b.is_zero());
226    }
227
228    #[test]
229    fn test_get_value() {
230        let a = Zq::new(1000);
231        assert_eq!(a.get_value(), 1000u32);
232    }
233
234    #[test]
235    fn test_basic_arithmetic() {
236        let a = Zq::new(5);
237        let b = Zq::new(3);
238
239        // Addition
240        assert_eq!((a + b).value, 8, "5 + 3 should be 8");
241        // Subtraction
242        assert_eq!((a - b).value, 2, "5 - 3 should be 2");
243        // Multiplication
244        assert_eq!((a * b).value, 15, "5 * 3 should be 15");
245    }
246
247    #[test]
248    fn test_wrapping_arithmetic() {
249        let a = Zq::NEG_ONE;
250        let b = Zq::ONE;
251
252        assert_eq!((a + b).value, 0, "u32::MAX + 1 should wrap to 0");
253        assert_eq!((b - a).value, 2, "1 - u32::MAX should wrap to 2 (mod 2^32)");
254    }
255
256    #[test]
257    fn test_subtraction_edge_cases() {
258        let max = Zq::NEG_ONE;
259        let one = Zq::ONE;
260        let two = Zq::TWO;
261
262        assert_eq!((one - max).value, 2);
263        assert_eq!((two - max).value, 3);
264        assert_eq!((max - max).value, 0);
265    }
266
267    #[test]
268    fn test_multiplication_wrapping() {
269        let a = Zq::new(1 << 31);
270        let two = Zq::TWO;
271
272        // Multiplication wraps when exceeding u32 range
273        assert_eq!((a * two).value, 1, "2^31 * 2 should wrap to 1");
274    }
275
276    #[test]
277    fn test_assignment_operators() {
278        let mut a = Zq::new(5);
279        let b = Zq::new(3);
280
281        a += b;
282        assert_eq!(a.value, 8, "5 += 3 should be 8");
283
284        a -= b;
285        assert_eq!(a.value, 5, "8 -= 3 should be 5");
286
287        a *= b;
288        assert_eq!(a.value, 15, "5 *= 3 should be 15");
289    }
290
291    #[test]
292    fn test_conversion_from_u32() {
293        let a: Zq = Zq::new(5);
294        assert_eq!(a.value, 5, "Conversion from u32 should preserve value");
295    }
296
297    #[test]
298    fn test_negative_arithmetic() {
299        let small = Zq::new(3);
300        let large = Zq::new(5);
301
302        // Test underflow handling (3 - 5 in u32 terms)
303        let result = small - large;
304        assert_eq!(result.value, u32::MAX - 2, "3 - 5 should wrap to 2^32 - 2");
305
306        // Test compound negative operations
307        let mut x = Zq::new(10);
308        x -= Zq::new(15);
309        assert_eq!(x.value, u32::MAX - 5, "10 -= 15 should wrap to 2^32 - 5");
310
311        // Test negative equivalent value in multiplication
312        let a = Zq::NEG_ONE; // Represents -1 in mod 2^32 arithmetic
313        let b = Zq::TWO;
314        assert_eq!(
315            (a * b).value,
316            u32::MAX - 2,
317            "(-1) * 2 should be -2 ≡ 2^32 - 2"
318        );
319    }
320
321    #[test]
322    fn test_display_implementation() {
323        let a = Zq::new(5);
324        let max = Zq::NEG_ONE;
325        assert_eq!(format!("{}", a), format!("5 (mod {})", Zq::Q));
326        assert_eq!(format!("{}", max), format!("4294967294 (mod {})", Zq::Q));
327    }
328
329    #[test]
330    fn test_maximum_element() {
331        dbg!(Zq::NEG_ONE);
332        dbg!(Zq::ZERO);
333        dbg!(Zq::ONE);
334        dbg!(Zq::ZERO - Zq::ONE);
335        assert_eq!(Zq::NEG_ONE, Zq::ZERO - Zq::ONE);
336    }
337
338    #[test]
339    fn test_ord() {
340        let a = Zq::new(100);
341        let b = Zq::new(200);
342        let c = Zq::new(100);
343        let d = Zq::new(400);
344
345        let res_1 = a.cmp(&b);
346        let res_2 = a.cmp(&c);
347        let res_3 = d.cmp(&b);
348        assert!(res_1.is_lt());
349        assert!(res_2.is_eq());
350        assert!(res_3.is_gt());
351        assert_eq!(a, c);
352        assert!(a < b);
353        assert!(d > b);
354    }
355
356    #[test]
357    fn test_neg() {
358        let a = Zq::new(100);
359        let b = Zq::ZERO;
360        let neg_a: Zq = -a;
361        let neg_b: Zq = -b;
362
363        assert_eq!(neg_a + a, Zq::ZERO);
364        assert_eq!(neg_b, Zq::ZERO);
365    }
366}