1use super::{
2 circom_tester::prove_and_verify,
3 general::scalar_to_bigint,
4 keypair::{derive_public_key, sign},
5 merkle_tree::{merkle_proof, merkle_root},
6 transaction::{commitment, nullifier},
7};
8use crate::test::utils::circom_tester::Inputs;
9use anyhow::{Result, ensure};
10use num_bigint::BigInt;
11use std::{
12 panic::{self, AssertUnwindSafe},
13 path::PathBuf,
14};
15use zkhash::fields::bn256::FpBN256 as Scalar;
16
17#[derive(Clone, Debug)]
18pub struct InputNote {
20 pub leaf_index: usize, pub priv_key: Scalar, pub blinding: Scalar, pub amount: Scalar, }
26
27#[derive(Clone, Debug)]
28pub struct OutputNote {
30 pub pub_key: Scalar,
31 pub blinding: Scalar,
32 pub amount: Scalar,
33}
34
35#[derive(Clone, Debug)]
36pub struct TxCase {
40 pub inputs: Vec<InputNote>,
41 pub outputs: Vec<OutputNote>,
42}
43
44impl TxCase {
45 pub fn new(inputs: Vec<InputNote>, outputs: Vec<OutputNote>) -> Self {
46 Self { inputs, outputs }
47 }
48}
49
50pub struct TransactionWitness {
51 pub root: Scalar,
52 pub public_keys: Vec<Scalar>,
53 pub nullifiers: Vec<Scalar>,
54 pub path_indices: Vec<Scalar>,
55 pub path_elements_flat: Vec<BigInt>,
56}
57
58pub fn prepare_transaction_witness(
76 case: &TxCase,
77 mut leaves: Vec<Scalar>,
78 expected_levels: usize,
79) -> Result<TransactionWitness> {
80 let mut commitments = Vec::with_capacity(case.inputs.len());
81 let mut public_keys = Vec::with_capacity(case.inputs.len());
82
83 for note in &case.inputs {
84 let pk = derive_public_key(note.priv_key);
85 let cm = commitment(note.amount, pk, note.blinding);
86 public_keys.push(pk);
87 commitments.push(cm);
88 leaves[note.leaf_index] = cm;
89 }
90
91 let root = merkle_root(leaves.clone());
92 let mut path_indices = Vec::with_capacity(case.inputs.len());
93 let mut path_elements_flat =
94 Vec::with_capacity(expected_levels.saturating_mul(case.inputs.len()));
95 let mut nullifiers = Vec::with_capacity(case.inputs.len());
96
97 for (i, note) in case.inputs.iter().enumerate() {
98 let (siblings, path_idx_u64, depth) = merkle_proof(&leaves, note.leaf_index);
99 ensure!(
100 depth == expected_levels,
101 "unexpected depth for input {i}, expected {expected_levels}, got {depth}"
102 );
103
104 path_elements_flat.extend(siblings.into_iter().map(scalar_to_bigint));
106
107 let path_idx = Scalar::from(path_idx_u64);
108 path_indices.push(path_idx);
109
110 let sig = sign(note.priv_key, commitments[i], path_idx);
111 let nul = nullifier(commitments[i], path_idx, sig);
112 nullifiers.push(nul);
113 }
114
115 Ok(TransactionWitness {
116 root,
117 public_keys,
118 nullifiers,
119 path_indices,
120 path_elements_flat,
121 })
122}
123
124pub fn build_base_inputs(
140 case: &TxCase,
141 witness: &TransactionWitness,
142 public_amount: Scalar,
143) -> Inputs {
144 let mut inputs = Inputs::new();
145
146 inputs.set("root", scalar_to_bigint(witness.root));
147 inputs.set("publicAmount", scalar_to_bigint(public_amount));
148 inputs.set("extDataHash", BigInt::from(0u32));
149
150 inputs.set("inputNullifier", witness.nullifiers.clone());
151 inputs.set(
152 "inAmount",
153 case.inputs
154 .iter()
155 .map(|n| n.amount)
156 .collect::<Vec<Scalar>>(),
157 );
158 inputs.set(
159 "inPrivateKey",
160 case.inputs
161 .iter()
162 .map(|n| n.priv_key)
163 .collect::<Vec<Scalar>>(),
164 );
165 inputs.set(
166 "inBlinding",
167 case.inputs
168 .iter()
169 .map(|n| n.blinding)
170 .collect::<Vec<Scalar>>(),
171 );
172 inputs.set("inPathIndices", witness.path_indices.clone());
173 inputs.set("inPathElements", witness.path_elements_flat.clone());
174
175 let output_commitments: Vec<BigInt> = case
176 .outputs
177 .iter()
178 .map(|out| scalar_to_bigint(commitment(out.amount, out.pub_key, out.blinding)))
179 .collect();
180 inputs.set("outputCommitment", output_commitments);
181
182 inputs.set(
183 "outAmount",
184 case.outputs
185 .iter()
186 .map(|n| n.amount)
187 .collect::<Vec<Scalar>>(),
188 );
189 inputs.set(
190 "outPubkey",
191 case.outputs
192 .iter()
193 .map(|n| n.pub_key)
194 .collect::<Vec<Scalar>>(),
195 );
196 inputs.set(
197 "outBlinding",
198 case.outputs
199 .iter()
200 .map(|n| n.blinding)
201 .collect::<Vec<Scalar>>(),
202 );
203
204 inputs
205}
206
207pub fn prove_transaction_case(
226 wasm: &PathBuf,
227 r1cs: &PathBuf,
228 case: &TxCase,
229 leaves: Vec<Scalar>,
230 public_amount: Scalar,
231 expected_levels: usize,
232) -> Result<()> {
233 let witness = prepare_transaction_witness(case, leaves, expected_levels)?;
234 let inputs = build_base_inputs(case, &witness, public_amount);
235
236 let prove_result =
237 panic::catch_unwind(AssertUnwindSafe(|| prove_and_verify(wasm, r1cs, &inputs)));
238
239 match prove_result {
240 Ok(Ok(res)) if res.verified => Ok(()),
241 Ok(Ok(_)) => Err(anyhow::anyhow!(
242 "Proof failed to verify (res.verified=false)"
243 )),
244 Ok(Err(e)) => Err(anyhow::anyhow!("Prover error: {e:?}")),
245 Err(panic_info) => {
246 let msg = if let Some(s) = panic_info.downcast_ref::<&str>() {
248 s.to_string()
249 } else if let Some(s) = panic_info.downcast_ref::<String>() {
250 s.clone()
251 } else {
252 "Unknown panic".to_string()
253 };
254 Err(anyhow::anyhow!(
255 "Prover panicked (expected on invalid proof): {msg}"
256 ))
257 }
258 }
259}