prover/
r1cs.rs

1//! R1CS Parser for Circom binary format
2//!
3//! Parses the `.r1cs` file format produced by the Circom compiler.
4//! This allows us to replay constraints during proof generation.
5//!
6//! # File Format
7//! The R1CS binary format consists of:
8//! - Header with magic number "r1cs"
9//! - Sections for header info, constraints, and wire mappings
10//!
11//! # Reference
12//! https://github.com/iden3/r1csfile/blob/master/doc/r1cs_bin_format.md
13
14use alloc::{format, vec::Vec};
15
16use ark_bn254::Fr;
17use ark_ff::PrimeField;
18use wasm_bindgen::JsValue;
19
20/// A term in a linear combination: coefficient * wire
21#[derive(Clone, Debug)]
22pub struct Term {
23    /// Wire index (variable index in the constraint system)
24    pub wire_id: u32,
25    /// Coefficient as a field element
26    pub coefficient: Fr,
27}
28
29/// A linear combination: sum of (coefficient * wire)
30#[derive(Clone, Debug, Default)]
31pub struct LinearCombination {
32    /// The terms in this linear combination
33    pub terms: Vec<Term>,
34}
35
36/// A single R1CS constraint: A * B = C
37/// Where A, B, C are linear combinations
38#[derive(Clone, Debug)]
39pub struct Constraint {
40    /// Linear combination A
41    pub a: LinearCombination,
42    /// Linear combination B  
43    pub b: LinearCombination,
44    /// Linear combination C
45    pub c: LinearCombination,
46}
47
48/// Parsed R1CS file
49#[derive(Clone, Debug)]
50pub struct R1CS {
51    /// Number of wires (variables) in the circuit
52    pub num_wires: u32,
53    /// Number of public outputs
54    pub num_pub_out: u32,
55    /// Number of public inputs
56    pub num_pub_in: u32,
57    /// Number of private inputs
58    pub num_prv_in: u32,
59    /// Total number of public inputs (outputs + inputs, excluding constant 1)
60    pub num_public: u32,
61    /// The constraints
62    pub constraints: Vec<Constraint>,
63}
64
65impl R1CS {
66    /// Parse R1CS from binary data
67    pub fn parse(data: &[u8]) -> Result<Self, JsValue> {
68        let mut cursor = Cursor::new(data);
69
70        // Read and verify magic number "r1cs"
71        let magic = cursor.read_bytes(4)?;
72        if magic != b"r1cs" {
73            return Err(JsValue::from_str("Invalid R1CS magic number"));
74        }
75
76        // Version (should be 1)
77        let version = cursor.read_u32_le()?;
78        if version != 1 {
79            return Err(JsValue::from_str(&format!(
80                "Unsupported R1CS version: {}",
81                version
82            )));
83        }
84
85        // Number of sections
86        let num_sections = cursor.read_u32_le()?;
87
88        let mut header: Option<R1CSHeader> = None;
89        let mut constraints_data: Option<(usize, usize)> = None; // (start, size)
90
91        // First pass: collect section locations
92        for _ in 0..num_sections {
93            let section_type = cursor.read_u32_le()?;
94            let section_size = cursor.read_u64_le()?;
95            let section_start = cursor.position;
96
97            let section_size_usize = usize::try_from(section_size)
98                .map_err(|_| JsValue::from_str("Section size too large"))?;
99
100            match section_type {
101                1 => {
102                    // Header section
103                    header = Some(Self::parse_header(&mut cursor)?);
104                }
105                2 => {
106                    // Constraints section - save location for later
107                    constraints_data = Some((section_start, section_size_usize));
108                    cursor.skip(section_size_usize)?;
109                }
110                3 => {
111                    // Wire2LabelId section - skip
112                    cursor.skip(section_size_usize)?;
113                }
114                _ => {
115                    // Unknown section - skip
116                    cursor.skip(section_size_usize)?;
117                }
118            }
119
120            // Ensure we consumed exactly section_size bytes
121            let consumed = cursor
122                .position
123                .checked_sub(section_start)
124                .ok_or_else(|| JsValue::from_str("Invalid cursor position"))?;
125            if consumed < section_size_usize {
126                let remaining = section_size_usize
127                    .checked_sub(consumed)
128                    .ok_or_else(|| JsValue::from_str("Invalid remaining bytes calculation"))?;
129                cursor.skip(remaining)?;
130            }
131        }
132
133        // Now parse constraints with header available
134        let header = header.ok_or_else(|| JsValue::from_str("Missing R1CS header section"))?;
135
136        let constraints = if let Some((start, _size)) = constraints_data {
137            cursor.position = start;
138            Self::parse_constraints(&mut cursor, &header)?
139        } else {
140            Vec::new()
141        };
142
143        // num_public = num_pub_out + num_pub_in (not including the constant 1 wire)
144        let num_public = header
145            .num_pub_out
146            .checked_add(header.num_pub_in)
147            .ok_or_else(|| JsValue::from_str("Overflow calculating num_public"))?;
148
149        Ok(R1CS {
150            num_wires: header.num_wires,
151            num_pub_out: header.num_pub_out,
152            num_pub_in: header.num_pub_in,
153            num_prv_in: header.num_prv_in,
154            num_public,
155            constraints,
156        })
157    }
158
159    fn parse_header(cursor: &mut Cursor) -> Result<R1CSHeader, JsValue> {
160        // Field size in bytes (should be 32 for BN254)
161        let field_size = cursor.read_u32_le()?;
162        if field_size != 32 {
163            return Err(JsValue::from_str(&format!(
164                "Unsupported field size: {} (expected 32)",
165                field_size
166            )));
167        }
168
169        // Prime (skip - we assume BN254)
170        cursor.skip(field_size as usize)?;
171
172        let num_wires = cursor.read_u32_le()?;
173        let num_pub_out = cursor.read_u32_le()?;
174        let num_pub_in = cursor.read_u32_le()?;
175        let num_prv_in = cursor.read_u32_le()?;
176        let _num_labels = cursor.read_u64_le()?;
177        let num_constraints = cursor.read_u32_le()?;
178
179        Ok(R1CSHeader {
180            field_size,
181            num_wires,
182            num_pub_out,
183            num_pub_in,
184            num_prv_in,
185            num_constraints,
186        })
187    }
188
189    fn parse_constraints(
190        cursor: &mut Cursor,
191        header: &R1CSHeader,
192    ) -> Result<Vec<Constraint>, JsValue> {
193        let mut constraints = Vec::with_capacity(header.num_constraints as usize);
194
195        for _ in 0..header.num_constraints {
196            let a = Self::parse_linear_combination(cursor, header.field_size)?;
197            let b = Self::parse_linear_combination(cursor, header.field_size)?;
198            let c = Self::parse_linear_combination(cursor, header.field_size)?;
199
200            constraints.push(Constraint { a, b, c });
201        }
202
203        Ok(constraints)
204    }
205
206    fn parse_linear_combination(
207        cursor: &mut Cursor,
208        field_size: u32,
209    ) -> Result<LinearCombination, JsValue> {
210        let num_terms = cursor.read_u32_le()?;
211        let mut terms = Vec::with_capacity(num_terms as usize);
212
213        for _ in 0..num_terms {
214            let wire_id = cursor.read_u32_le()?;
215            let coeff_bytes = cursor.read_bytes(field_size as usize)?;
216            let coefficient = Fr::from_le_bytes_mod_order(coeff_bytes);
217
218            terms.push(Term {
219                wire_id,
220                coefficient,
221            });
222        }
223
224        Ok(LinearCombination { terms })
225    }
226
227    /// Get total number of constraints
228    pub fn num_constraints(&self) -> usize {
229        self.constraints.len()
230    }
231}
232
233/// Internal header struct
234struct R1CSHeader {
235    field_size: u32,
236    num_wires: u32,
237    num_pub_out: u32,
238    num_pub_in: u32,
239    num_prv_in: u32,
240    num_constraints: u32,
241}
242
243/// Simple cursor for reading binary data
244struct Cursor<'a> {
245    data: &'a [u8],
246    position: usize,
247}
248
249impl<'a> Cursor<'a> {
250    fn new(data: &'a [u8]) -> Self {
251        Cursor { data, position: 0 }
252    }
253
254    fn read_bytes(&mut self, n: usize) -> Result<&'a [u8], JsValue> {
255        let end = self
256            .position
257            .checked_add(n)
258            .ok_or_else(|| JsValue::from_str("Overflow in cursor position"))?;
259        if end > self.data.len() {
260            return Err(JsValue::from_str("Unexpected end of R1CS data"));
261        }
262        let slice = &self.data[self.position..end];
263        self.position = end;
264        Ok(slice)
265    }
266
267    fn read_u32_le(&mut self) -> Result<u32, JsValue> {
268        let bytes = self.read_bytes(4)?;
269        Ok(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
270    }
271
272    fn read_u64_le(&mut self) -> Result<u64, JsValue> {
273        let bytes = self.read_bytes(8)?;
274        Ok(u64::from_le_bytes([
275            bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
276        ]))
277    }
278
279    fn skip(&mut self, n: usize) -> Result<(), JsValue> {
280        let end = self
281            .position
282            .checked_add(n)
283            .ok_or_else(|| JsValue::from_str("Overflow in cursor skip"))?;
284        if end > self.data.len() {
285            return Err(JsValue::from_str("Unexpected end of R1CS data"));
286        }
287        self.position = end;
288        Ok(())
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    #[test]
297    fn test_cursor_reads() {
298        let data = [0x72, 0x31, 0x63, 0x73, 0x01, 0x00, 0x00, 0x00]; // "r1cs" + version 1 + 0 padding
299        let mut cursor = Cursor::new(&data);
300
301        let magic = cursor.read_bytes(4).expect("should read magic bytes");
302        assert_eq!(magic, b"r1cs");
303
304        let version = cursor.read_u32_le().expect("should read version");
305        assert_eq!(version, 1);
306    }
307}