1use std::ops::{Add, Mul, Sub};
2
3use crate::ring::rq::Rq;
4use crate::ring::rq_vector::RqVector;
5use crate::ring::zq::Zq;
6use rand::distr::{Distribution, Uniform};
7use rand::{CryptoRng, Rng};
8use rustfft::{num_complex::Complex, FftPlanner};
9
10#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Default)]
12pub struct PolyRing {
13 coeffs: Vec<Zq>,
14}
15impl PolyRing {
16 pub const DEGREE_BOUND: usize = 64;
18
19 pub fn new(coeffs: Vec<Zq>) -> Self {
20 assert!(
21 coeffs.len() <= Self::DEGREE_BOUND,
22 "Polynomial degree should be less than {}",
23 Self::DEGREE_BOUND
24 );
25 Self { coeffs }
26 }
27 pub fn zero(degree: usize) -> Self {
28 Self::new(vec![Zq::ZERO; degree])
29 }
30
31 pub fn zero_poly() -> Self {
32 Self::new(vec![Zq::ZERO; 1])
33 }
34
35 pub fn len(&self) -> usize {
36 self.coeffs.len()
37 }
38
39 pub fn is_empty(&self) -> bool {
40 self.coeffs.is_empty()
41 }
42
43 pub fn get_coeffs(&self) -> &Vec<Zq> {
44 &self.coeffs
45 }
46
47 pub fn iter(&self) -> impl Iterator<Item = &Zq> {
48 self.coeffs.iter()
49 }
50
51 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Zq> {
52 self.coeffs.iter_mut()
53 }
54
55 pub fn inner_product(&self, other: &Self) -> Zq {
57 self.coeffs
58 .iter()
59 .zip(other.coeffs.iter())
60 .map(|(a, b)| *a * *b)
61 .sum()
62 }
63
64 pub fn random<R: Rng + CryptoRng>(rng: &mut R, n: usize) -> Self {
66 let uniform = Uniform::new_inclusive(Zq::ZERO, Zq::MAX).unwrap();
67 let mut coeffs = Vec::with_capacity(n);
68 coeffs.extend((0..n).map(|_| uniform.sample(rng)));
69 Self { coeffs }
70 }
71
72 pub fn random_ternary<R: Rng + CryptoRng>(rng: &mut R, n: usize) -> Self {
74 let mut coeffs = vec![Zq::ZERO; n];
75
76 for coeff in coeffs.iter_mut() {
77 let val = match rng.random_range(0..3) {
79 0 => Zq::MAX, 1 => Zq::ZERO, 2 => Zq::ONE, _ => unreachable!(),
83 };
84 *coeff = val;
85 }
86
87 Self::new(coeffs)
88 }
89
90 pub fn conjugate_automorphism(&self) -> PolyRing {
92 let q_minus_1 = Zq::MAX;
93 let mut new_coeffs = vec![Zq::ZERO; PolyRing::DEGREE_BOUND];
94 for (i, new_coeff) in new_coeffs
95 .iter_mut()
96 .enumerate()
97 .take(PolyRing::DEGREE_BOUND)
98 {
99 if i < self.get_coeffs().len() {
100 if i == 0 {
101 *new_coeff = self.get_coeffs()[i];
102 } else {
103 *new_coeff = self.get_coeffs()[i] * q_minus_1;
104 }
105 } else {
106 *new_coeff = Zq::ZERO;
107 }
108 }
109 let reversed_coefficients = new_coeffs
110 .iter()
111 .take(1)
112 .cloned()
113 .chain(new_coeffs.iter().skip(1).rev().cloned())
114 .collect::<Vec<Zq>>();
115
116 PolyRing::new(reversed_coefficients)
117 }
118
119 #[allow(clippy::as_conversions)]
127 pub fn operator_norm(&self) -> f64 {
128 let coeffs = self.get_coeffs();
129 let n = coeffs.len();
130 let mut planner = FftPlanner::new();
131 let fft = planner.plan_fft_forward(n);
132
133 let mut buffer: Vec<Complex<f64>> = coeffs
135 .iter()
136 .map(|&x| {
137 let half = Zq::MAX.scale_by(Zq::TWO);
138 let converted_value = if x > half {
139 x.to_u128() as f64 - Zq::MAX.to_u128() as f64 - 1.0
140 } else {
141 x.to_u128() as f64
142 };
143 Complex {
144 re: converted_value,
145 im: 0.0,
146 }
147 })
148 .collect();
149
150 fft.process(&mut buffer);
152
153 buffer
155 .iter()
156 .map(|c| c.norm())
157 .fold(0.0, |max, x| max.max(x))
158 }
159
160 pub fn decompose(&self, base: Zq, num_parts: usize) -> PolyVector {
164 let mut parts = Vec::with_capacity(num_parts);
165 let mut current = self.clone();
166
167 for i in 0..num_parts {
168 if i == num_parts - 1 {
169 parts.push(current.clone());
170 } else {
171 let mut low_coeffs = vec![Zq::ZERO; self.len()];
173
174 for (j, coeff) in current.get_coeffs().iter().enumerate() {
175 low_coeffs[j] = coeff.centered_mod(base);
176 }
177
178 let low_part = Self::new(low_coeffs);
179 parts.push(low_part.clone());
180
181 current = ¤t - &low_part;
183
184 let mut scaled_coeffs = vec![Zq::ZERO; self.len()];
186 for (j, coeff) in current.get_coeffs().iter().enumerate() {
187 scaled_coeffs[j] = coeff.scale_by(base);
188 }
189 current = Self::new(scaled_coeffs);
190 }
191 }
192
193 PolyVector::new(parts)
194 }
195}
196
197impl<const D: usize> From<PolyRing> for Rq<D> {
198 fn from(zqs: PolyRing) -> Self {
199 zqs.get_coeffs().clone().into()
200 }
201}
202
203impl FromIterator<Zq> for PolyRing {
204 fn from_iter<T: IntoIterator<Item = Zq>>(iter: T) -> Self {
205 let coeffs: Vec<Zq> = iter.into_iter().collect();
206 PolyRing::new(coeffs)
207 }
208}
209
210impl Add<&PolyRing> for &PolyRing {
211 type Output = PolyRing;
212 fn add(self, other: &PolyRing) -> PolyRing {
214 let max_degree = self.get_coeffs().len().max(other.get_coeffs().len());
215 let mut coeffs = vec![Zq::ZERO; max_degree];
216 for (i, coeff) in coeffs.iter_mut().enumerate().take(max_degree) {
217 if i < self.get_coeffs().len() {
218 *coeff += self.get_coeffs()[i];
219 }
220 if i < other.get_coeffs().len() {
221 *coeff += other.get_coeffs()[i];
222 }
223 }
224 PolyRing::new(coeffs)
225 }
226}
227
228impl Sub<&PolyRing> for &PolyRing {
229 type Output = PolyRing;
230 fn sub(self, other: &PolyRing) -> PolyRing {
232 let max_degree = self.get_coeffs().len().max(other.get_coeffs().len());
233 let mut coeffs = vec![Zq::ZERO; max_degree];
234 for (i, coeff) in coeffs.iter_mut().enumerate().take(max_degree) {
235 if i < self.get_coeffs().len() {
236 *coeff += self.get_coeffs()[i];
237 }
238 if i < other.get_coeffs().len() {
239 *coeff -= other.get_coeffs()[i];
240 }
241 }
242 PolyRing::new(coeffs)
243 }
244}
245
246impl Mul<&PolyRing> for &PolyRing {
247 type Output = PolyRing;
248 fn mul(self, other: &PolyRing) -> PolyRing {
250 let mut result_coefficients =
252 vec![Zq::new(0); self.get_coeffs().len() + other.get_coeffs().len() - 1];
253 for (i, &coeff1) in self.get_coeffs().iter().enumerate() {
254 for (j, &coeff2) in other.get_coeffs().iter().enumerate() {
255 result_coefficients[i + j] += coeff1 * coeff2;
256 }
257 }
258
259 if result_coefficients.len() > PolyRing::DEGREE_BOUND {
261 let q_minus_1 = Zq::MAX;
262 let (left, right) = result_coefficients.split_at_mut(PolyRing::DEGREE_BOUND);
263 for (i, &overflow) in right.iter().enumerate() {
264 left[i] += overflow * q_minus_1;
265 }
266 result_coefficients.truncate(PolyRing::DEGREE_BOUND);
267 }
268 PolyRing::new(result_coefficients)
269 }
270}
271
272impl Mul<&Zq> for &PolyRing {
273 type Output = PolyRing;
274 fn mul(self, other: &Zq) -> PolyRing {
276 PolyRing::new(self.coeffs.iter().map(|c| c * *other).collect())
277 }
278}
279
280#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Default)]
281pub struct PolyVector {
282 elements: Vec<PolyRing>,
283}
284
285impl PolyVector {
286 pub fn new(elements: Vec<PolyRing>) -> Self {
287 Self { elements }
288 }
289
290 pub fn zero() -> Self {
291 Self {
292 elements: vec![PolyRing::zero(0)],
293 }
294 }
295
296 pub fn get_elements(&self) -> &Vec<PolyRing> {
297 &self.elements
298 }
299
300 pub fn len(&self) -> usize {
301 self.elements.len()
302 }
303
304 pub fn is_empty(&self) -> bool {
305 self.elements.is_empty()
306 }
307
308 pub fn iter(&self) -> impl Iterator<Item = &PolyRing> {
309 self.elements.iter()
310 }
311
312 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut PolyRing> {
313 self.elements.iter_mut()
314 }
315
316 pub fn random(n: usize, m: usize) -> Self {
318 let mut vector = PolyVector::new(vec![]);
319 vector.elements = (0..n)
320 .map(|_| PolyRing::random(&mut rand::rng(), m))
321 .collect();
322 vector
323 }
324
325 pub fn random_ternary(n: usize, m: usize) -> Self {
327 let mut vector = PolyVector::new(vec![]);
328 vector.elements = (0..n)
329 .map(|_| PolyRing::random_ternary(&mut rand::rng(), m))
330 .collect();
331 vector
332 }
333
334 pub fn inner_product_poly_vector(&self, other: &PolyVector) -> PolyRing {
335 self.iter().zip(other.iter()).map(|(a, b)| a * b).fold(
336 PolyRing::zero(self.get_elements()[0].get_coeffs().len()),
337 |acc, val| &acc + &val,
338 )
339 }
340
341 pub fn compute_norm_squared(&self) -> Zq {
343 self.elements
344 .iter()
345 .flat_map(|poly| poly.get_coeffs()) .map(|coeff| *coeff * *coeff)
347 .sum()
348 }
349
350 pub fn concatenate_coefficients(&self, s: usize) -> ZqVector {
352 let total_coeffs = self.get_elements().len() * s;
353 let mut concatenated_coeffs: Vec<Zq> = Vec::with_capacity(total_coeffs);
354 for rq in self.get_elements() {
356 let coeffs = rq.get_coeffs();
357 concatenated_coeffs.extend_from_slice(coeffs);
358 }
359
360 ZqVector {
361 coeffs: concatenated_coeffs,
362 }
363 }
364
365 pub fn decompose(&self, b: Zq, parts: usize) -> Vec<PolyVector> {
366 self.iter()
367 .map(|i| PolyRing::decompose(i, b, parts))
368 .collect()
369 }
370}
371
372impl<const N: usize, const D: usize> From<PolyVector> for RqVector<N, D> {
373 fn from(polys: PolyVector) -> Self {
374 let mut rq_vector = RqVector::zero();
375 for (i, poly) in polys.elements.iter().enumerate() {
376 rq_vector[i] = poly.get_coeffs().clone().into();
377 }
378 rq_vector
379 }
380}
381
382impl FromIterator<PolyRing> for PolyVector {
383 fn from_iter<T: IntoIterator<Item = PolyRing>>(iter: T) -> Self {
384 let mut elements = Vec::new();
385 for item in iter {
386 elements.push(item);
387 }
388 PolyVector::new(elements)
389 }
390}
391
392impl Add<&PolyVector> for &PolyVector {
393 type Output = PolyVector;
394 fn add(self, other: &PolyVector) -> PolyVector {
396 self.iter().zip(other.iter()).map(|(a, b)| a + b).collect()
397 }
398}
399
400impl Mul<&Zq> for &PolyVector {
401 type Output = PolyVector;
402 fn mul(self, other: &Zq) -> PolyVector {
404 self.iter().map(|a| a * other).collect()
405 }
406}
407
408impl Mul<&PolyRing> for &PolyVector {
409 type Output = PolyVector;
410 fn mul(self, other: &PolyRing) -> PolyVector {
412 self.iter().map(|s| s * other).collect()
413 }
414}
415
416impl Mul<&Vec<PolyVector>> for &PolyVector {
417 type Output = PolyVector;
418 fn mul(self, other: &Vec<PolyVector>) -> PolyVector {
420 other
421 .iter()
422 .map(|o| o.inner_product_poly_vector(self))
423 .collect()
424 }
425}
426
427#[derive(Debug, Clone, PartialEq, Eq)]
430pub struct ZqVector {
431 coeffs: Vec<Zq>,
432}
433impl ZqVector {
434 pub fn new(coeffs: Vec<Zq>) -> Self {
435 Self { coeffs }
436 }
437
438 pub fn zero(len: usize) -> Self {
439 Self::new(vec![Zq::ZERO; len])
440 }
441
442 pub fn get_coeffs(&self) -> &Vec<Zq> {
443 &self.coeffs
444 }
445
446 pub fn len(&self) -> usize {
447 self.coeffs.len()
448 }
449
450 pub fn is_empty(&self) -> bool {
451 self.coeffs.is_empty()
452 }
453
454 pub fn random<R: Rng + CryptoRng>(rng: &mut R, n: usize) -> Self {
455 let uniform = Uniform::new_inclusive(Zq::ZERO, Zq::MAX).unwrap();
456 let mut coeffs = Vec::with_capacity(n);
457 coeffs.extend((0..n).map(|_| uniform.sample(rng)));
458 Self { coeffs }
459 }
460
461 pub fn iter(&self) -> impl Iterator<Item = &Zq> {
462 self.coeffs.iter()
463 }
464
465 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Zq> {
466 self.coeffs.iter_mut()
467 }
468
469 pub fn inner_product(&self, other: &Self) -> Zq {
471 self.coeffs
472 .iter()
473 .zip(other.coeffs.iter())
474 .map(|(&a, &b)| a * b)
475 .fold(Zq::ZERO, |acc, x| acc + x)
476 }
477
478 pub fn conjugate_automorphism(&self) -> ZqVector {
480 let q_minus_1 = Zq::MAX;
481 let mut new_coeffs = vec![Zq::ZERO; self.get_coeffs().len()];
482 for (i, new_coeff) in new_coeffs
483 .iter_mut()
484 .enumerate()
485 .take(self.get_coeffs().len())
486 {
487 if i < self.get_coeffs().len() {
488 if i == 0 {
489 *new_coeff = self.get_coeffs()[i];
490 } else {
491 *new_coeff = self.get_coeffs()[i] * q_minus_1;
492 }
493 } else {
494 *new_coeff = Zq::ZERO;
495 }
496 }
497 let reversed_coefficients = new_coeffs
498 .iter()
499 .take(1)
500 .cloned()
501 .chain(new_coeffs.iter().skip(1).rev().cloned())
502 .collect::<Vec<Zq>>();
503
504 ZqVector::new(reversed_coefficients)
505 }
506}
507
508impl<const D: usize> From<ZqVector> for Rq<D> {
509 fn from(zqs: ZqVector) -> Self {
510 zqs.get_coeffs().clone().into()
511 }
512}
513
514impl FromIterator<Zq> for ZqVector {
515 fn from_iter<T: IntoIterator<Item = Zq>>(iter: T) -> Self {
516 let coeffs: Vec<Zq> = iter.into_iter().collect();
517 ZqVector::new(coeffs)
518 }
519}
520
521impl Add<&ZqVector> for &ZqVector {
522 type Output = ZqVector;
523 fn add(self, other: &ZqVector) -> ZqVector {
525 let max_degree = self.get_coeffs().len().max(other.get_coeffs().len());
526 let mut coeffs = vec![Zq::ZERO; max_degree];
527 for (i, coeff) in coeffs.iter_mut().enumerate().take(max_degree) {
528 if i < self.get_coeffs().len() {
529 *coeff += self.get_coeffs()[i];
530 }
531 if i < other.get_coeffs().len() {
532 *coeff += other.get_coeffs()[i];
533 }
534 }
535 ZqVector::new(coeffs)
536 }
537}
538
539impl Mul<&ZqVector> for &ZqVector {
544 type Output = ZqVector;
545 fn mul(self, other: &ZqVector) -> ZqVector {
546 let mut result_coefficients =
547 vec![Zq::new(0); self.get_coeffs().len() + other.get_coeffs().len() - 1];
548 for (i, &coeff1) in self.get_coeffs().iter().enumerate() {
549 for (j, &coeff2) in other.get_coeffs().iter().enumerate() {
550 result_coefficients[i + j] += coeff1 * coeff2;
551 }
552 }
553
554 if result_coefficients.len() > self.get_coeffs().len() {
555 let q_minus_1 = Zq::MAX;
556 let (left, right) = result_coefficients.split_at_mut(self.get_coeffs().len());
557 for (i, &overflow) in right.iter().enumerate() {
558 left[i] += overflow * q_minus_1;
559 }
560 result_coefficients.truncate(self.get_coeffs().len());
561 }
562 ZqVector::new(result_coefficients)
563 }
564}
565
566#[cfg(test)]
567mod tests {
568 use super::*;
569
570 #[test]
571 fn test_conjugate_automorphism() {
572 let poly1: PolyRing = PolyRing::new(vec![Zq::ONE, Zq::TWO, Zq::new(3)]);
573 let poly2: PolyRing = PolyRing::new(vec![Zq::new(4), Zq::new(5), Zq::new(6)]);
574 let inner_12 = poly1.inner_product(&poly2);
575 let conjugated_1 = poly1.conjugate_automorphism();
576 let inner_conjugated_12 = &conjugated_1 * &poly2;
577
578 assert_eq!(inner_conjugated_12.len(), PolyRing::DEGREE_BOUND);
579 assert_eq!(inner_conjugated_12.get_coeffs()[0], Zq::from(32));
580 assert_eq!(inner_conjugated_12.get_coeffs()[1], Zq::from(17));
581 assert_eq!(inner_conjugated_12.get_coeffs()[2], Zq::new(6));
582
583 let ct_inner_conjugated_12 = inner_conjugated_12.get_coeffs()[0];
585 assert_eq!(ct_inner_conjugated_12, inner_12);
586 }
587
588 #[test]
589 fn test_polyring_to_rq() {
590 let polyring = PolyRing::new(vec![Zq::ONE, Zq::TWO, Zq::new(3)]);
591 let rq_vector: Rq<3> = polyring.into();
592 let expect_rq = Rq::new([Zq::ONE, Zq::TWO, Zq::new(3)]);
593 assert_eq!(rq_vector, expect_rq);
594 }
595
596 #[test]
597 fn test_zqvector_to_rq() {
598 let zq_vector = ZqVector::new(vec![Zq::ONE, Zq::TWO, Zq::new(3)]);
599 let rq_vector: Rq<3> = zq_vector.into();
600 let expect_rq = Rq::new([Zq::ONE, Zq::TWO, Zq::new(3)]);
601 assert_eq!(rq_vector, expect_rq);
602 }
603
604 #[test]
605 fn test_polyvector_to_rqvector() {
606 let poly_vector = PolyVector::new(vec![PolyRing::new(vec![Zq::ONE, Zq::TWO, Zq::new(3)])]);
607 let rqs: RqVector<1, 3> = poly_vector.into();
608 let expect_rq = Rq::new([Zq::ONE, Zq::TWO, Zq::new(3)]);
609 let expect_rqvector: RqVector<1, 3> = RqVector::from(vec![expect_rq]);
610 assert_eq!(rqs, expect_rqvector);
611 }
612
613 #[test]
614 fn test_scalar_mul_vector_mul_vector() {
615 let vector = PolyVector::new(vec![PolyRing::new(vec![Zq::ONE, Zq::TWO, Zq::new(3)])]);
616 let zq = Zq::new(2);
617 let result = &vector * &zq;
618 let expect = PolyVector::new(vec![PolyRing::new(vec![Zq::TWO, Zq::new(4), Zq::new(6)])]);
619 assert_eq!(result, expect)
620 }
621
622 #[test]
623 fn test_add_poly_vector() {
624 let vector1 = PolyVector::new(vec![PolyRing::new(vec![Zq::ONE, Zq::TWO, Zq::new(3)])]);
625 let vector2 = PolyVector::new(vec![PolyRing::new(vec![Zq::new(4), Zq::new(5)])]);
626 let result = &vector1 + &vector2;
627 let expect = PolyVector::new(vec![PolyRing::new(vec![
628 Zq::new(5),
629 Zq::new(7),
630 Zq::new(3),
631 ])]);
632 assert_eq!(result, expect)
633 }
634
635 #[test]
636 fn test_sub_polyring() {
637 let vector1 = PolyRing::new(vec![Zq::ONE, Zq::TWO, Zq::new(3)]);
638 let vector2 = PolyRing::new(vec![Zq::new(4), Zq::new(5)]);
639 let result = &vector1 - &vector2;
640 let expect = PolyRing::new(vec![Zq::MAX - Zq::TWO, Zq::MAX - Zq::TWO, Zq::new(3)]);
641 assert_eq!(result, expect)
642 }
643}