witness/
lib.rs

1//! Witness Generation WASM Module
2//!
3//! Uses ark-circom to compute witnesses for Circom circuits in the browser.
4//! Outputs witness bytes compatible with the prover module.
5
6use ark_bn254::Fr;
7use ark_circom::{WitnessCalculator as ArkWitnessCalculator, circom::R1CSFile};
8use num_bigint::{BigInt, Sign};
9// These are part of the reduced STD that is browser compatible
10use std::{collections::HashMap, io::Cursor, string::String, vec::Vec};
11use wasm_bindgen::prelude::*;
12use wasmer::{Module, Store};
13
14/// BN254 scalar field modulus
15const BN254_FIELD_MODULUS: &str =
16    "21888242871839275222246405745257275088548364400416034343698204186575808495617";
17
18/// Initialize the WASM module
19#[wasm_bindgen(start)]
20pub fn init() {
21    console_error_panic_hook::set_once();
22}
23
24/// Get module version
25#[wasm_bindgen]
26pub fn version() -> String {
27    String::from(env!("CARGO_PKG_VERSION"))
28}
29
30/// Witness calculator instance
31#[wasm_bindgen]
32pub struct WitnessCalculator {
33    /// Wasmer store for the circuit WASM instance
34    store: Store,
35    /// Internal ark-circom witness calculator
36    calculator: ArkWitnessCalculator,
37    /// Number of variables in the witness
38    witness_size: u32,
39    /// Number of public inputs (does not include public outputs or the constant
40    /// signal 1)
41    num_public_inputs: u32,
42}
43
44#[wasm_bindgen]
45impl WitnessCalculator {
46    /// Create a new WitnessCalculator from circuit WASM and R1CS bytes
47    ///
48    /// # Arguments
49    /// * `circuit_wasm` - The compiled circuit WASM bytes
50    /// * `r1cs_bytes` - The R1CS constraint system bytes
51    #[wasm_bindgen(constructor)]
52    pub fn new(circuit_wasm: &[u8], r1cs_bytes: &[u8]) -> Result<WitnessCalculator, JsValue> {
53        // Parse R1CS from bytes
54        let cursor = Cursor::new(r1cs_bytes);
55        let r1cs_file: R1CSFile<Fr> = R1CSFile::new(cursor)
56            .map_err(|e| JsValue::from_str(&format!("Failed to parse R1CS: {}", e)))?;
57
58        let witness_size = r1cs_file.header.n_wires;
59        let num_public_inputs = r1cs_file.header.n_pub_in;
60
61        // Create wasmer store and load circuit module from bytes
62        let mut store = Store::default();
63        let module = Module::new(&store, circuit_wasm)
64            .map_err(|e| JsValue::from_str(&format!("Failed to load circuit WASM: {}", e)))?;
65
66        // Create witness calculator from module
67        let calculator = ArkWitnessCalculator::from_module(&mut store, module)
68            .map_err(|e| JsValue::from_str(&format!("Failed to init witness calc: {}", e)))?;
69
70        Ok(WitnessCalculator {
71            store,
72            calculator,
73            witness_size,
74            num_public_inputs,
75        })
76    }
77
78    /// Compute witness from JSON inputs
79    ///
80    /// # Arguments
81    /// * `inputs_json` - JSON string with circuit inputs
82    ///
83    /// # Returns
84    /// * Witness as Little-Endian bytes (32 bytes per field element)
85    #[wasm_bindgen]
86    pub fn compute_witness(&mut self, inputs_json: &str) -> Result<Vec<u8>, JsValue> {
87        use serde_json::Value;
88
89        // Parse JSON inputs
90        let inputs: Value = serde_json::from_str(inputs_json)
91            .map_err(|e| JsValue::from_str(&format!("Invalid JSON: {}", e)))?;
92
93        let inputs_map = inputs
94            .as_object()
95            .ok_or_else(|| JsValue::from_str("Inputs must be a JSON object"))?;
96
97        // Convert to HashMap<String, Vec<BigInt>> by flattening nested structures
98        let mut inputs_hashmap: HashMap<String, Vec<BigInt>> = HashMap::new();
99
100        for (key, value) in inputs_map {
101            flatten_input(key, value, &mut inputs_hashmap)?;
102        }
103
104        // Calculate witness
105        let witness = self
106            .calculator
107            .calculate_witness(&mut self.store, inputs_hashmap, false)
108            .map_err(|e| JsValue::from_str(&format!("Witness calculation failed: {}", e)))?;
109
110        // Convert to Little-Endian bytes
111        Ok(witness_to_bytes(&witness))
112    }
113
114    /// Get the witness size (number of field elements)
115    #[wasm_bindgen(getter)]
116    pub fn witness_size(&self) -> u32 {
117        self.witness_size
118    }
119
120    /// Get the number of public inputs
121    #[wasm_bindgen(getter)]
122    pub fn num_public_inputs(&self) -> u32 {
123        self.num_public_inputs
124    }
125}
126
127/// Convert a BigInt to its field element representation.
128/// Negative numbers are converted to p - |value| where p is the field modulus.
129/// Relevant for ZK proof computation. For on-chain token transfer
130/// we use a I256 passed to the contract.
131fn to_field_element(bi: BigInt) -> BigInt {
132    let modulus =
133        BigInt::parse_bytes(BN254_FIELD_MODULUS.as_bytes(), 10).expect("Invalid field modulus");
134
135    if bi.sign() == Sign::Minus {
136        let abs_value = bi
137            .checked_mul(&BigInt::from(-1))
138            .expect("Overflow in getting the abs value"); // Get absolute value
139
140        // Check absolute value must be less than the field modulus
141        assert!(
142            abs_value < modulus,
143            "Negative value {} exceeds field modulus",
144            bi
145        );
146
147        // For negative n: field_element = p - |n|
148        modulus
149            .checked_sub(&abs_value)
150            .expect("Overflow in field element computation")
151    } else {
152        // Validate: positive value must be less than the field modulus
153        assert!(bi < modulus, "Value {} exceeds field modulus", bi);
154        bi
155    }
156}
157
158/// Check if a JSON value is an array containing only primitives.
159fn is_pure_array(value: &serde_json::Value) -> bool {
160    use serde_json::Value;
161
162    let mut stack: Vec<&Value> = vec![value];
163
164    while let Some(current) = stack.pop() {
165        match current {
166            Value::Number(_) | Value::String(_) | Value::Bool(_) | Value::Null => {}
167            Value::Array(arr) => {
168                for item in arr {
169                    stack.push(item);
170                }
171            }
172            Value::Object(_) => return false,
173        }
174    }
175    true
176}
177
178/// Flatten a JSON value into the inputs hashmap.
179///
180/// For Circom circuits:
181/// - Multi-dimensional arrays of primitives are flattened to a single key in
182///   row-major order
183/// - Arrays containing objects use indexed keys with dot notation for fields
184fn flatten_input(
185    key: &str,
186    value: &serde_json::Value,
187    inputs: &mut HashMap<String, Vec<BigInt>>,
188) -> Result<(), JsValue> {
189    use serde_json::Value;
190
191    // (key, value) pairs to iterate over.
192    let mut stack: Vec<(String, &Value)> = vec![(key.to_string(), value)];
193
194    while let Some((current_key, current_value)) = stack.pop() {
195        match current_value {
196            Value::Number(n) => {
197                let bi = if let Some(i) = n.as_u64() {
198                    BigInt::from(i)
199                } else if let Some(i) = n.as_i64() {
200                    BigInt::from(i)
201                } else {
202                    return Err(JsValue::from_str(&format!(
203                        "Invalid number for {}",
204                        current_key
205                    )));
206                };
207                // Convert to field element (handles negative numbers)
208                inputs
209                    .entry(current_key)
210                    .or_default()
211                    .push(to_field_element(bi));
212            }
213            Value::String(s) => {
214                let bi = if let Some(hex) = s.strip_prefix("0x") {
215                    BigInt::parse_bytes(hex.as_bytes(), 16)
216                } else {
217                    BigInt::parse_bytes(s.as_bytes(), 10)
218                };
219                let bi = bi.ok_or_else(|| {
220                    JsValue::from_str(&format!("Invalid bigint for {}: {}", current_key, s))
221                })?;
222                // Convert to field element (handles negative numbers)
223                inputs
224                    .entry(current_key)
225                    .or_default()
226                    .push(to_field_element(bi));
227            }
228            Value::Array(arr) => {
229                // Pure arrays get flattened to a single key as in
230                if is_pure_array(current_value) {
231                    flatten_pure_array(&current_key, current_value, inputs)?;
232                } else {
233                    //  If the array contains objects, we push indexed items in reverse order
234                    // to maintain the original order when popping
235                    for (idx, item) in arr.iter().enumerate().rev() {
236                        let indexed_key = format!("{}[{}]", current_key, idx);
237                        stack.push((indexed_key, item));
238                    }
239                }
240            }
241            Value::Object(obj) => {
242                // Push object fields
243                for (field, val) in obj {
244                    let nested_key = format!("{}.{}", current_key, field);
245                    stack.push((nested_key, val));
246                }
247            }
248            Value::Bool(b) => {
249                let bi = if *b { BigInt::from(1) } else { BigInt::from(0) };
250                inputs.entry(current_key).or_default().push(bi);
251            }
252            Value::Null => {
253                inputs.entry(current_key).or_default().push(BigInt::from(0));
254            }
255        }
256    }
257    Ok(())
258}
259
260/// Flatten a pure array to a single key in row-major order.
261fn flatten_pure_array(
262    key: &str,
263    value: &serde_json::Value,
264    inputs: &mut HashMap<String, Vec<BigInt>>,
265) -> Result<(), JsValue> {
266    use serde_json::Value;
267
268    // We use indices to maintain row-major order:
269    // each item is (array_ref, next_index_to_process).
270    // For non-array values, we process them immediately.
271    enum WorkItem<'a> {
272        Value(&'a Value),
273        ArrayIter { arr: &'a [Value], idx: usize },
274    }
275
276    let mut stack: Vec<WorkItem<'_>> = vec![WorkItem::Value(value)];
277
278    while let Some(item) = stack.pop() {
279        match item {
280            WorkItem::Value(v) => match v {
281                Value::Number(n) => {
282                    let bi = if let Some(i) = n.as_u64() {
283                        BigInt::from(i)
284                    } else if let Some(i) = n.as_i64() {
285                        BigInt::from(i)
286                    } else {
287                        return Err(JsValue::from_str(&format!("Invalid number for {}", key)));
288                    };
289                    inputs
290                        .entry(key.to_string())
291                        .or_default()
292                        .push(to_field_element(bi));
293                }
294                Value::String(s) => {
295                    let bi = if let Some(hex) = s.strip_prefix("0x") {
296                        BigInt::parse_bytes(hex.as_bytes(), 16)
297                    } else {
298                        BigInt::parse_bytes(s.as_bytes(), 10)
299                    };
300                    let bi = bi.ok_or_else(|| {
301                        JsValue::from_str(&format!("Invalid bigint for {}: {}", key, s))
302                    })?;
303                    inputs
304                        .entry(key.to_string())
305                        .or_default()
306                        .push(to_field_element(bi));
307                }
308                Value::Array(arr) => {
309                    if !arr.is_empty() {
310                        stack.push(WorkItem::ArrayIter { arr, idx: 0 });
311                    }
312                }
313                Value::Bool(b) => {
314                    let bi = if *b { BigInt::from(1) } else { BigInt::from(0) };
315                    inputs.entry(key.to_string()).or_default().push(bi);
316                }
317                Value::Null => {
318                    inputs
319                        .entry(key.to_string())
320                        .or_default()
321                        .push(BigInt::from(0));
322                }
323                Value::Object(_) => {
324                    return Err(JsValue::from_str(&format!(
325                        "Unexpected object in pure array: {}",
326                        key
327                    )));
328                }
329            },
330            WorkItem::ArrayIter { arr, idx } => {
331                // Push continuation for remaining elements first
332                let next_idx = idx.saturating_add(1);
333                if next_idx < arr.len() {
334                    stack.push(WorkItem::ArrayIter { arr, idx: next_idx });
335                }
336                // Then push current element
337                stack.push(WorkItem::Value(&arr[idx]));
338            }
339        }
340    }
341    Ok(())
342}
343
344/// Convert witness to Little-Endian bytes (32 bytes per element)
345fn witness_to_bytes(witness: &[BigInt]) -> Vec<u8> {
346    let mut bytes = Vec::with_capacity(
347        witness
348            .len()
349            .checked_mul(32)
350            .expect("Overflow in witness size"),
351    );
352
353    for bi in witness {
354        // Convert BigInt to 32 LE bytes
355        let (sign, be_bytes) = bi.to_bytes_be();
356
357        // Check it fits in 32 bytes
358        assert!(
359            be_bytes.len() <= 32,
360            "Field element exceeds 32 bytes in witness"
361        );
362
363        // Negative numbers should not occur in witness output since inputs
364        // are converted to field elements. Assert this invariant.
365        assert!(
366            sign != Sign::Minus,
367            "Negative number in witness output - inputs should be field elements"
368        );
369
370        // Pad to 32 bytes (big-endian)
371        let mut padded = vec![0u8; 32];
372        let offset = 32usize.saturating_sub(be_bytes.len());
373        padded[offset..].copy_from_slice(&be_bytes[..be_bytes.len().min(32)]);
374
375        // Convert to little-endian
376        padded.reverse();
377        bytes.extend_from_slice(&padded);
378    }
379
380    bytes
381}