1pub mod collect_scalars;
4pub mod constraint_counter;
5pub mod degree_counter;
6pub mod do_nothing_builder;
7pub mod dummy_semiring;
8pub mod ideal;
9pub mod ideal_collector;
10pub mod lookup_types;
11
12use crypto_primitives::Semiring;
13use std::borrow::Cow;
14use zinc_poly::{
15 mle::DenseMultilinearExtension,
16 univariate::{binary::BinaryPoly, dense::DensePolynomial},
17};
18use zinc_utils::{UNCHECKED, add, from_ref::FromRef, mul_by_scalar::MulByScalar, sub};
19
20use crate::ideal::{Ideal, IdealCheck};
21
22pub use lookup_types::{LookupColumnSpec, LookupTableType};
23
24pub trait ConstraintBuilder {
27 type Expr: Semiring;
32 type Ideal: Ideal + IdealCheck<Self::Expr>;
34
35 fn assert_in_ideal(&mut self, expr: Self::Expr, ideal: &Self::Ideal);
37
38 fn assert_zero(&mut self, expr: Self::Expr);
41}
42
43#[derive(Clone, Debug, PartialEq, Eq, Hash)]
51pub struct ShiftSpec {
52 source_col: usize,
56 shift_amount: usize,
58}
59
60impl ShiftSpec {
61 pub fn new(source_col: usize, shift_amount: usize) -> Self {
62 assert!(shift_amount > 0, "shift must be non-zero");
63 Self {
64 source_col,
65 shift_amount,
66 }
67 }
68
69 pub fn source_col(&self) -> usize {
70 self.source_col
71 }
72
73 pub fn shift_amount(&self) -> usize {
74 self.shift_amount
75 }
76}
77
78#[derive(Clone, Debug, Default)]
86pub struct ColumnLayout {
87 num_binary_poly_cols: usize,
88 num_arbitrary_poly_cols: usize,
89 num_int_cols: usize,
90}
91
92impl ColumnLayout {
93 pub fn new(
94 num_binary_poly_cols: usize,
95 num_arbitrary_poly_cols: usize,
96 num_int_cols: usize,
97 ) -> Self {
98 Self {
99 num_binary_poly_cols,
100 num_arbitrary_poly_cols,
101 num_int_cols,
102 }
103 }
104
105 pub fn num_binary_poly_cols(&self) -> usize {
106 self.num_binary_poly_cols
107 }
108
109 pub fn num_arbitrary_poly_cols(&self) -> usize {
110 self.num_arbitrary_poly_cols
111 }
112
113 pub fn num_int_cols(&self) -> usize {
114 self.num_int_cols
115 }
116
117 pub fn max_cols(&self) -> usize {
119 [
120 self.num_binary_poly_cols,
121 self.num_arbitrary_poly_cols,
122 self.num_int_cols,
123 ]
124 .into_iter()
125 .max()
126 .expect("the iterator is not empty")
127 }
128
129 #[allow(clippy::arithmetic_side_effects)]
131 pub fn cols(&self) -> usize {
132 self.num_binary_poly_cols + self.num_arbitrary_poly_cols + self.num_int_cols
133 }
134}
135
136macro_rules! column_layout_wrapper {
137 ($(#[$meta:meta])* $name:ident) => {
138 $(#[$meta])*
139 #[derive(Clone, Debug, Default)]
140 pub struct $name(ColumnLayout);
141
142 impl $name {
143 pub fn new(num_binary_poly_cols: usize, num_arbitrary_poly_cols: usize, num_int_cols: usize) -> Self {
144 Self(ColumnLayout::new(num_binary_poly_cols, num_arbitrary_poly_cols, num_int_cols))
145 }
146
147 pub fn num_binary_poly_cols(&self) -> usize { self.0.num_binary_poly_cols() }
148 pub fn num_arbitrary_poly_cols(&self) -> usize { self.0.num_arbitrary_poly_cols() }
149 pub fn num_int_cols(&self) -> usize { self.0.num_int_cols() }
150 pub fn max_cols(&self) -> usize { self.0.max_cols() }
151 pub fn cols(&self) -> usize { self.0.cols() }
152 pub fn as_column_layout(&self) -> &ColumnLayout { &self.0 }
153 }
154 };
155}
156
157column_layout_wrapper!(TotalColumnLayout);
159column_layout_wrapper!(PublicColumnLayout);
161column_layout_wrapper!(VirtualColumnLayout);
163column_layout_wrapper!(WitnessColumnLayout);
165
166#[derive(Clone, Debug)]
176pub struct UairSignature {
177 total_cols: TotalColumnLayout,
179 public_cols: PublicColumnLayout,
181 witness_cols: WitnessColumnLayout,
183 shifts: Vec<ShiftSpec>,
185 down_cols: VirtualColumnLayout,
187 lookup_specs: Vec<LookupColumnSpec>,
190}
191
192impl UairSignature {
193 pub fn new(
195 total_cols: TotalColumnLayout,
196 public_cols: PublicColumnLayout,
197 mut shifts: Vec<ShiftSpec>,
198 lookup_specs: Vec<LookupColumnSpec>,
199 ) -> Self {
200 for (name, pub_n, tot_n) in [
201 (
202 "binary_poly",
203 public_cols.num_binary_poly_cols(),
204 total_cols.num_binary_poly_cols(),
205 ),
206 (
207 "arbitrary_poly",
208 public_cols.num_arbitrary_poly_cols(),
209 total_cols.num_arbitrary_poly_cols(),
210 ),
211 ("int", public_cols.num_int_cols(), total_cols.num_int_cols()),
212 ] {
213 assert!(
214 pub_n <= tot_n,
215 "public {name}_cols ({pub_n}) > total ({tot_n})"
216 );
217 }
218
219 let num_cols = total_cols.cols();
220 for spec in &shifts {
221 assert!(
222 spec.source_col() < num_cols,
223 "ShiftSpec source_col {} out of range (total_cols = {}). \
224 source_col uses flat indexing: binary_poly || arbitrary_poly || int.",
225 spec.source_col(),
226 num_cols,
227 );
228 }
229
230 shifts.sort_by_key(|spec| spec.source_col());
231 let down_cols = Self::compute_down_layout(&total_cols, &shifts);
232 let witness_cols = WitnessColumnLayout::new(
233 sub!(
234 total_cols.num_binary_poly_cols(),
235 public_cols.num_binary_poly_cols()
236 ),
237 sub!(
238 total_cols.num_arbitrary_poly_cols(),
239 public_cols.num_arbitrary_poly_cols()
240 ),
241 sub!(total_cols.num_int_cols(), public_cols.num_int_cols()),
242 );
243
244 Self {
245 total_cols,
246 public_cols,
247 shifts,
248 down_cols,
249 witness_cols,
250 lookup_specs,
251 }
252 }
253
254 pub fn lookup_specs(&self) -> &[LookupColumnSpec] {
255 &self.lookup_specs
256 }
257
258 fn compute_down_layout(
259 total_cols: &TotalColumnLayout,
260 shifts: &[ShiftSpec],
261 ) -> VirtualColumnLayout {
262 let binary_poly_end = total_cols.num_binary_poly_cols();
263 let arbitrary_poly_end = add!(binary_poly_end, total_cols.num_arbitrary_poly_cols());
264 let mut num_binary_poly = 0usize;
265 let mut num_arbitrary_poly = 0usize;
266 let mut num_int = 0usize;
267 for spec in shifts {
268 if spec.source_col() < binary_poly_end {
269 num_binary_poly = add!(num_binary_poly, 1);
270 } else if spec.source_col() < arbitrary_poly_end {
271 num_arbitrary_poly = add!(num_arbitrary_poly, 1);
272 } else {
273 num_int = add!(num_int, 1);
274 }
275 }
276 VirtualColumnLayout::new(num_binary_poly, num_arbitrary_poly, num_int)
277 }
278
279 pub fn total_cols(&self) -> &TotalColumnLayout {
280 &self.total_cols
281 }
282
283 pub fn public_cols(&self) -> &PublicColumnLayout {
284 &self.public_cols
285 }
286
287 pub fn witness_cols(&self) -> &WitnessColumnLayout {
289 &self.witness_cols
290 }
291
292 pub fn shifts(&self) -> &[ShiftSpec] {
293 &self.shifts
294 }
295
296 pub fn down_cols(&self) -> &VirtualColumnLayout {
298 &self.down_cols
299 }
300
301 pub fn dummy_rows<T: Clone>(&self, val: T) -> (Vec<T>, Vec<T>) {
305 let up_size = self.total_cols.cols();
306 let down_size = self.down_cols.cols();
307 (vec![val.clone(); up_size], vec![val; down_size])
308 }
309}
310
311#[derive(Debug, Clone, Default)]
319pub struct UairTrace<'a, PolyCoeff: Clone, Int: Clone, const D: usize> {
320 pub binary_poly: Cow<'a, [DenseMultilinearExtension<BinaryPoly<D>>]>,
321 pub arbitrary_poly: Cow<'a, [DenseMultilinearExtension<DensePolynomial<PolyCoeff, D>>]>,
322 pub int: Cow<'a, [DenseMultilinearExtension<Int>]>,
323}
324
325impl<PolyCoeff: Clone, Int: Clone, const D: usize> UairTrace<'static, PolyCoeff, Int, D> {
326 pub fn public(&self, sig: &UairSignature) -> UairTrace<'_, PolyCoeff, Int, D> {
329 let p = sig.public_cols();
330 UairTrace {
331 binary_poly: Cow::Borrowed(&self.binary_poly[0..p.num_binary_poly_cols()]),
332 arbitrary_poly: Cow::Borrowed(&self.arbitrary_poly[0..p.num_arbitrary_poly_cols()]),
333 int: Cow::Borrowed(&self.int[0..p.num_int_cols()]),
334 }
335 }
336
337 pub fn witness(&self, sig: &UairSignature) -> UairTrace<'_, PolyCoeff, Int, D> {
340 let p = sig.public_cols();
341 UairTrace {
342 binary_poly: Cow::Borrowed(&self.binary_poly[p.num_binary_poly_cols()..]),
343 arbitrary_poly: Cow::Borrowed(&self.arbitrary_poly[p.num_arbitrary_poly_cols()..]),
344 int: Cow::Borrowed(&self.int[p.num_int_cols()..]),
345 }
346 }
347}
348
349#[derive(Clone, Copy)]
357pub struct TraceRow<'a, Expr> {
358 pub binary_poly: &'a [Expr],
359 pub arbitrary_poly: &'a [Expr],
360 pub int: &'a [Expr],
361}
362
363impl<'a, Expr> TraceRow<'a, Expr> {
364 #[allow(clippy::arithmetic_side_effects)]
368 pub fn from_slice_with_layout(row: &'a [Expr], layout: &ColumnLayout) -> Self {
369 let num_binary_poly = layout.num_binary_poly_cols();
370 let num_arbitrary_poly = layout.num_arbitrary_poly_cols();
371 Self {
372 binary_poly: &row[0..num_binary_poly],
373 arbitrary_poly: &row[num_binary_poly..num_binary_poly + num_arbitrary_poly],
374 int: &row[num_binary_poly + num_arbitrary_poly..],
375 }
376 }
377}
378
379pub trait Uair: Clone {
389 type Ideal: Ideal;
397
398 type Scalar: Semiring;
404
405 fn signature() -> UairSignature;
412
413 fn constrain_general<B, FromR, MulByScalar, IFromR>(
431 b: &mut B,
432 up: TraceRow<B::Expr>,
433 down: TraceRow<B::Expr>,
434 from_ref: FromR,
435 mbs: MulByScalar,
436 ideal_from_ref: IFromR,
437 ) where
438 B: ConstraintBuilder,
439 FromR: Fn(&Self::Scalar) -> B::Expr,
440 MulByScalar: Fn(&B::Expr, &Self::Scalar) -> Option<B::Expr>,
441 IFromR: Fn(&Self::Ideal) -> B::Ideal;
442
443 fn constrain<B>(b: &mut B, up: TraceRow<B::Expr>, down: TraceRow<B::Expr>)
446 where
447 B: ConstraintBuilder,
448 B::Expr: FromRef<Self::Scalar> + for<'b> MulByScalar<&'b Self::Scalar>,
449 B::Ideal: FromRef<Self::Ideal>,
450 {
451 Self::constrain_general(
452 b,
453 up,
454 down,
455 B::Expr::from_ref,
456 |x, y| B::Expr::mul_by_scalar::<UNCHECKED>(x, y),
457 B::Ideal::from_ref,
458 )
459 }
460}