labrador/ring/
zq.rs

1use crate::ring::Norms;
2use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
3use rand::distr::uniform::{Error, SampleBorrow, SampleUniform, UniformInt, UniformSampler};
4use rand::prelude::*;
5use std::fmt;
6use std::iter::Sum;
7
8/// Element of the group **Z/(2^32 − 1)**.
9/// Uses native u32 operations with automatic modulo reduction through wrapping arithmetic.
10#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Default)]
11pub struct Zq {
12    /// Values in the range `0..u32::MAX−1`.
13    value: u32,
14}
15
16impl Zq {
17    /// Modulus `q = 2^32 − 1`
18    #[allow(clippy::as_conversions)]
19    pub const Q: u32 = u32::MAX;
20
21    // ------- constants -------
22    pub const ZERO: Self = Self::new(0);
23    pub const ONE: Self = Self::new(1);
24    pub const TWO: Self = Self::new(2);
25    // -1 or Maximum possible value. Equals `q - 1` or ` 2^32 − 2`
26    pub const NEG_ONE: Self = Self::new(u32::MAX - 1);
27
28    /// Creates a new Zq element from a raw u32 value.
29    /// No explicit modulo needed as u32 automatically wraps
30    pub const fn new(value: u32) -> Self {
31        Self { value }
32    }
33
34    pub fn to_u128(&self) -> u128 {
35        u128::from(self.value)
36    }
37
38    pub fn get_value(&self) -> u32 {
39        self.value
40    }
41
42    pub const fn is_zero(&self) -> bool {
43        self.value == 0
44    }
45
46    /// Returns `1` iff the element is in `(q-1/2, q)`
47    #[allow(clippy::as_conversions)]
48    pub fn is_larger_than_half(&self) -> bool {
49        self.value > (Self::Q - 1) / 2
50    }
51
52    /// Centered representative in `(-q/2, q/2]`.
53    #[allow(clippy::as_conversions)]
54    pub(crate) fn centered_mod(&self) -> i128 {
55        let bound = Self::Q as i128;
56        let value = self.value as i128;
57
58        if value > (bound - 1) / 2 {
59            value - bound
60        } else {
61            value
62        }
63    }
64
65    /// Floor division by another `Zq` value (*not* a field inverse!, just dividing the values).
66    pub(crate) fn div_floor_by(&self, rhs: u32) -> Self {
67        assert_ne!(rhs, 0, "division by zero");
68        Self::new(self.value / rhs)
69    }
70
71    /// Decompose the element to #num_parts number of parts,
72    /// where each part's infinity norm is less than or equal to bound/2
73    pub(crate) fn decompose(&self, bound: Self, num_parts: usize) -> Vec<Zq> {
74        assert!(bound >= Self::TWO, "base must be ≥ 2");
75        assert_ne!(num_parts, 0, "num_parts cannot be zero");
76
77        let mut parts = vec![Self::ZERO; num_parts];
78        let half_bound = bound.div_floor_by(2);
79        let mut abs_self = match self.is_larger_than_half() {
80            true => -(*self),
81            false => *self,
82        };
83
84        for part in &mut parts {
85            let mut remainder = Self::new(abs_self.value % bound.value);
86            if remainder > half_bound {
87                remainder -= bound;
88            }
89            *part = match self.is_larger_than_half() {
90                true => -remainder,
91                false => remainder,
92            };
93            abs_self = Self::new((abs_self - remainder).value / bound.value);
94            if abs_self == Self::ZERO {
95                break;
96            }
97        }
98        parts
99    }
100
101    #[allow(clippy::as_conversions)]
102    fn add_op(self, rhs: Zq) -> Zq {
103        let sum = (self.value as u64 + rhs.value as u64) % Zq::Q as u64;
104        Zq::new(sum as u32)
105    }
106
107    #[allow(clippy::as_conversions)]
108    fn sub_op(self, rhs: Zq) -> Zq {
109        let sub = (self.value as u64 + Zq::Q as u64 - rhs.value as u64) % Zq::Q as u64;
110        Zq::new(sub as u32)
111    }
112
113    #[allow(clippy::as_conversions)]
114    fn mul_op(self, b: Zq) -> Zq {
115        let prod = (self.value as u64 * b.value as u64) % Zq::Q as u64;
116        Zq::new(prod as u32)
117    }
118}
119
120// Macro to generate arithmetic trait implementations
121macro_rules! impl_arithmetic {
122    ($trait:ident, $assign_trait:ident, $method:ident, $assign_method:ident, $op:ident) => {
123        impl $trait for Zq {
124            type Output = Self;
125
126            fn $method(self, rhs: Self) -> Self::Output {
127                self.$op(rhs)
128            }
129        }
130
131        impl $assign_trait for Zq {
132            fn $assign_method(&mut self, rhs: Self) {
133                *self = self.$op(rhs);
134            }
135        }
136
137        impl $trait<Zq> for &Zq {
138            type Output = Zq;
139
140            fn $method(self, rhs: Zq) -> Self::Output {
141                self.$op(rhs)
142            }
143        }
144
145        impl $trait<&Zq> for &Zq {
146            type Output = Zq;
147
148            fn $method(self, rhs: &Zq) -> Self::Output {
149                self.$op(*rhs)
150            }
151        }
152    };
153}
154
155impl_arithmetic!(Add, AddAssign, add, add_assign, add_op);
156impl_arithmetic!(Sub, SubAssign, sub, sub_assign, sub_op);
157impl_arithmetic!(Mul, MulAssign, mul, mul_assign, mul_op);
158
159// Implement the Neg trait for Zq.
160impl Neg for Zq {
161    type Output = Zq;
162
163    /// Returns the additive inverse of the field element.
164    ///
165    /// Wrap around (q - a) mod q.
166    fn neg(self) -> Zq {
167        // If the value is zero, its inverse is itself.
168        if self.value == 0 {
169            self
170        } else {
171            #[allow(clippy::as_conversions)]
172            Zq::new(Zq::Q - self.get_value())
173        }
174    }
175}
176
177impl fmt::Display for Zq {
178    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179        // Shows value with modulus for clarity
180        write!(f, "{} (mod {})", self.value, Zq::Q)
181    }
182}
183
184#[derive(Clone, Copy, Debug)]
185pub struct UniformZq(UniformInt<u32>);
186
187impl UniformSampler for UniformZq {
188    type X = Zq;
189
190    fn new<B1, B2>(low: B1, high: B2) -> Result<Self, Error>
191    where
192        B1: SampleBorrow<Self::X> + Sized,
193        B2: SampleBorrow<Self::X> + Sized,
194    {
195        UniformInt::<u32>::new(low.borrow().value, high.borrow().value).map(UniformZq)
196    }
197    fn new_inclusive<B1, B2>(low: B1, high: B2) -> Result<Self, Error>
198    where
199        B1: SampleBorrow<Self::X> + Sized,
200        B2: SampleBorrow<Self::X> + Sized,
201    {
202        UniformInt::<u32>::new_inclusive(low.borrow().value, high.borrow().value).map(UniformZq)
203    }
204    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
205        Self::X::new(self.0.sample(rng))
206    }
207}
208
209impl SampleUniform for Zq {
210    type Sampler = UniformZq;
211}
212
213impl Sum for Zq {
214    // Accumulate using the addition operator
215    fn sum<I>(iter: I) -> Self
216    where
217        I: Iterator<Item = Zq>,
218    {
219        iter.fold(Zq::ZERO, |acc, x| acc + x)
220    }
221}
222
223/// Adds `rhs` into `lhs` component‑wise.
224pub fn add_assign_two_zq_vectors(lhs: &mut [Zq], rhs: Vec<Zq>) {
225    debug_assert_eq!(lhs.len(), rhs.len(), "vector length mismatch");
226    lhs.iter_mut().zip(rhs).for_each(|(l, r)| *l += r);
227}
228
229// Implement l2 and infinity norms for a slice of Zq elements
230impl Norms for [Zq] {
231    type NormType = u128;
232
233    #[allow(clippy::as_conversions)]
234    fn l2_norm_squared(&self) -> Self::NormType {
235        self.iter().fold(0u128, |acc, coeff| {
236            let c = coeff.centered_mod();
237            acc + (c * c) as u128
238        })
239    }
240
241    #[allow(clippy::as_conversions)]
242    fn linf_norm(&self) -> Self::NormType {
243        self.iter()
244            .map(|coeff| coeff.centered_mod().unsigned_abs())
245            .max()
246            .unwrap_or(0)
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[test]
255    fn test_to_u128() {
256        let a = Zq::new(10);
257        let b = a.to_u128();
258        assert_eq!(b, 10u128);
259    }
260
261    #[test]
262    fn test_is_zero() {
263        let a = Zq::new(0);
264        let b = Zq::new(10);
265        assert!(a.is_zero());
266        assert!(!b.is_zero());
267    }
268
269    #[test]
270    fn test_get_value() {
271        let a = Zq::new(1000);
272        assert_eq!(a.get_value(), 1000u32);
273    }
274
275    #[test]
276    fn test_basic_arithmetic() {
277        let a = Zq::new(5);
278        let b = Zq::new(3);
279
280        // Addition
281        assert_eq!((a + b).value, 8, "5 + 3 should be 8");
282        // Subtraction
283        assert_eq!((a - b).value, 2, "5 - 3 should be 2");
284        // Multiplication
285        assert_eq!((a * b).value, 15, "5 * 3 should be 15");
286    }
287
288    #[test]
289    fn test_wrapping_arithmetic() {
290        let a = Zq::NEG_ONE;
291        let b = Zq::ONE;
292
293        assert_eq!((a + b).value, 0, "u32::MAX + 1 should wrap to 0");
294        assert_eq!((b - a).value, 2, "1 - u32::MAX should wrap to 2 (mod 2^32)");
295    }
296
297    #[test]
298    fn test_subtraction_edge_cases() {
299        let max = Zq::NEG_ONE;
300        let one = Zq::ONE;
301        let two = Zq::TWO;
302
303        assert_eq!((one - max).value, 2);
304        assert_eq!((two - max).value, 3);
305        assert_eq!((max - max).value, 0);
306    }
307
308    #[test]
309    fn test_multiplication_wrapping() {
310        let a = Zq::new(1 << 31);
311        let two = Zq::TWO;
312
313        // Multiplication wraps when exceeding u32 range
314        assert_eq!((a * two).value, 1, "2^31 * 2 should wrap to 1");
315    }
316
317    #[test]
318    fn test_assignment_operators() {
319        let mut a = Zq::new(5);
320        let b = Zq::new(3);
321
322        a += b;
323        assert_eq!(a.value, 8, "5 += 3 should be 8");
324
325        a -= b;
326        assert_eq!(a.value, 5, "8 -= 3 should be 5");
327
328        a *= b;
329        assert_eq!(a.value, 15, "5 *= 3 should be 15");
330    }
331
332    #[test]
333    fn test_conversion_from_u32() {
334        let a: Zq = Zq::new(5);
335        assert_eq!(a.value, 5, "Conversion from u32 should preserve value");
336    }
337
338    #[test]
339    fn test_negative_arithmetic() {
340        let small = Zq::new(3);
341        let large = Zq::new(5);
342
343        // Test underflow handling (3 - 5 in u32 terms)
344        let result = small - large;
345        assert_eq!(result.value, u32::MAX - 2, "3 - 5 should wrap to 2^32 - 2");
346
347        // Test compound negative operations
348        let mut x = Zq::new(10);
349        x -= Zq::new(15);
350        assert_eq!(x.value, u32::MAX - 5, "10 -= 15 should wrap to 2^32 - 5");
351
352        // Test negative equivalent value in multiplication
353        let a = Zq::NEG_ONE; // Represents -1 in mod 2^32 arithmetic
354        let b = Zq::TWO;
355        assert_eq!(
356            (a * b).value,
357            u32::MAX - 2,
358            "(-1) * 2 should be -2 ≡ 2^32 - 2"
359        );
360    }
361
362    #[test]
363    fn test_display_implementation() {
364        let a = Zq::new(5);
365        let max = Zq::NEG_ONE;
366        assert_eq!(format!("{a}"), format!("5 (mod {})", Zq::Q));
367        assert_eq!(format!("{max}"), format!("4294967294 (mod {})", Zq::Q));
368    }
369
370    #[test]
371    fn test_maximum_element() {
372        dbg!(Zq::NEG_ONE);
373        dbg!(Zq::ZERO);
374        dbg!(Zq::ONE);
375        dbg!(Zq::ZERO - Zq::ONE);
376        assert_eq!(Zq::NEG_ONE, Zq::ZERO - Zq::ONE);
377    }
378
379    #[test]
380    fn test_ord() {
381        let a = Zq::new(100);
382        let b = Zq::new(200);
383        let c = Zq::new(100);
384        let d = Zq::new(400);
385
386        let res_1 = a.cmp(&b);
387        let res_2 = a.cmp(&c);
388        let res_3 = d.cmp(&b);
389        assert!(res_1.is_lt());
390        assert!(res_2.is_eq());
391        assert!(res_3.is_gt());
392        assert_eq!(a, c);
393        assert!(a < b);
394        assert!(d > b);
395    }
396
397    #[test]
398    fn test_neg() {
399        let a = Zq::new(100);
400        let b = Zq::ZERO;
401        let neg_a: Zq = -a;
402        let neg_b: Zq = -b;
403
404        assert_eq!(neg_a + a, Zq::ZERO);
405        assert_eq!(neg_b, Zq::ZERO);
406    }
407
408    #[test]
409    fn test_centered_mod() {
410        let a = -Zq::new(1);
411        assert_eq!(-1, a.centered_mod());
412
413        let a = Zq::new(4294967103);
414        assert_eq!(a, -Zq::new(192));
415        assert_eq!(-192, a.centered_mod());
416    }
417}
418
419#[cfg(test)]
420mod norm_tests {
421    use super::*;
422
423    #[test]
424    fn test_l2_norm() {
425        let zq_vector = [
426            Zq::new(1),
427            Zq::new(2),
428            Zq::new(3),
429            Zq::new(4),
430            Zq::new(5),
431            Zq::new(6),
432            Zq::new(7),
433        ];
434        let res = zq_vector.l2_norm_squared();
435
436        assert_eq!(res, 140);
437    }
438
439    #[test]
440    fn test_l2_norm_with_negative_values() {
441        let zq_vector = [
442            Zq::new(1),
443            Zq::new(2),
444            Zq::new(3),
445            -Zq::new(4),
446            -Zq::new(5),
447            -Zq::new(6),
448            -Zq::new(7),
449        ];
450        let res = zq_vector.l2_norm_squared();
451
452        assert_eq!(res, 140);
453    }
454
455    #[test]
456    fn test_linf_norm() {
457        let zq_vector = [
458            Zq::new(1),
459            Zq::new(200),
460            Zq::new(300),
461            Zq::new(40),
462            -Zq::new(5),
463            -Zq::new(6),
464            -Zq::new(700000),
465        ];
466        let res = zq_vector.linf_norm();
467        assert_eq!(res, 700000);
468
469        let zq_vector = [
470            Zq::new(1000000),
471            Zq::new(200),
472            Zq::new(300),
473            Zq::new(40),
474            -Zq::new(5),
475            -Zq::new(6),
476            -Zq::new(999999),
477        ];
478        let res = zq_vector.linf_norm();
479        assert_eq!(res, 1000000);
480
481        let zq_vector = [
482            Zq::new(1),
483            Zq::new(2),
484            Zq::new(3),
485            -Zq::new(4),
486            Zq::new(0),
487            -Zq::new(3),
488            -Zq::new(2),
489            -Zq::new(1),
490        ];
491        let res = zq_vector.linf_norm();
492        assert_eq!(res, 4);
493    }
494}
495
496#[cfg(test)]
497mod decomposition_tests {
498    use crate::ring::{zq::Zq, Norms};
499
500    #[test]
501    fn test_zq_decomposition() {
502        let (base, parts) = (Zq::new(12), 10);
503        let pos_zq = Zq::new(29);
504        let neg_zq = -Zq::new(29);
505
506        let pos_decomposed = pos_zq.decompose(base, parts);
507        let neg_decomposed = neg_zq.decompose(base, parts);
508
509        assert_eq!(
510            pos_decomposed,
511            vec![
512                Zq::new(5),
513                Zq::new(2),
514                Zq::ZERO,
515                Zq::ZERO,
516                Zq::ZERO,
517                Zq::ZERO,
518                Zq::ZERO,
519                Zq::ZERO,
520                Zq::ZERO,
521                Zq::ZERO
522            ]
523        );
524        assert_eq!(
525            neg_decomposed,
526            vec![
527                -Zq::new(5),
528                -Zq::new(2),
529                Zq::ZERO,
530                Zq::ZERO,
531                Zq::ZERO,
532                Zq::ZERO,
533                Zq::ZERO,
534                Zq::ZERO,
535                Zq::ZERO,
536                Zq::ZERO
537            ]
538        );
539    }
540
541    #[test]
542    fn test_zq_recompositoin() {
543        let (base, parts) = (Zq::new(1802), 10);
544        let pos_zq = -Zq::new(16200);
545
546        let pos_decomposed = pos_zq.decompose(base, parts);
547        let mut exponensial_base = Zq::new(1);
548        let mut result = Zq::new(0);
549        for decomposed_part in pos_decomposed {
550            result += decomposed_part * exponensial_base;
551            exponensial_base *= base;
552        }
553        assert_eq!(result, pos_zq)
554    }
555
556    #[test]
557    fn test_zq_recompositoin_positive() {
558        let (base, parts) = (Zq::new(1802), 10);
559        let pos_zq = Zq::new(23071);
560
561        let pos_decomposed = pos_zq.decompose(base, parts);
562        let mut exponensial_base = Zq::new(1);
563        let mut result = Zq::new(0);
564        for decomposed_part in pos_decomposed {
565            result += decomposed_part * exponensial_base;
566            exponensial_base *= base;
567        }
568        assert_eq!(result, pos_zq)
569    }
570
571    #[test]
572    fn test_linf_norm() {
573        let (base, parts) = (Zq::new(1802), 10);
574        let pos_zq = Zq::new(16200);
575
576        let pos_decomposed = pos_zq.decompose(base, parts);
577        dbg!(&pos_decomposed);
578        assert!(pos_decomposed.linf_norm() <= 901);
579    }
580}