1use ark_bn254::Fr;
7use ark_circom::{WitnessCalculator as ArkWitnessCalculator, circom::R1CSFile};
8use num_bigint::{BigInt, Sign};
9use std::{collections::HashMap, io::Cursor, string::String, vec::Vec};
11use wasm_bindgen::prelude::*;
12use wasmer::{Module, Store};
13
14const BN254_FIELD_MODULUS: &str =
16 "21888242871839275222246405745257275088548364400416034343698204186575808495617";
17
18#[wasm_bindgen(start)]
20pub fn init() {
21 console_error_panic_hook::set_once();
22}
23
24#[wasm_bindgen]
26pub fn version() -> String {
27 String::from(env!("CARGO_PKG_VERSION"))
28}
29
30#[wasm_bindgen]
32pub struct WitnessCalculator {
33 store: Store,
35 calculator: ArkWitnessCalculator,
37 witness_size: u32,
39 num_public_inputs: u32,
42}
43
44#[wasm_bindgen]
45impl WitnessCalculator {
46 #[wasm_bindgen(constructor)]
52 pub fn new(circuit_wasm: &[u8], r1cs_bytes: &[u8]) -> Result<WitnessCalculator, JsValue> {
53 let cursor = Cursor::new(r1cs_bytes);
55 let r1cs_file: R1CSFile<Fr> = R1CSFile::new(cursor)
56 .map_err(|e| JsValue::from_str(&format!("Failed to parse R1CS: {}", e)))?;
57
58 let witness_size = r1cs_file.header.n_wires;
59 let num_public_inputs = r1cs_file.header.n_pub_in;
60
61 let mut store = Store::default();
63 let module = Module::new(&store, circuit_wasm)
64 .map_err(|e| JsValue::from_str(&format!("Failed to load circuit WASM: {}", e)))?;
65
66 let calculator = ArkWitnessCalculator::from_module(&mut store, module)
68 .map_err(|e| JsValue::from_str(&format!("Failed to init witness calc: {}", e)))?;
69
70 Ok(WitnessCalculator {
71 store,
72 calculator,
73 witness_size,
74 num_public_inputs,
75 })
76 }
77
78 #[wasm_bindgen]
86 pub fn compute_witness(&mut self, inputs_json: &str) -> Result<Vec<u8>, JsValue> {
87 use serde_json::Value;
88
89 let inputs: Value = serde_json::from_str(inputs_json)
91 .map_err(|e| JsValue::from_str(&format!("Invalid JSON: {}", e)))?;
92
93 let inputs_map = inputs
94 .as_object()
95 .ok_or_else(|| JsValue::from_str("Inputs must be a JSON object"))?;
96
97 let mut inputs_hashmap: HashMap<String, Vec<BigInt>> = HashMap::new();
99
100 for (key, value) in inputs_map {
101 flatten_input(key, value, &mut inputs_hashmap)?;
102 }
103
104 let witness = self
106 .calculator
107 .calculate_witness(&mut self.store, inputs_hashmap, false)
108 .map_err(|e| JsValue::from_str(&format!("Witness calculation failed: {}", e)))?;
109
110 Ok(witness_to_bytes(&witness))
112 }
113
114 #[wasm_bindgen(getter)]
116 pub fn witness_size(&self) -> u32 {
117 self.witness_size
118 }
119
120 #[wasm_bindgen(getter)]
122 pub fn num_public_inputs(&self) -> u32 {
123 self.num_public_inputs
124 }
125}
126
127fn to_field_element(bi: BigInt) -> BigInt {
132 let modulus =
133 BigInt::parse_bytes(BN254_FIELD_MODULUS.as_bytes(), 10).expect("Invalid field modulus");
134
135 if bi.sign() == Sign::Minus {
136 let abs_value = bi
137 .checked_mul(&BigInt::from(-1))
138 .expect("Overflow in getting the abs value"); assert!(
142 abs_value < modulus,
143 "Negative value {} exceeds field modulus",
144 bi
145 );
146
147 modulus
149 .checked_sub(&abs_value)
150 .expect("Overflow in field element computation")
151 } else {
152 assert!(bi < modulus, "Value {} exceeds field modulus", bi);
154 bi
155 }
156}
157
158fn is_pure_array(value: &serde_json::Value) -> bool {
160 use serde_json::Value;
161
162 let mut stack: Vec<&Value> = vec![value];
163
164 while let Some(current) = stack.pop() {
165 match current {
166 Value::Number(_) | Value::String(_) | Value::Bool(_) | Value::Null => {}
167 Value::Array(arr) => {
168 for item in arr {
169 stack.push(item);
170 }
171 }
172 Value::Object(_) => return false,
173 }
174 }
175 true
176}
177
178fn flatten_input(
185 key: &str,
186 value: &serde_json::Value,
187 inputs: &mut HashMap<String, Vec<BigInt>>,
188) -> Result<(), JsValue> {
189 use serde_json::Value;
190
191 let mut stack: Vec<(String, &Value)> = vec![(key.to_string(), value)];
193
194 while let Some((current_key, current_value)) = stack.pop() {
195 match current_value {
196 Value::Number(n) => {
197 let bi = if let Some(i) = n.as_u64() {
198 BigInt::from(i)
199 } else if let Some(i) = n.as_i64() {
200 BigInt::from(i)
201 } else {
202 return Err(JsValue::from_str(&format!(
203 "Invalid number for {}",
204 current_key
205 )));
206 };
207 inputs
209 .entry(current_key)
210 .or_default()
211 .push(to_field_element(bi));
212 }
213 Value::String(s) => {
214 let bi = if let Some(hex) = s.strip_prefix("0x") {
215 BigInt::parse_bytes(hex.as_bytes(), 16)
216 } else {
217 BigInt::parse_bytes(s.as_bytes(), 10)
218 };
219 let bi = bi.ok_or_else(|| {
220 JsValue::from_str(&format!("Invalid bigint for {}: {}", current_key, s))
221 })?;
222 inputs
224 .entry(current_key)
225 .or_default()
226 .push(to_field_element(bi));
227 }
228 Value::Array(arr) => {
229 if is_pure_array(current_value) {
231 flatten_pure_array(¤t_key, current_value, inputs)?;
232 } else {
233 for (idx, item) in arr.iter().enumerate().rev() {
236 let indexed_key = format!("{}[{}]", current_key, idx);
237 stack.push((indexed_key, item));
238 }
239 }
240 }
241 Value::Object(obj) => {
242 for (field, val) in obj {
244 let nested_key = format!("{}.{}", current_key, field);
245 stack.push((nested_key, val));
246 }
247 }
248 Value::Bool(b) => {
249 let bi = if *b { BigInt::from(1) } else { BigInt::from(0) };
250 inputs.entry(current_key).or_default().push(bi);
251 }
252 Value::Null => {
253 inputs.entry(current_key).or_default().push(BigInt::from(0));
254 }
255 }
256 }
257 Ok(())
258}
259
260fn flatten_pure_array(
262 key: &str,
263 value: &serde_json::Value,
264 inputs: &mut HashMap<String, Vec<BigInt>>,
265) -> Result<(), JsValue> {
266 use serde_json::Value;
267
268 enum WorkItem<'a> {
272 Value(&'a Value),
273 ArrayIter { arr: &'a [Value], idx: usize },
274 }
275
276 let mut stack: Vec<WorkItem<'_>> = vec![WorkItem::Value(value)];
277
278 while let Some(item) = stack.pop() {
279 match item {
280 WorkItem::Value(v) => match v {
281 Value::Number(n) => {
282 let bi = if let Some(i) = n.as_u64() {
283 BigInt::from(i)
284 } else if let Some(i) = n.as_i64() {
285 BigInt::from(i)
286 } else {
287 return Err(JsValue::from_str(&format!("Invalid number for {}", key)));
288 };
289 inputs
290 .entry(key.to_string())
291 .or_default()
292 .push(to_field_element(bi));
293 }
294 Value::String(s) => {
295 let bi = if let Some(hex) = s.strip_prefix("0x") {
296 BigInt::parse_bytes(hex.as_bytes(), 16)
297 } else {
298 BigInt::parse_bytes(s.as_bytes(), 10)
299 };
300 let bi = bi.ok_or_else(|| {
301 JsValue::from_str(&format!("Invalid bigint for {}: {}", key, s))
302 })?;
303 inputs
304 .entry(key.to_string())
305 .or_default()
306 .push(to_field_element(bi));
307 }
308 Value::Array(arr) => {
309 if !arr.is_empty() {
310 stack.push(WorkItem::ArrayIter { arr, idx: 0 });
311 }
312 }
313 Value::Bool(b) => {
314 let bi = if *b { BigInt::from(1) } else { BigInt::from(0) };
315 inputs.entry(key.to_string()).or_default().push(bi);
316 }
317 Value::Null => {
318 inputs
319 .entry(key.to_string())
320 .or_default()
321 .push(BigInt::from(0));
322 }
323 Value::Object(_) => {
324 return Err(JsValue::from_str(&format!(
325 "Unexpected object in pure array: {}",
326 key
327 )));
328 }
329 },
330 WorkItem::ArrayIter { arr, idx } => {
331 let next_idx = idx.saturating_add(1);
333 if next_idx < arr.len() {
334 stack.push(WorkItem::ArrayIter { arr, idx: next_idx });
335 }
336 stack.push(WorkItem::Value(&arr[idx]));
338 }
339 }
340 }
341 Ok(())
342}
343
344fn witness_to_bytes(witness: &[BigInt]) -> Vec<u8> {
346 let mut bytes = Vec::with_capacity(
347 witness
348 .len()
349 .checked_mul(32)
350 .expect("Overflow in witness size"),
351 );
352
353 for bi in witness {
354 let (sign, be_bytes) = bi.to_bytes_be();
356
357 assert!(
359 be_bytes.len() <= 32,
360 "Field element exceeds 32 bytes in witness"
361 );
362
363 assert!(
366 sign != Sign::Minus,
367 "Negative number in witness output - inputs should be field elements"
368 );
369
370 let mut padded = vec![0u8; 32];
372 let offset = 32usize.saturating_sub(be_bytes.len());
373 padded[offset..].copy_from_slice(&be_bytes[..be_bytes.len().min(32)]);
374
375 padded.reverse();
377 bytes.extend_from_slice(&padded);
378 }
379
380 bytes
381}