1pub mod collect_scalars;
4pub mod constraint_counter;
5pub mod degree_counter;
6pub mod do_nothing_builder;
7pub mod dummy_semiring;
8pub mod ideal;
9pub mod ideal_collector;
10pub mod lookup_types;
11
12use crypto_primitives::Semiring;
13use std::borrow::Cow;
14use zinc_poly::{
15 mle::DenseMultilinearExtension,
16 univariate::{binary::BinaryPoly, dense::DensePolynomial},
17};
18use zinc_utils::{UNCHECKED, add, from_ref::FromRef, mul_by_scalar::MulByScalar, sub};
19
20use crate::ideal::{Ideal, IdealCheck};
21
22pub use lookup_types::{LookupColumnSpec, LookupTableType};
23
24pub trait ConstraintBuilder {
27 type Expr: Semiring;
32 type Ideal: Ideal + IdealCheck<Self::Expr>;
34
35 fn assert_in_ideal(&mut self, expr: Self::Expr, ideal: &Self::Ideal);
37
38 fn assert_zero(&mut self, expr: Self::Expr);
41}
42
43#[derive(Clone, Debug, PartialEq, Eq, Hash)]
51pub struct ShiftSpec {
52 source_col: usize,
56 shift_amount: usize,
58}
59
60impl ShiftSpec {
61 pub fn new(source_col: usize, shift_amount: usize) -> Self {
62 assert!(shift_amount > 0, "shift must be non-zero");
63 Self {
64 source_col,
65 shift_amount,
66 }
67 }
68
69 pub fn source_col(&self) -> usize {
70 self.source_col
71 }
72
73 pub fn shift_amount(&self) -> usize {
74 self.shift_amount
75 }
76}
77
78#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
101pub enum BitOp {
102 Rot(usize),
106 ShR(usize),
109}
110
111impl BitOp {
112 pub fn count(&self) -> usize {
114 match self {
115 BitOp::Rot(c) | BitOp::ShR(c) => *c,
116 }
117 }
118}
119
120#[derive(Clone, Debug, PartialEq, Eq, Hash)]
129pub struct BitOpSpec {
130 source_col: usize,
133 op: BitOp,
135}
136
137impl BitOpSpec {
138 pub fn new(source_col: usize, op: BitOp) -> Self {
139 assert!(op.count() > 0, "bit-op count must be non-zero");
140 Self { source_col, op }
141 }
142
143 pub fn source_col(&self) -> usize {
144 self.source_col
145 }
146
147 pub fn op(&self) -> BitOp {
148 self.op
149 }
150}
151
152#[derive(Clone, Debug, Default)]
160pub struct ColumnLayout {
161 num_binary_poly_cols: usize,
162 num_arbitrary_poly_cols: usize,
163 num_int_cols: usize,
164}
165
166impl ColumnLayout {
167 pub fn new(
168 num_binary_poly_cols: usize,
169 num_arbitrary_poly_cols: usize,
170 num_int_cols: usize,
171 ) -> Self {
172 Self {
173 num_binary_poly_cols,
174 num_arbitrary_poly_cols,
175 num_int_cols,
176 }
177 }
178
179 pub fn num_binary_poly_cols(&self) -> usize {
180 self.num_binary_poly_cols
181 }
182
183 pub fn num_arbitrary_poly_cols(&self) -> usize {
184 self.num_arbitrary_poly_cols
185 }
186
187 pub fn num_int_cols(&self) -> usize {
188 self.num_int_cols
189 }
190
191 pub fn max_cols(&self) -> usize {
193 [
194 self.num_binary_poly_cols,
195 self.num_arbitrary_poly_cols,
196 self.num_int_cols,
197 ]
198 .into_iter()
199 .max()
200 .expect("the iterator is not empty")
201 }
202
203 #[allow(clippy::arithmetic_side_effects)]
205 pub fn cols(&self) -> usize {
206 self.num_binary_poly_cols + self.num_arbitrary_poly_cols + self.num_int_cols
207 }
208}
209
210macro_rules! column_layout_wrapper {
211 ($(#[$meta:meta])* $name:ident) => {
212 $(#[$meta])*
213 #[derive(Clone, Debug, Default)]
214 pub struct $name(ColumnLayout);
215
216 impl $name {
217 pub fn new(num_binary_poly_cols: usize, num_arbitrary_poly_cols: usize, num_int_cols: usize) -> Self {
218 Self(ColumnLayout::new(num_binary_poly_cols, num_arbitrary_poly_cols, num_int_cols))
219 }
220
221 pub fn num_binary_poly_cols(&self) -> usize { self.0.num_binary_poly_cols() }
222 pub fn num_arbitrary_poly_cols(&self) -> usize { self.0.num_arbitrary_poly_cols() }
223 pub fn num_int_cols(&self) -> usize { self.0.num_int_cols() }
224 pub fn max_cols(&self) -> usize { self.0.max_cols() }
225 pub fn cols(&self) -> usize { self.0.cols() }
226 pub fn as_column_layout(&self) -> &ColumnLayout { &self.0 }
227 }
228 };
229}
230
231column_layout_wrapper!(TotalColumnLayout);
233column_layout_wrapper!(PublicColumnLayout);
235column_layout_wrapper!(VirtualColumnLayout);
237column_layout_wrapper!(WitnessColumnLayout);
239
240#[derive(Clone, Debug)]
250pub struct UairSignature {
251 total_cols: TotalColumnLayout,
253 public_cols: PublicColumnLayout,
255 witness_cols: WitnessColumnLayout,
257 shifts: Vec<ShiftSpec>,
259 bit_op_specs: Vec<BitOpSpec>,
263 down_cols: VirtualColumnLayout,
265 lookup_specs: Vec<LookupColumnSpec>,
268}
269
270impl UairSignature {
271 pub fn new(
273 total_cols: TotalColumnLayout,
274 public_cols: PublicColumnLayout,
275 mut shifts: Vec<ShiftSpec>,
276 lookup_specs: Vec<LookupColumnSpec>,
277 ) -> Self {
278 for (name, pub_n, tot_n) in [
279 (
280 "binary_poly",
281 public_cols.num_binary_poly_cols(),
282 total_cols.num_binary_poly_cols(),
283 ),
284 (
285 "arbitrary_poly",
286 public_cols.num_arbitrary_poly_cols(),
287 total_cols.num_arbitrary_poly_cols(),
288 ),
289 ("int", public_cols.num_int_cols(), total_cols.num_int_cols()),
290 ] {
291 assert!(
292 pub_n <= tot_n,
293 "public {name}_cols ({pub_n}) > total ({tot_n})"
294 );
295 }
296
297 let num_cols = total_cols.cols();
298 for spec in &shifts {
299 assert!(
300 spec.source_col() < num_cols,
301 "ShiftSpec source_col {} out of range (total_cols = {}). \
302 source_col uses flat indexing: binary_poly || arbitrary_poly || int.",
303 spec.source_col(),
304 num_cols,
305 );
306 }
307
308 shifts.sort_by_key(|spec| spec.source_col());
309 let down_cols = Self::compute_down_layout(&total_cols, &shifts, &[]);
310 let witness_cols = WitnessColumnLayout::new(
311 sub!(
312 total_cols.num_binary_poly_cols(),
313 public_cols.num_binary_poly_cols()
314 ),
315 sub!(
316 total_cols.num_arbitrary_poly_cols(),
317 public_cols.num_arbitrary_poly_cols()
318 ),
319 sub!(total_cols.num_int_cols(), public_cols.num_int_cols()),
320 );
321
322 Self {
323 total_cols,
324 public_cols,
325 shifts,
326 bit_op_specs: Vec::new(),
327 down_cols,
328 witness_cols,
329 lookup_specs,
330 }
331 }
332
333 pub fn with_bit_op_specs(mut self, cell_width: usize, bit_op_specs: Vec<BitOpSpec>) -> Self {
359 let binary_poly_end = self.total_cols.num_binary_poly_cols();
360 for spec in &bit_op_specs {
361 assert!(
362 spec.source_col() < binary_poly_end,
363 "BitOpSpec source_col {} is not a binary_poly column \
364 (binary_poly_end = {}). Bit-ops are only defined on the \
365 cell ring F_2[X]/(X^W).",
366 spec.source_col(),
367 binary_poly_end,
368 );
369 assert!(
370 spec.op().count() < cell_width,
371 "BitOpSpec count {} out of range (must satisfy 0 < count < {})",
372 spec.op().count(),
373 cell_width,
374 );
375 }
376 self.bit_op_specs = bit_op_specs;
377 self.down_cols =
378 Self::compute_down_layout(&self.total_cols, &self.shifts, &self.bit_op_specs);
379 self
380 }
381
382 pub fn lookup_specs(&self) -> &[LookupColumnSpec] {
383 &self.lookup_specs
384 }
385
386 fn compute_down_layout(
387 total_cols: &TotalColumnLayout,
388 shifts: &[ShiftSpec],
389 bit_op_specs: &[BitOpSpec],
390 ) -> VirtualColumnLayout {
391 let binary_poly_end = total_cols.num_binary_poly_cols();
392 let arbitrary_poly_end = add!(binary_poly_end, total_cols.num_arbitrary_poly_cols());
393 let mut num_binary_poly = 0usize;
394 let mut num_arbitrary_poly = 0usize;
395 let mut num_int = 0usize;
396 for spec in shifts {
397 if spec.source_col() < binary_poly_end {
398 num_binary_poly = add!(num_binary_poly, 1);
399 } else if spec.source_col() < arbitrary_poly_end {
400 num_arbitrary_poly = add!(num_arbitrary_poly, 1);
401 } else {
402 num_int = add!(num_int, 1);
403 }
404 }
405 num_binary_poly = add!(num_binary_poly, bit_op_specs.len());
406 VirtualColumnLayout::new(num_binary_poly, num_arbitrary_poly, num_int)
407 }
408
409 pub fn total_cols(&self) -> &TotalColumnLayout {
410 &self.total_cols
411 }
412
413 pub fn public_cols(&self) -> &PublicColumnLayout {
414 &self.public_cols
415 }
416
417 pub fn witness_cols(&self) -> &WitnessColumnLayout {
419 &self.witness_cols
420 }
421
422 pub fn shifts(&self) -> &[ShiftSpec] {
423 &self.shifts
424 }
425
426 pub fn bit_op_specs(&self) -> &[BitOpSpec] {
430 &self.bit_op_specs
431 }
432
433 pub fn down_cols(&self) -> &VirtualColumnLayout {
435 &self.down_cols
436 }
437
438 pub fn dummy_rows<T: Clone>(&self, val: T) -> (Vec<T>, Vec<T>) {
442 let up_size = self.total_cols.cols();
443 let down_size = self.down_cols.cols();
444 (vec![val.clone(); up_size], vec![val; down_size])
445 }
446}
447
448#[derive(Debug, Clone, Default)]
456pub struct UairTrace<
457 'a,
458 PolyCoeff: Clone,
459 Int: Clone,
460 const BINARY_POLY_DEGREE_PLUS_ONE: usize,
461 const ARBITRARY_POLY_DEGREE_PLUS_ONE: usize,
462> {
463 pub binary_poly: Cow<'a, [DenseMultilinearExtension<BinaryPoly<BINARY_POLY_DEGREE_PLUS_ONE>>]>,
464 pub arbitrary_poly: Cow<
465 'a,
466 [DenseMultilinearExtension<DensePolynomial<PolyCoeff, ARBITRARY_POLY_DEGREE_PLUS_ONE>>],
467 >,
468 pub int: Cow<'a, [DenseMultilinearExtension<Int>]>,
469}
470
471impl<PolyCoeff: Clone, Int: Clone, const DB: usize, const DA: usize>
472 UairTrace<'static, PolyCoeff, Int, DB, DA>
473{
474 pub fn public(&self, sig: &UairSignature) -> UairTrace<'_, PolyCoeff, Int, DB, DA> {
477 let p = sig.public_cols();
478 UairTrace {
479 binary_poly: Cow::Borrowed(&self.binary_poly[0..p.num_binary_poly_cols()]),
480 arbitrary_poly: Cow::Borrowed(&self.arbitrary_poly[0..p.num_arbitrary_poly_cols()]),
481 int: Cow::Borrowed(&self.int[0..p.num_int_cols()]),
482 }
483 }
484
485 pub fn witness(&self, sig: &UairSignature) -> UairTrace<'_, PolyCoeff, Int, DB, DA> {
488 let p = sig.public_cols();
489 UairTrace {
490 binary_poly: Cow::Borrowed(&self.binary_poly[p.num_binary_poly_cols()..]),
491 arbitrary_poly: Cow::Borrowed(&self.arbitrary_poly[p.num_arbitrary_poly_cols()..]),
492 int: Cow::Borrowed(&self.int[p.num_int_cols()..]),
493 }
494 }
495}
496
497#[derive(Clone, Copy)]
505pub struct TraceRow<'a, Expr> {
506 pub binary_poly: &'a [Expr],
507 pub arbitrary_poly: &'a [Expr],
508 pub int: &'a [Expr],
509}
510
511impl<'a, Expr> TraceRow<'a, Expr> {
512 #[allow(clippy::arithmetic_side_effects)]
516 pub fn from_slice_with_layout(row: &'a [Expr], layout: &ColumnLayout) -> Self {
517 let num_binary_poly = layout.num_binary_poly_cols();
518 let num_arbitrary_poly = layout.num_arbitrary_poly_cols();
519 Self {
520 binary_poly: &row[0..num_binary_poly],
521 arbitrary_poly: &row[num_binary_poly..num_binary_poly + num_arbitrary_poly],
522 int: &row[num_binary_poly + num_arbitrary_poly..],
523 }
524 }
525}
526
527pub trait Uair: Clone {
537 type Ideal: Ideal;
545
546 type Scalar: Semiring;
552
553 fn signature() -> UairSignature;
560
561 fn constrain_general<B, FromR, MulByScalar, IFromR>(
579 b: &mut B,
580 up: TraceRow<B::Expr>,
581 down: TraceRow<B::Expr>,
582 from_ref: FromR,
583 mbs: MulByScalar,
584 ideal_from_ref: IFromR,
585 ) where
586 B: ConstraintBuilder,
587 FromR: Fn(&Self::Scalar) -> B::Expr,
588 MulByScalar: Fn(&B::Expr, &Self::Scalar) -> Option<B::Expr>,
589 IFromR: Fn(&Self::Ideal) -> B::Ideal;
590
591 fn constrain<B>(b: &mut B, up: TraceRow<B::Expr>, down: TraceRow<B::Expr>)
594 where
595 B: ConstraintBuilder,
596 B::Expr: FromRef<Self::Scalar> + for<'b> MulByScalar<&'b Self::Scalar>,
597 B::Ideal: FromRef<Self::Ideal>,
598 {
599 Self::constrain_general(
600 b,
601 up,
602 down,
603 B::Expr::from_ref,
604 |x, y| B::Expr::mul_by_scalar::<UNCHECKED>(x, y),
605 B::Ideal::from_ref,
606 )
607 }
608}
609
610#[cfg(test)]
611mod tests {
612 use super::*;
613
614 fn signature_with_mixed_shifts() -> UairSignature {
615 UairSignature::new(
616 TotalColumnLayout::new(2, 1, 1),
617 PublicColumnLayout::new(0, 0, 0),
618 vec![
619 ShiftSpec::new(0, 1),
620 ShiftSpec::new(2, 1),
621 ShiftSpec::new(3, 1),
622 ],
623 vec![],
624 )
625 }
626
627 #[test]
628 fn bit_op_specs_extend_binary_down_layout() {
629 let specs = vec![
630 BitOpSpec::new(1, BitOp::ShR(3)),
631 BitOpSpec::new(0, BitOp::Rot(2)),
632 ];
633 let sig = signature_with_mixed_shifts().with_bit_op_specs(8, specs.clone());
634
635 assert_eq!(sig.bit_op_specs(), specs);
636 assert_eq!(sig.bit_op_specs()[0].source_col(), 1);
637 assert_eq!(sig.bit_op_specs()[0].op(), BitOp::ShR(3));
638 assert_eq!(sig.bit_op_specs()[0].op().count(), 3);
639 assert_eq!(sig.down_cols().num_binary_poly_cols(), 3);
640 assert_eq!(sig.down_cols().num_arbitrary_poly_cols(), 1);
641 assert_eq!(sig.down_cols().num_int_cols(), 1);
642 }
643
644 #[test]
645 fn empty_bit_op_specs_keep_shift_only_down_layout() {
646 let sig = signature_with_mixed_shifts().with_bit_op_specs(8, vec![]);
647
648 assert!(sig.bit_op_specs().is_empty());
649 assert_eq!(sig.down_cols().num_binary_poly_cols(), 1);
650 assert_eq!(sig.down_cols().num_arbitrary_poly_cols(), 1);
651 assert_eq!(sig.down_cols().num_int_cols(), 1);
652 }
653
654 #[test]
655 #[should_panic(expected = "bit-op count must be non-zero")]
656 fn bit_op_spec_rejects_zero_count() {
657 let _ = BitOpSpec::new(0, BitOp::Rot(0));
658 }
659
660 #[test]
661 #[should_panic(expected = "is not a binary_poly column")]
662 fn bit_op_specs_reject_non_binary_source() {
663 let _ = signature_with_mixed_shifts()
664 .with_bit_op_specs(8, vec![BitOpSpec::new(2, BitOp::ShR(1))]);
665 }
666
667 #[test]
668 #[should_panic(expected = "out of range")]
669 fn bit_op_specs_reject_count_at_cell_width() {
670 let _ = signature_with_mixed_shifts()
671 .with_bit_op_specs(8, vec![BitOpSpec::new(0, BitOp::Rot(8))]);
672 }
673}