labrador/ring/
zq.rs

1use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2use rand::distr::uniform::{Error, SampleBorrow, SampleUniform, UniformInt, UniformSampler};
3use rand::prelude::*;
4use std::fmt;
5/// Represents an element in the ring Z/qZ where q = 2^32.
6/// Uses native u32 operations with automatic modulo reduction through wrapping arithmetic.
7#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord, Default)]
8pub struct Zq {
9    /// Stored value is always in [0, q-1] due to u32's wrapping behavior
10    value: u32,
11}
12
13impl Zq {
14    /// Modulus q = 2^32 (stored as 0 in u32 due to wrapping behavior)
15    pub const Q: u32 = u32::MAX.wrapping_add(1);
16    /// Zero element (additive identity)
17    pub const ZERO: Self = Self::new(0);
18    /// Multiplicative identity
19    pub const ONE: Self = Self::new(1);
20    /// Two
21    pub const TWO: Self = Self::new(2);
22    /// Maximum element
23    pub const MAX: Self = Self::new(u32::MAX);
24
25    /// Creates a new Zq element from a raw u32 value.
26    /// No explicit modulo needed as u32 automatically wraps
27    pub const fn new(value: u32) -> Self {
28        Self { value }
29    }
30
31    pub fn to_u128(&self) -> u128 {
32        u128::from(self.value)
33    }
34
35    pub const fn is_zero(&self) -> bool {
36        self.value == 0
37    }
38
39    #[allow(clippy::as_conversions)]
40    pub fn get_value(&self) -> usize {
41        self.value as usize
42    }
43
44    /// Returns the centered representative modulo the given bound
45    /// Result is guaranteed to be in (-bound/2, bound/2]
46    ///
47    /// # Panics
48    ///
49    /// Panics if `bound` is zero.
50    pub(crate) fn centered_mod(&self, bound: Self) -> Self {
51        assert!(
52            bound != Zq::ZERO,
53            "cannot get centered representative modulo for zero bound"
54        );
55        let bounded_coeff = Self::new(self.value % bound.value);
56        let half_bound = bound.scale_by(Self::TWO);
57
58        if bounded_coeff > half_bound {
59            bounded_coeff - bound
60        } else {
61            bounded_coeff
62        }
63    }
64
65    /// Scales by other Zq.
66    ///
67    /// Effectively it is a floor division of internal values.
68    /// But for the ring of integers there is no defined division
69    /// operation.
70    ///
71    /// # Panics
72    ///
73    /// Panics if `bound` is zero.
74    pub(crate) fn scale_by(&self, rhs: Self) -> Self {
75        assert!(rhs != Zq::ZERO, "cannot scale by zero");
76        Self::new(self.value / rhs.value)
77    }
78}
79
80// Macro to generate arithmetic trait implementations
81macro_rules! impl_arithmetic {
82    ($trait:ident, $assign_trait:ident, $method:ident, $assign_method:ident, $op:ident) => {
83        impl $trait for Zq {
84            type Output = Self;
85
86            fn $method(self, rhs: Self) -> Self::Output {
87                Self::new(self.value.$op(rhs.value))
88            }
89        }
90
91        impl $assign_trait for Zq {
92            fn $assign_method(&mut self, rhs: Self) {
93                self.value = self.value.$op(rhs.value);
94            }
95        }
96
97        impl $trait<Zq> for &Zq {
98            type Output = Zq;
99
100            fn $method(self, rhs: Zq) -> Self::Output {
101                Zq::new(self.value.$op(rhs.value))
102            }
103        }
104    };
105}
106
107impl_arithmetic!(Add, AddAssign, add, add_assign, wrapping_add);
108impl_arithmetic!(Sub, SubAssign, sub, sub_assign, wrapping_sub);
109impl_arithmetic!(Mul, MulAssign, mul, mul_assign, wrapping_mul);
110
111impl From<u32> for Zq {
112    fn from(value: u32) -> Self {
113        Self::new(value)
114    }
115}
116
117impl fmt::Display for Zq {
118    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
119        // Shows value with modulus for clarity
120        write!(f, "{} (mod 2^32)", self.value)
121    }
122}
123
124#[derive(Clone, Copy, Debug)]
125pub struct UniformZq(UniformInt<u32>);
126
127impl UniformSampler for UniformZq {
128    type X = Zq;
129
130    fn new<B1, B2>(low: B1, high: B2) -> Result<Self, Error>
131    where
132        B1: SampleBorrow<Self::X> + Sized,
133        B2: SampleBorrow<Self::X> + Sized,
134    {
135        UniformInt::<u32>::new(low.borrow().value, high.borrow().value).map(UniformZq)
136    }
137    fn new_inclusive<B1, B2>(low: B1, high: B2) -> Result<Self, Error>
138    where
139        B1: SampleBorrow<Self::X> + Sized,
140        B2: SampleBorrow<Self::X> + Sized,
141    {
142        UniformInt::<u32>::new_inclusive(low.borrow().value, high.borrow().value).map(UniformZq)
143    }
144    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
145        self.0.sample(rng).into()
146    }
147}
148
149impl SampleUniform for Zq {
150    type Sampler = UniformZq;
151}
152
153// Implement the Neg trait for Zq.
154impl Neg for Zq {
155    type Output = Zq;
156
157    /// Returns the additive inverse of the field element.
158    ///
159    /// Wrap around (q - a) mod q.
160    fn neg(self) -> Zq {
161        // If the value is zero, its inverse is itself.
162        if self.value == 0 {
163            self
164        } else {
165            Zq::MAX + Zq::ONE - self
166        }
167    }
168}
169
170pub trait ZqVector {
171    fn random<R: Rng + CryptoRng>(rng: &mut R, n: usize) -> Self;
172    fn conjugate_automorphism(&self) -> Self;
173    fn multiply(&self, other: &Self) -> Self;
174    fn add(&self, other: &Self) -> Self;
175    fn inner_product(&self, other: &Self) -> Zq;
176}
177
178impl ZqVector for Vec<Zq> {
179    fn random<R: Rng + CryptoRng>(rng: &mut R, n: usize) -> Self {
180        // you can re‑use the UniformZq defined above
181        let uniform = UniformZq::new_inclusive(Zq::ZERO, Zq::MAX).unwrap();
182        (0..n).map(|_| uniform.sample(rng)).collect()
183    }
184
185    /// Add two ZqVector with flexible degree
186    fn add(&self, other: &Self) -> Self {
187        let max_degree = self.len().max(other.len());
188        let mut coeffs = vec![Zq::ZERO; max_degree];
189        for (i, coeff) in coeffs.iter_mut().enumerate().take(max_degree) {
190            if i < self.len() {
191                *coeff += self[i];
192            }
193            if i < other.len() {
194                *coeff += other[i];
195            }
196        }
197        coeffs
198    }
199
200    /// Note: This is a key performance bottleneck. The multiplication here is primarily used in: Prover.check_projection()
201    /// which verifies the condition: p_j? = ct(sum(<σ−1(pi_i^(j)), s_i>))
202    /// Each ZqVector involved has a length of 2*lambda (default: 256).
203    /// Consider optimizing this operation by applying NTT-based multiplication to improve performance.
204    fn multiply(&self, other: &Vec<Zq>) -> Vec<Zq> {
205        let mut result_coefficients = vec![Zq::new(0); self.len() + other.len() - 1];
206        for (i, &coeff1) in self.iter().enumerate() {
207            for (j, &coeff2) in other.iter().enumerate() {
208                result_coefficients[i + j] += coeff1 * coeff2;
209            }
210        }
211
212        if result_coefficients.len() > self.len() {
213            let q_minus_1 = Zq::MAX;
214            let (left, right) = result_coefficients.split_at_mut(self.len());
215            for (i, &overflow) in right.iter().enumerate() {
216                left[i] += overflow * q_minus_1;
217            }
218            result_coefficients.truncate(self.len());
219        }
220        result_coefficients
221    }
222
223    /// Dot product between coefficients
224    fn inner_product(&self, other: &Self) -> Zq {
225        self.iter()
226            .zip(other.iter())
227            .map(|(&a, &b)| a * b)
228            .fold(Zq::ZERO, |acc, x| acc + x)
229    }
230
231    /// Compute the conjugate automorphism \sigma_{-1} of vector based on B) Constraints..., Page 21.
232    fn conjugate_automorphism(&self) -> Vec<Zq> {
233        let q_minus_1 = Zq::MAX;
234        let mut new_coeffs = vec![Zq::ZERO; self.len()];
235        for (i, new_coeff) in new_coeffs.iter_mut().enumerate().take(self.len()) {
236            if i < self.len() {
237                if i == 0 {
238                    *new_coeff = self[i];
239                } else {
240                    *new_coeff = self[i] * q_minus_1;
241                }
242            } else {
243                *new_coeff = Zq::ZERO;
244            }
245        }
246        let reversed_coefficients = new_coeffs
247            .iter()
248            .take(1)
249            .cloned()
250            .chain(new_coeffs.iter().skip(1).rev().cloned())
251            .collect::<Vec<Zq>>();
252
253        reversed_coefficients
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260
261    #[test]
262    fn test_basic_arithmetic() {
263        let a = Zq::new(5);
264        let b = Zq::new(3);
265
266        // Addition
267        assert_eq!((a + b).value, 8, "5 + 3 should be 8");
268        // Subtraction
269        assert_eq!((a - b).value, 2, "5 - 3 should be 2");
270        // Multiplication
271        assert_eq!((a * b).value, 15, "5 * 3 should be 15");
272    }
273
274    #[test]
275    fn test_wrapping_arithmetic() {
276        let a = Zq::MAX;
277        let b = Zq::ONE;
278
279        assert_eq!((a + b).value, 0, "u32::MAX + 1 should wrap to 0");
280        assert_eq!((b - a).value, 2, "1 - u32::MAX should wrap to 2 (mod 2^32)");
281    }
282
283    #[test]
284    fn test_subtraction_edge_cases() {
285        let max = Zq::MAX;
286        let one = Zq::ONE;
287        let two = Zq::TWO;
288
289        assert_eq!((one - max).value, 2);
290        assert_eq!((two - max).value, 3);
291        assert_eq!((max - max).value, 0);
292    }
293
294    #[test]
295    fn test_multiplication_wrapping() {
296        let a = Zq::new(1 << 31);
297        let two = Zq::TWO;
298
299        // Multiplication wraps when exceeding u32 range
300        assert_eq!((a * two).value, 0, "2^31 * 2 should wrap to 0");
301    }
302
303    #[test]
304    fn test_assignment_operators() {
305        let mut a = Zq::new(5);
306        let b = Zq::new(3);
307
308        a += b;
309        assert_eq!(a.value, 8, "5 += 3 should be 8");
310
311        a -= b;
312        assert_eq!(a.value, 5, "8 -= 3 should be 5");
313
314        a *= b;
315        assert_eq!(a.value, 15, "5 *= 3 should be 15");
316    }
317
318    #[test]
319    fn test_conversion_from_u32() {
320        let a: Zq = 5_u32.into();
321        assert_eq!(a.value, 5, "Conversion from u32 should preserve value");
322    }
323
324    #[test]
325    fn test_negative_arithmetic() {
326        let small = Zq::new(3);
327        let large = Zq::new(5);
328
329        // Test underflow handling (3 - 5 in u32 terms)
330        let result = small - large;
331        assert_eq!(result.value, u32::MAX - 1, "3 - 5 should wrap to 2^32 - 2");
332
333        // Test compound negative operations
334        let mut x = Zq::new(10);
335        x -= Zq::new(15);
336        assert_eq!(x.value, u32::MAX - 4, "10 -= 15 should wrap to 2^32 - 5");
337
338        // Test negative equivalent value in multiplication
339        let a = Zq::MAX; // Represents -1 in mod 2^32 arithmetic
340        let b = Zq::TWO;
341        assert_eq!(
342            (a * b).value,
343            u32::MAX - 1,
344            "(-1) * 2 should be -2 ≡ 2^32 - 2"
345        );
346    }
347
348    #[test]
349    fn test_display_implementation() {
350        let a = Zq::new(5);
351        let max = Zq::MAX;
352
353        assert_eq!(format!("{}", a), "5 (mod 2^32)");
354        assert_eq!(format!("{}", max), "4294967295 (mod 2^32)");
355    }
356
357    #[test]
358    fn test_maximum_element() {
359        assert_eq!(Zq::MAX, Zq::ZERO - Zq::ONE);
360    }
361
362    #[test]
363    fn test_ord() {
364        let a = Zq::new(100);
365        let b = Zq::new(200);
366        let c = Zq::new(100);
367        let d = Zq::new(400);
368
369        let res_1 = a.cmp(&b);
370        let res_2 = a.cmp(&c);
371        let res_3 = d.cmp(&b);
372        assert!(res_1.is_lt());
373        assert!(res_2.is_eq());
374        assert!(res_3.is_gt());
375        assert_eq!(a, c);
376        assert!(a < b);
377        assert!(d > b);
378    }
379
380    #[test]
381    fn test_neg() {
382        let a = Zq::new(100);
383        let b = Zq::ZERO;
384        let neg_a: Zq = -a;
385        let neg_b: Zq = -b;
386
387        assert_eq!(neg_a + a, Zq::ZERO);
388        assert_eq!(neg_b, Zq::ZERO);
389    }
390}