labrador/ring/
rq.rs

1// This file is part of the polynomial ring operations module.
2//
3//
4// Currently implemented functions include:
5// - Polynomial addition:          +
6// - Polynomial multiplication:    *
7// - inner_product/ Dot product:   inner_product()
8// - Polynomial subtraction:       -
9// - Polynomial negation:          neg()
10// - Scalar multiplication:        scalar_mul()
11// - Polynomial evaluation:        eval()
12// - Zero check:                   is_zero()
13// - Polynomial equality check:    is_equal()
14// - Get the Coefficients:         get_coefficients()
15// - Random small norm vector:     random_small_vector()
16// - Squared norm of coefficients: compute_norm_squared()
17//
18// Further operations and optimizations will be added in future versions.
19
20// We use the Zq ring
21use crate::ring::zq::Zq;
22use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
23use rand::distr::{Distribution, Uniform};
24use rand::{CryptoRng, Rng};
25use std::iter::Sum;
26
27/// This module provides implementations for various operations
28/// in the polynomial ring R = Z_q\[X\] / (X^d + 1).
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub struct Rq<const D: usize> {
31    coeffs: [Zq; D],
32}
33
34impl<const D: usize> Rq<D> {
35    /// Constructor for the polynomial ring
36    pub const fn new(coeffs: [Zq; D]) -> Self {
37        Rq { coeffs }
38    }
39    /// Get the coefficients as a vector
40    pub fn get_coefficients(&self) -> &[Zq; D] {
41        &self.coeffs
42    }
43
44    pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, Zq> {
45        self.coeffs.iter_mut()
46    }
47
48    /// Polynomial addition
49    fn addition(&self, other: &Self) -> Self {
50        let mut result = [Zq::ZERO; D];
51        for (r, (a, b)) in result
52            .iter_mut()
53            .zip(self.coeffs.iter().zip(other.coeffs.iter()))
54        {
55            *r = *a + *b;
56        }
57        Rq::new(result)
58    }
59
60    /// Polynomial subtraction
61    fn subtraction(&self, other: &Self) -> Self {
62        let mut result = [Zq::ZERO; D];
63        for (r, (a, b)) in result
64            .iter_mut()
65            .zip(self.coeffs.iter().zip(other.coeffs.iter()))
66        {
67            *r = *a - *b;
68        }
69        Rq::new(result)
70    }
71
72    /// Polynomial multiplication modulo x^D + 1
73    fn multiplication(&self, other: &Self) -> Self {
74        let mut result = [Zq::ZERO; D];
75        let mut out_of_field = [Zq::ZERO; D];
76        for (i, &self_coeff) in self.coeffs.iter().enumerate() {
77            for (j, &other_coeff) in other.coeffs.iter().enumerate() {
78                if i + j < D {
79                    result[i + j] += self_coeff * other_coeff;
80                } else {
81                    out_of_field[(i + j) % D] += self_coeff * other_coeff;
82                }
83            }
84        }
85        // Process excess terms with sign adjustment
86        for i in (0..D).rev() {
87            let m = i / D;
88            let r = i % D;
89            let sign = if (m + 1) % 2 == 0 { 1 } else { -1 };
90            if sign == 1 {
91                result[r] += out_of_field[i];
92            } else {
93                result[r] -= out_of_field[i];
94            }
95        }
96        Rq::new(result)
97    }
98
99    /// Dot product between coefficients
100    pub fn inner_product(&self, other: &Self) -> Zq {
101        self.coeffs
102            .iter()
103            .zip(other.coeffs.iter())
104            .map(|(&a, &b)| a * b)
105            .fold(Zq::ZERO, |acc, x| acc + x)
106    }
107
108    /// Scalar multiplication
109    pub fn scalar_mul(&self, s: Zq) -> Self {
110        let mut result = [Zq::ZERO; D];
111        for (i, &coeff) in self.coeffs.iter().enumerate() {
112            result[i] = s * (coeff);
113        }
114        Rq::new(result)
115    }
116
117    /// Evaluate the polynomial at a specific point
118    pub fn eval(&self, x: Zq) -> Zq {
119        let mut result = Zq::ZERO;
120        for coeff in self.coeffs.iter().rev() {
121            result = result * x + *coeff;
122        }
123
124        result
125    }
126
127    /// Check if Polynomial == 0
128    pub fn is_zero(&self) -> bool {
129        self.coeffs.iter().all(|&coeff| coeff == Zq::ZERO)
130    }
131
132    /// Check if two polynomials are equal
133    pub fn is_equal(&self, other: &Self) -> bool {
134        self.coeffs == other.coeffs
135    }
136
137    /// Generate random polynomial with a provided cryptographically secure RNG
138    pub fn random<R: Rng + CryptoRng>(rng: &mut R) -> Self {
139        let uniform = Uniform::new_inclusive(Zq::ZERO, Zq::MAX).unwrap();
140        let mut coeffs = [Zq::ZERO; D];
141        coeffs.iter_mut().for_each(|c| *c = uniform.sample(rng));
142        Self { coeffs }
143    }
144
145    /// Generate random small polynomial with secure RNG implementation
146    pub fn random_ternary<R: Rng + CryptoRng>(rng: &mut R) -> Self {
147        let mut coeffs = [Zq::ZERO; D];
148
149        for coeff in coeffs.iter_mut() {
150            // Explicitly sample from {-1, 0, 1} with equal probability
151            let val = match rng.random_range(0..3) {
152                0 => Zq::MAX,  // -1 mod q
153                1 => Zq::ZERO, // 0
154                2 => Zq::ONE,  // 1
155                _ => unreachable!(),
156            };
157            *coeff = val;
158        }
159
160        Rq::new(coeffs)
161    }
162
163    /// Decomposes a polynomial into base-B representation:
164    /// p = p⁽⁰⁾ + p⁽¹⁾·B + p⁽²⁾·B² + ... + p⁽ᵗ⁻¹⁾·B^(t-1)
165    /// Where each p⁽ⁱ⁾ has small coefficients, using centered representatives
166    pub fn decompose(&self, base: Zq, num_parts: usize) -> Vec<Self> {
167        let mut parts = Vec::with_capacity(num_parts);
168        let mut current = self.clone();
169
170        for i in 0..num_parts {
171            if i == num_parts - 1 {
172                parts.push(current.clone());
173            } else {
174                // Extract low part (mod base, centered around 0)
175                let mut low_coeffs = [Zq::ZERO; D];
176
177                for (j, coeff) in current.get_coefficients().iter().enumerate() {
178                    low_coeffs[j] = coeff.centered_mod(base);
179                }
180
181                let low_part = Self::new(low_coeffs);
182                parts.push(low_part.clone());
183
184                // Update current
185                current -= low_part;
186
187                // Scale by base
188                let mut scaled_coeffs = [Zq::ZERO; D];
189                for (j, coeff) in current.get_coefficients().iter().enumerate() {
190                    scaled_coeffs[j] = coeff.scale_by(base);
191                }
192                current = Self::new(scaled_coeffs);
193            }
194        }
195
196        parts
197    }
198
199    /// Encode message into polynomial with small coefficients.
200    ///
201    /// # Arguments
202    /// * `message` - A slice of booleans representing a binary message
203    ///
204    /// # Returns
205    /// * `Some(Rq)` - A polynomial where each coefficient is 0 or 1 based on the message bits
206    /// * `None` - If the message length exceeds the polynomial degree D
207    ///
208    /// # Format
209    /// * Each boolean is encoded as a coefficient: false -> 0, true -> 1
210    /// * Message bits are mapped to coefficients in order (index 0 -> constant term)
211    /// * Remaining coefficients (if message is shorter than D) are set to 0
212    pub fn encode_message(message: &[bool]) -> Option<Self> {
213        if message.len() > D {
214            return None;
215        }
216
217        let mut coeffs = [Zq::ZERO; D];
218        for (i, &bit) in message.iter().enumerate() {
219            coeffs[i] = Zq::new(u32::from(bit));
220        }
221        Some(Rq::new(coeffs))
222    }
223
224    /// Iterator over coefficients
225    pub fn iter(&self) -> std::slice::Iter<'_, Zq> {
226        self.coeffs.iter()
227    }
228
229    /// Check if polynomial coefficients are within bounds
230    pub fn check_bounds(&self, bound: Zq) -> bool {
231        self.iter().all(|coeff| coeff <= &bound || coeff >= &-bound)
232    }
233
234    pub const fn zero() -> Self {
235        Self::new([Zq::ZERO; D])
236    }
237}
238
239macro_rules! impl_arithmetic {
240    ($trait:ident, $assign_trait:ident, $method:ident, $assign_method:ident, $op_method:ident) => {
241        impl<const D: usize> $trait for Rq<{ D }> {
242            type Output = Self;
243
244            fn $method(self, rhs: Self) -> Self::Output {
245                self.$op_method(&rhs)
246            }
247        }
248
249        impl<const D: usize> $assign_trait for Rq<{ D }> {
250            fn $assign_method(&mut self, rhs: Self) {
251                let result = self.$op_method(&rhs);
252                self.coeffs = result.coeffs;
253            }
254        }
255    };
256}
257
258impl_arithmetic!(Add, AddAssign, add, add_assign, addition);
259impl_arithmetic!(Sub, SubAssign, sub, sub_assign, subtraction);
260impl_arithmetic!(Mul, MulAssign, mul, mul_assign, multiplication);
261
262impl<const D: usize> From<Vec<Zq>> for Rq<D> {
263    fn from(vec: Vec<Zq>) -> Self {
264        let mut temp = [Zq::ZERO; D];
265        // Process excess terms with sign adjustment
266        for i in (0..vec.len()).rev() {
267            let m = i / D;
268            let r = i % D;
269            let sign = if m % 2 == 0 { 1 } else { -1 };
270            if sign == 1 {
271                temp[r] += vec[i];
272            } else {
273                temp[r] -= vec[i];
274            }
275        }
276        Rq::new(temp)
277    }
278}
279
280impl Sum for Zq {
281    // Accumulate using the addition operator
282    fn sum<I>(iter: I) -> Self
283    where
284        I: Iterator<Item = Zq>,
285    {
286        iter.fold(Zq::ZERO, |acc, x| acc + x)
287    }
288}
289
290// Implementing the Neg trait
291impl<const D: usize> Neg for Rq<D> {
292    type Output = Self;
293
294    /// Polynomial negation
295    fn neg(self) -> Self {
296        let mut result = [Zq::ZERO; D];
297        for (i, &coeff) in self.coeffs.iter().enumerate() {
298            result[i] = Zq::ZERO - coeff;
299        }
300        Rq::new(result)
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    // Test new() and polynomial creation
309    #[test]
310    fn test_new_and_create_poly() {
311        let poly = Rq::new([Zq::ONE, Zq::new(2), Zq::new(3), Zq::new(4)]);
312        assert_eq!(poly.coeffs, [Zq::ONE, Zq::new(2), Zq::new(3), Zq::new(4)]);
313
314        // Direct conversion
315        let poly_from_vec_direct: Rq<4> = vec![Zq::ONE, Zq::new(2), Zq::new(3), Zq::new(4)].into();
316        assert_eq!(
317            poly_from_vec_direct.coeffs,
318            [Zq::ONE, Zq::new(2), Zq::new(3), Zq::new(4)]
319        );
320        // Wrapping around
321        let poly_from_vec_wrapping: Rq<4> =
322            vec![Zq::ONE, Zq::new(2), Zq::new(3), Zq::new(4), Zq::ONE].into();
323        assert_eq!(
324            poly_from_vec_wrapping.coeffs,
325            [Zq::ZERO, Zq::new(2), Zq::new(3), Zq::new(4)]
326        );
327        // Filling up with zeros
328        let poly_from_vec_zeros: Rq<4> = vec![Zq::ONE, Zq::new(2)].into();
329        assert_eq!(
330            poly_from_vec_zeros.coeffs,
331            [Zq::ONE, Zq::new(2), Zq::ZERO, Zq::ZERO]
332        );
333        // High-Degree Term Reduction
334        let poly_high_degree_reduction: Rq<2> = vec![
335            Zq::ONE,
336            Zq::new(2),
337            Zq::ZERO,
338            Zq::ZERO,
339            Zq::ZERO,
340            Zq::ZERO,
341            Zq::ONE,
342        ]
343        .into();
344        assert_eq!(poly_high_degree_reduction.coeffs, [Zq::ZERO, Zq::new(2)]);
345    }
346
347    // Test addition of polynomials
348    #[test]
349    fn test_add() {
350        // Within bounds
351        let poly1: Rq<4> = vec![Zq::ONE, Zq::new(2), Zq::new(3), Zq::new(4)].into();
352        let poly2: Rq<4> = vec![Zq::new(4), Zq::new(3), Zq::new(2), Zq::ONE].into();
353        let result = poly1 + poly2;
354        assert_eq!(
355            result.coeffs,
356            [Zq::new(5), Zq::new(5), Zq::new(5), Zq::new(5)]
357        );
358
359        // Outside of bounds
360        let poly3: Rq<4> = vec![Zq::ONE, Zq::new(2), Zq::new(3), Zq::new(4)].into();
361        let poly4: Rq<4> = vec![Zq::MAX, Zq::new(3), Zq::MAX, Zq::ONE].into();
362        let result2 = poly3 + poly4;
363        assert_eq!(
364            result2.coeffs,
365            [Zq::ZERO, Zq::new(5), Zq::new(2), Zq::new(5)]
366        );
367        // Addition with zero polynomial
368        let poly5: Rq<4> = vec![Zq::ONE, Zq::new(2), Zq::new(3), Zq::new(4)].into();
369        let poly6: Rq<4> = vec![Zq::ZERO].into();
370        let result3 = poly5 + poly6;
371        assert_eq!(
372            result3.coeffs,
373            [Zq::ONE, Zq::new(2), Zq::new(3), Zq::new(4)]
374        );
375        // Addition with high coefficients
376        let poly7: Rq<4> = vec![Zq::ONE, Zq::new(2), Zq::new(3), Zq::MAX].into();
377        let poly8: Rq<4> = vec![Zq::MAX, Zq::MAX, Zq::MAX, Zq::MAX].into();
378        let result3 = poly7 + poly8;
379        assert_eq!(
380            result3.coeffs,
381            [
382                Zq::ZERO,
383                Zq::ONE,
384                Zq::new(2),
385                Zq::new(u32::MAX.wrapping_add(u32::MAX))
386            ]
387        );
388    }
389    // Test multiplication of polynomials
390    #[test]
391
392    fn test_mul() {
393        // Multiplication with wrapping
394        let poly1: Rq<3> = vec![Zq::ONE, Zq::ONE, Zq::new(2)].into();
395        let poly2: Rq<3> = vec![Zq::ONE, Zq::ONE].into();
396        let result = poly1 * poly2;
397        assert_eq!(result.coeffs, [Zq::MAX, Zq::new(2), Zq::new(3)]);
398
399        // Multiplication with zero polynomial
400        let poly3: Rq<3> = vec![Zq::ONE, Zq::ONE, Zq::new(2)].into();
401        let poly4: Rq<3> = vec![Zq::ZERO].into();
402        let result2 = poly3 * poly4;
403        assert_eq!(result2.coeffs, [Zq::ZERO, Zq::ZERO, Zq::ZERO]);
404
405        // Multiplication with wrapping higher order
406        let poly5: Rq<3> = vec![Zq::ONE, Zq::ONE, Zq::new(2)].into();
407        let poly6: Rq<3> = vec![Zq::ONE, Zq::ONE, Zq::new(7), Zq::new(5)].into();
408        let result3 = poly5 * poly6;
409        assert_eq!(
410            result3.coeffs,
411            [Zq::new(u32::MAX - 12), Zq::new(u32::MAX - 16), Zq::ZERO]
412        );
413    }
414
415    // Test subtraction of polynomials
416    #[test]
417    fn test_sub() {
418        // within bounds
419        let poly1: Rq<4> = vec![Zq::new(5), Zq::new(10), Zq::new(15), Zq::new(20)].into();
420        let poly2: Rq<4> = vec![Zq::new(2), Zq::new(4), Zq::new(6), Zq::new(8)].into();
421        let result = poly1 - poly2;
422        assert_eq!(
423            result.coeffs,
424            [Zq::new(3), Zq::new(6), Zq::new(9), Zq::new(12)]
425        );
426
427        // Outside of bounds
428        let poly3: Rq<4> = vec![Zq::ONE, Zq::ONE, Zq::new(3), Zq::new(2)].into();
429        let poly4: Rq<4> = vec![Zq::new(2), Zq::new(4), Zq::new(6), Zq::new(8)].into();
430        let result2 = poly3 - poly4;
431        assert_eq!(
432            result2.coeffs,
433            [
434                Zq::MAX,
435                Zq::new(u32::MAX - 2),
436                Zq::new(u32::MAX - 2),
437                Zq::new(u32::MAX - 5)
438            ]
439        );
440        // Subtraction with zero polynomial
441        let poly5: Rq<4> = vec![Zq::ONE, Zq::new(2), Zq::new(3), Zq::new(4)].into();
442        let poly6: Rq<4> = vec![Zq::ZERO].into();
443        let result3 = poly6.clone() - poly5.clone();
444        let result4 = poly5.clone() - poly6.clone();
445        assert_eq!(
446            result3.coeffs,
447            [
448                Zq::MAX,
449                Zq::new(u32::MAX - 1),
450                Zq::new(u32::MAX - 2),
451                Zq::new(u32::MAX - 3)
452            ]
453        );
454        assert_eq!(
455            result4.coeffs,
456            [Zq::ONE, Zq::new(2), Zq::new(3), Zq::new(4)]
457        );
458    }
459
460    // Test negation of polynomial
461    #[test]
462    fn test_neg() {
463        let poly: Rq<4> = vec![Zq::ONE, Zq::new(2), Zq::new(3), Zq::new(4)].into();
464        let result = -poly;
465        assert_eq!(
466            result.coeffs,
467            [
468                Zq::MAX,
469                Zq::new(u32::MAX - 1),
470                Zq::new(u32::MAX - 2),
471                Zq::new(u32::MAX - 3)
472            ]
473        );
474    }
475
476    // Test scalar multiplication
477    #[test]
478    fn test_scalar_mul() {
479        let poly: Rq<4> = vec![Zq::ONE, Zq::new(2), Zq::new(3), Zq::new(4)].into();
480        let result = poly.scalar_mul(Zq::new(2));
481        assert_eq!(
482            result.coeffs,
483            [Zq::new(2), Zq::new(4), Zq::new(6), Zq::new(8)]
484        );
485    }
486
487    // Test polynomial evaluation
488    #[test]
489    fn test_eval() {
490        let poly: Rq<4> = vec![Zq::ONE, Zq::new(2), Zq::new(3), Zq::new(4)].into();
491        let result = poly.eval(Zq::new(2));
492        assert_eq!(result, Zq::new(49));
493    }
494
495    // Test equality check
496    #[test]
497    fn test_is_equal() {
498        let poly1: Rq<4> = vec![Zq::ONE, Zq::new(2), Zq::new(3), Zq::new(4)].into();
499        let poly2: Rq<4> = vec![Zq::ONE, Zq::new(2), Zq::new(3), Zq::new(4)].into();
500        let poly3: Rq<4> = vec![Zq::new(4), Zq::new(3), Zq::new(2), Zq::ONE].into();
501        assert!(poly1.is_equal(&poly2));
502        assert!(!poly1.is_equal(&poly3));
503    }
504
505    // Test zero polynomial check
506    #[test]
507    fn test_is_zero_poly() {
508        let zero_poly: Rq<4> = vec![Zq::ZERO; 4].into();
509        let non_zero_poly: Rq<4> = vec![Zq::ONE, Zq::ZERO, Zq::ZERO, Zq::ZERO].into();
510        assert!(zero_poly.is_zero());
511        assert!(!non_zero_poly.is_zero());
512    }
513
514    #[test]
515    fn test_encode_message() {
516        // Test successful encoding
517        let message = vec![true, false, true, false];
518        let encoded = Rq::<4>::encode_message(&message).unwrap();
519        assert_eq!(encoded.coeffs, [Zq::ONE, Zq::ZERO, Zq::ONE, Zq::ZERO]);
520
521        // Test message shorter than degree
522        let short_message = vec![true, false];
523        let encoded_short = Rq::<4>::encode_message(&short_message).unwrap();
524        assert_eq!(
525            encoded_short.coeffs,
526            [Zq::ONE, Zq::ZERO, Zq::ZERO, Zq::ZERO]
527        );
528
529        // Test message too long
530        let long_message = vec![true; 5];
531        assert!(Rq::<4>::encode_message(&long_message).is_none());
532
533        // Test empty message
534        let empty_message: Vec<bool> = vec![];
535        let encoded_empty = Rq::<4>::encode_message(&empty_message).unwrap();
536        assert!(encoded_empty.is_zero());
537    }
538
539    // Test coefficient extraction
540    #[test]
541    fn test_get_coefficient() {
542        let poly: Rq<4> = vec![Zq::ONE, Zq::ZERO, Zq::new(5), Zq::MAX].into();
543        let vec = vec![Zq::ONE, Zq::ZERO, Zq::new(5), Zq::MAX];
544        assert!(poly.get_coefficients().to_vec() == vec);
545
546        let poly_zero: Rq<4> = vec![Zq::ZERO, Zq::ZERO, Zq::ZERO, Zq::ZERO].into();
547        let vec_zero = vec![Zq::ZERO, Zq::ZERO, Zq::ZERO, Zq::ZERO];
548        assert!(poly_zero.get_coefficients().to_vec() == vec_zero);
549    }
550
551    #[test]
552    fn test_base2_decomposition() {
553        // Test case 1: Base 2 decomposition
554        let poly: Rq<4> = vec![Zq::new(5), Zq::new(3), Zq::new(7), Zq::new(1)].into();
555        let parts = poly.decompose(Zq::TWO, 2);
556
557        // Part 0: remainders mod 2 (no centering needed for base 2)
558        assert_eq!(
559            parts[0].coeffs,
560            [
561                Zq::ONE, // 5 mod 2 = 1
562                Zq::ONE, // 3 mod 2 = 1
563                Zq::ONE, // 7 mod 2 = 1
564                Zq::ONE, // 1 mod 2 = 1
565            ]
566        );
567
568        // Part 1: quotients after division by 2
569        assert_eq!(
570            parts[1].coeffs,
571            [
572                Zq::new(2), // 5 div 2 = 2
573                Zq::ONE,    // 3 div 2 = 1
574                Zq::new(3), // 7 div 2 = 3
575                Zq::ZERO,   // 1 div 2 = 0
576            ]
577        );
578
579        // Verify Base 2 reconstruction coefficient by coefficient
580        for i in 0..4 {
581            let expected = poly.coeffs[i];
582            let actual = parts[0].coeffs[i] + parts[1].coeffs[i] * Zq::TWO;
583            assert_eq!(actual, expected, "Base 2: Coefficient {} mismatch", i);
584        }
585    }
586
587    #[test]
588    fn test_base3_decomposition() {
589        // Test case: Base 3 decomposition with centering
590        let specific_poly: Rq<4> = vec![Zq::new(8), Zq::new(11), Zq::new(4), Zq::new(15)].into();
591        let parts = specific_poly.decompose(Zq::new(3), 2);
592
593        // Part 0: centered remainders mod 3
594        assert_eq!(
595            parts[0].coeffs,
596            [
597                Zq::MAX,  // 8 mod 3 = 2 -> -1 (centered)
598                Zq::MAX,  // 11 mod 3 = 2 -> -1 (centered)
599                Zq::ONE,  // 4 mod 3 = 1 -> 1 (centered)
600                Zq::ZERO, // 15 mod 3 = 0 -> 0 (centered)
601            ]
602        );
603
604        // Part 1: quotients
605        assert_eq!(
606            parts[1].coeffs,
607            [
608                Zq::new(3), // (8 + 1) div 3 = 3
609                Zq::new(4), // (11 + 1) div 3 = 4
610                Zq::ONE,    // 4 div 3 = 1
611                Zq::new(5), // 15 div 3 = 5
612            ]
613        );
614
615        // Verify Base 3 reconstruction coefficient by coefficient
616        for i in 0..4 {
617            let expected = specific_poly.coeffs[i];
618            let p0 = parts[0].coeffs[i];
619            let p1 = parts[1].coeffs[i];
620            let actual = p0 + p1 * Zq::new(3);
621            assert_eq!(actual, expected, "Base 3: Coefficient {} mismatch", i);
622        }
623    }
624
625    #[test]
626    fn test_decomposition_edge_cases() {
627        // Test zero polynomial
628        let zero_poly: Rq<4> = vec![Zq::ZERO; 4].into();
629        let parts = zero_poly.decompose(Zq::TWO, 2);
630        assert!(
631            parts.iter().all(|p| p.is_zero()),
632            "Zero polynomial decomposition failed"
633        );
634
635        // Test single part decomposition
636        let simple_poly: Rq<4> = vec![Zq::ONE, Zq::new(2), Zq::new(3), Zq::new(4)].into();
637        let parts = simple_poly.decompose(Zq::TWO, 1);
638        assert_eq!(parts.len(), 1, "Single part decomposition length incorrect");
639        assert_eq!(
640            parts[0], simple_poly,
641            "Single part decomposition value incorrect"
642        );
643    }
644
645    #[test]
646    fn test_large_base_decomposition() {
647        // Test decomposition with larger bases (8 and 16)
648        let poly: Rq<4> = vec![Zq::new(120), Zq::new(33), Zq::new(255), Zq::new(19)].into();
649
650        // Base 8 decomposition
651        let parts_base8 = poly.decompose(Zq::new(8), 2);
652
653        // Part 0: centered remainders mod 8
654        assert_eq!(
655            parts_base8[0].coeffs,
656            [
657                Zq::ZERO,   // 120 mod 8 = 0 -> 0 (centered)
658                Zq::ONE,    // 33 mod 8 = 1 -> 1 (centered)
659                Zq::MAX,    // 255 mod 8 = 7 -> -1 (centered)
660                Zq::new(3), // 19 mod 8 = 3 -> 3 (centered)
661            ]
662        );
663
664        // Part 1: quotients
665        assert_eq!(
666            parts_base8[1].coeffs,
667            [
668                Zq::new(15), // 120 div 8 = 15
669                Zq::new(4),  // 33 div 8 = 4
670                Zq::new(32), // (255 + 1) div 8 = 32
671                Zq::new(2),  // 19 div 8 = 2
672            ]
673        );
674
675        // Verify reconstruction coefficient by coefficient
676        for i in 0..4 {
677            let expected = poly.coeffs[i];
678            let p0 = parts_base8[0].coeffs[i];
679            let p1 = parts_base8[1].coeffs[i];
680            let actual = p0 + p1 * Zq::new(8);
681            assert_eq!(actual, expected, "Base 8: Coefficient {} mismatch", i);
682        }
683
684        // Base 16 decomposition
685        let parts_base16 = poly.decompose(Zq::new(16), 2);
686
687        // Verify reconstruction for base 16
688        for i in 0..4 {
689            let expected = poly.coeffs[i];
690            let p0 = parts_base16[0].coeffs[i];
691            let p1 = parts_base16[1].coeffs[i];
692            let actual = p0 + p1 * Zq::new(16);
693            assert_eq!(actual, expected, "Base 16: Coefficient {} mismatch", i);
694        }
695    }
696
697    #[test]
698    fn test_multi_part_decomposition() {
699        // Test with more than 2 parts
700        let poly: Rq<4> = vec![Zq::new(123), Zq::new(456), Zq::new(789), Zq::new(101112)].into();
701
702        // Decompose into 3 parts with base 4
703        let parts = poly.decompose(Zq::new(4), 3);
704        assert_eq!(parts.len(), 3, "Should have 3 parts");
705
706        // Test reconstruction with all 3 parts
707        let reconstructed = parts[0].clone()
708            + parts[1].clone().scalar_mul(Zq::new(4))
709            + parts[2].clone().scalar_mul(Zq::new(16)); // 4²
710
711        // Verify reconstruction coefficient by coefficient
712        for i in 0..4 {
713            assert_eq!(
714                reconstructed.coeffs[i], poly.coeffs[i],
715                "3-part base 4: Coefficient {} mismatch",
716                i
717            );
718        }
719    }
720
721    #[test]
722    fn test_centering_properties() {
723        // Test that centering works correctly for various values
724        // Using base 5 which has half_base = 2
725        let values = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
726        let poly: Rq<11> = values
727            .iter()
728            .map(|&v| Zq::new(v))
729            .collect::<Vec<Zq>>()
730            .into();
731
732        let parts = poly.decompose(Zq::new(5), 2);
733
734        // Expected centered values for each coefficient:
735        // 0 mod 5 = 0 -> 0
736        // 1 mod 5 = 1 -> 1
737        // 2 mod 5 = 2 -> 2 (at threshold)
738        // 3 mod 5 = 3 -> -2 (centered)
739        // 4 mod 5 = 4 -> -1 (centered)
740        // 5 mod 5 = 0 -> 0
741        // ... and so on
742        let expected_centered = [
743            Zq::ZERO,    // 0 centered
744            Zq::ONE,     // 1 centered
745            Zq::new(2),  // 2 centered (at threshold)
746            -Zq::new(2), // 3 centered to -2
747            -Zq::ONE,    // 4 centered to -1
748            Zq::ZERO,    // 5 centered
749            Zq::ONE,     // 6 centered
750            Zq::new(2),  // 7 centered
751            -Zq::new(2), // 8 centered to -2
752            -Zq::ONE,    // 9 centered to -1
753            Zq::ZERO,    // 10 centered
754        ];
755
756        for (i, &expected) in expected_centered.iter().enumerate() {
757            assert_eq!(
758                parts[0].coeffs[i], expected,
759                "Base 5 centering: Coefficient {} incorrectly centered",
760                i
761            );
762        }
763    }
764
765    #[test]
766    fn test_extreme_values() {
767        // Test with values near the extremes of the Zq range
768        let poly: Rq<3> = vec![Zq::ZERO, Zq::MAX, Zq::MAX - Zq::ONE].into();
769
770        // Decompose with base 3
771        let parts = poly.decompose(Zq::new(3), 2);
772
773        // Verify reconstruction
774        let reconstructed = parts[0].clone() + parts[1].clone().scalar_mul(Zq::new(3));
775
776        for i in 0..3 {
777            assert_eq!(
778                reconstructed.coeffs[i], poly.coeffs[i],
779                "Extreme values: Coefficient {} mismatch",
780                i
781            );
782        }
783
784        // Corrected test for high value divisibility
785        // u32::MAX = 4294967295, which equals 1431655765 * 3 + 0
786        // So u32::MAX mod 3 = 0, which remains 0 (no centering needed)
787        assert_eq!(parts[0].coeffs[1], Zq::ZERO); // Remainder after division by 3
788        assert_eq!(parts[1].coeffs[1], Zq::new(1431655765)); // Quotient
789
790        // Check u32::MAX - 1 as well
791        // 4294967294 mod 3 = 1, which remains 1 (no centering needed since 1 <= half_base)
792        assert_eq!(parts[0].coeffs[2], Zq::MAX); // u32::MAX - 1 is the third coefficient
793        assert_eq!(parts[1].coeffs[2], Zq::new(1431655765)); // Should be same quotient
794    }
795
796    #[test]
797    fn test_decomposition_properties() {
798        // Test the algebraic property that all coefficients in first part should be small
799        let poly: Rq<8> = vec![
800            Zq::new(100),
801            Zq::new(200),
802            Zq::new(300),
803            Zq::new(400),
804            Zq::new(500),
805            Zq::new(600),
806            Zq::new(700),
807            Zq::new(800),
808        ]
809        .into();
810
811        for base in [2, 3, 4, 5, 8, 10, 16].iter() {
812            let parts = poly.decompose(Zq::new(*base), 2);
813            let half_base = Zq::new(*base).scale_by(Zq::TWO);
814
815            // Check that all coefficients in first part are properly "small"
816            for coeff in parts[0].coeffs.iter() {
817                // In centered representation, all coefficients should be <= half_base
818                let abs_coeff = if *coeff > Zq::new(u32::MAX / 2) {
819                    Zq::ZERO - *coeff // Handle negative values (represented as large positive ones)
820                } else {
821                    *coeff
822                };
823
824                assert!(
825                    abs_coeff <= half_base,
826                    "Base {}: First part coefficient {} exceeds half-base {}",
827                    base,
828                    coeff,
829                    half_base
830                );
831            }
832
833            // Verify reconstruction
834            let reconstructed = parts[0].clone() + parts[1].clone().scalar_mul(Zq::new(*base));
835            assert_eq!(reconstructed, poly, "Base {}: Reconstruction failed", base);
836        }
837    }
838}