circuits/test/utils/
sparse_merkle_tree.rs

1//! Sparse Merkle Tree implementation compatible with circomlibjs/smt.js
2//!
3//! This is a Rust port of the Sparse Merkle Tree implementation from:
4//! - JavaScript: https://github.com/iden3/circomlibjs/blob/main/src/smt.js
5//!
6//! This implementation uses Poseidon2 hash function for compatibility with
7//! circomlib circuits.
8use crate::test::utils::general::{
9    poseidon2_compression as poseidon2_compression_bn256, poseidon2_hash2 as poseidon2_hash2_bn256,
10};
11use anyhow::{Result, anyhow};
12use num_bigint::{BigInt, BigUint};
13use num_integer::Integer;
14use std::{collections::HashMap, ops::Shr};
15use zkhash::{
16    ark_ff::{BigInteger, PrimeField},
17    fields::bn256::FpBN256,
18};
19/// Reduce a num_bigint::BigInt modulo the BN256 field modulus and convert to
20/// FpBN256. Circom circuits operate inside the BN256 scalar field, so every
21/// BigInt we hash must be reduced.
22fn big_int_to_fp(x: &BigInt) -> FpBN256 {
23    // Get the field modulus as a num_bigint::BigInt
24    let modulus_bytes = FpBN256::MODULUS.to_bytes_be();
25    let modulus_bigint = BigInt::from_bytes_be(num_bigint::Sign::Plus, &modulus_bytes);
26
27    // Floor-mod reduce into [0, modulus)
28    let reduced = x.mod_floor(&modulus_bigint);
29
30    // Convert non-negative BigInt to BigUint, then into FpBN256
31    let (_sign, bytes) = reduced.to_bytes_be();
32    let as_biguint = BigUint::from_bytes_be(&bytes);
33
34    FpBN256::from(as_biguint)
35}
36
37/// Poseidon2 hash of two field elements using optimized compression mode
38///
39/// Hash function for BigInt values, used for inner nodes of the sparse Merkle
40/// tree. Converts BigInt inputs to field elements, performs Poseidon2
41/// compression, and converts the result back to BigInt.
42///
43/// # Arguments
44///
45/// * `left` - Left input as BigInt
46/// * `right` - Right input as BigInt
47///
48/// # Returns
49///
50/// Returns the hash result as a BigInt value.
51pub fn poseidon2_compression_sparse(left: &BigInt, right: &BigInt) -> BigInt {
52    let left_fp = big_int_to_fp(left);
53    let right_fp = big_int_to_fp(right);
54
55    let perm = poseidon2_compression_bn256(left_fp, right_fp);
56
57    fp_bn256_to_big_int(&perm)
58}
59
60/// Poseidon2 hash function for leaf nodes (key, value, 1)
61///
62/// Computes the hash for leaf nodes using Poseidon2 with three inputs.
63/// Mirrors circomlibjs "hash1" function so roots generated here match
64/// the JavaScript prover and test tooling.
65///
66/// # Arguments
67///
68/// * `key` - Leaf key as BigInt
69/// * `value` - Leaf value as BigInt
70///
71/// # Returns
72///
73/// Returns the leaf hash as a BigInt value.
74pub fn poseidon2_hash3_sparse(key: &BigInt, value: &BigInt) -> BigInt {
75    let key_fp = big_int_to_fp(key);
76    let value_fp = big_int_to_fp(value);
77    let one_fp = FpBN256::from(1u64);
78
79    let result = poseidon2_hash2_bn256(key_fp, value_fp, Some(one_fp));
80
81    fp_bn256_to_big_int(&result)
82}
83
84/// Convert FpBN256 to BigInt
85fn fp_bn256_to_big_int(fp: &FpBN256) -> BigInt {
86    let bytes = fp.into_bigint().to_bytes_be();
87    BigInt::from_bytes_be(num_bigint::Sign::Plus, &bytes)
88}
89
90/// Database trait for SMT storage
91pub trait SMTDatabase {
92    /// Get a value from the database
93    fn get(&self, key: &BigInt) -> Option<Vec<BigInt>>;
94    /// Set a value in the database
95    fn set(&mut self, key: BigInt, value: Vec<BigInt>);
96    /// Delete a value from the database
97    fn delete(&mut self, key: &BigInt);
98    /// Get the current root
99    fn get_root(&self) -> BigInt;
100    /// Set the current root
101    fn set_root(&mut self, root: BigInt);
102    /// Insert multiple values
103    fn multi_ins(&mut self, inserts: Vec<(BigInt, Vec<BigInt>)>);
104    /// Delete multiple values
105    fn multi_del(&mut self, deletes: Vec<BigInt>);
106}
107
108/// In-memory database implementation
109/// Stores every node (leaves and internal nodes) as raw BigInt vectors,
110/// matching circomlibjs layout.
111pub struct SMTMemDB {
112    data: HashMap<BigInt, Vec<BigInt>>, // key -> [value, sibling1, sibling2, ...]
113    root: BigInt,
114}
115
116impl SMTMemDB {
117    /// Create a new in-memory database
118    pub fn new() -> Self {
119        Self {
120            data: HashMap::new(),
121            root: BigInt::from(0u32),
122        }
123    }
124}
125impl Default for SMTMemDB {
126    fn default() -> Self {
127        Self::new()
128    }
129}
130
131impl SMTDatabase for SMTMemDB {
132    fn get(&self, key: &BigInt) -> Option<Vec<BigInt>> {
133        self.data.get(key).cloned()
134    }
135
136    fn set(&mut self, key: BigInt, value: Vec<BigInt>) {
137        self.data.insert(key, value);
138    }
139
140    fn delete(&mut self, key: &BigInt) {
141        self.data.remove(key);
142    }
143
144    fn get_root(&self) -> BigInt {
145        self.root.clone()
146    }
147
148    fn set_root(&mut self, root: BigInt) {
149        self.root = root;
150    }
151
152    fn multi_ins(&mut self, inserts: Vec<(BigInt, Vec<BigInt>)>) {
153        for (key, value) in inserts {
154            self.data.insert(key, value);
155        }
156    }
157
158    fn multi_del(&mut self, deletes: Vec<BigInt>) {
159        for key in deletes {
160            self.data.remove(&key);
161        }
162    }
163}
164
165/// Sparse Merkle Tree implementation matching circomlibjs/smt.js
166/// Provides insert/update/delete/find helpers that operate entirely over
167/// BigInts so test harnesses can generate witnesses identical to the JavaScript
168/// reference implementation.
169pub struct SparseMerkleTree<DB: SMTDatabase> {
170    db: DB,
171    root: BigInt,
172}
173
174/// Result of SMT operations
175#[derive(Debug, Clone)]
176pub struct SMTResult {
177    /// The old root before the operation
178    pub old_root: BigInt,
179    /// The new root after the operation
180    pub new_root: BigInt,
181    /// Sibling hashes along the path
182    pub siblings: Vec<BigInt>,
183    /// The old key
184    pub old_key: BigInt,
185    /// The old value
186    pub old_value: BigInt,
187    /// The new key
188    pub new_key: BigInt,
189    /// The new value
190    pub new_value: BigInt,
191    /// Whether the old value was zero
192    pub is_old0: bool,
193}
194
195/// Find result for internal operations
196#[derive(Debug, Clone)]
197pub struct FindResult {
198    /// Whether the key was found
199    pub found: bool,
200    /// Sibling hashes along the path
201    pub siblings: Vec<BigInt>,
202    /// The found value
203    pub found_value: BigInt,
204    /// The key that was not found (for collision detection)
205    pub not_found_key: BigInt,
206    /// The value that was not found
207    pub not_found_value: BigInt,
208    /// Whether the old value was zero
209    pub is_old0: bool,
210}
211
212impl<DB: SMTDatabase> SparseMerkleTree<DB> {
213    /// Create a new Sparse Merkle Tree
214    ///
215    /// # Arguments
216    ///
217    /// * `db` - Database implementation for storing tree nodes
218    /// * `root` - Initial root value (typically BigInt::from(0) for empty tree)
219    ///
220    /// # Returns
221    ///
222    /// Returns a new `SparseMerkleTree` instance.
223    pub fn new(db: DB, root: BigInt) -> Self {
224        Self { db, root }
225    }
226
227    /// Get the current root of the tree
228    ///
229    /// # Returns
230    ///
231    /// Returns a reference to the current root BigInt value.
232    pub fn root(&self) -> &BigInt {
233        &self.root
234    }
235
236    /// Split key into bits (256 bits total)
237    /// This should match the JavaScript implementation which uses Scalar.bits()
238    /// so we traverse identical paths for a given key.
239    fn split_bits(&self, key: &BigInt) -> Vec<bool> {
240        let mut bits = Vec::with_capacity(256);
241        let mut key = key.clone();
242
243        // Extract bits from LSB to MSB (same as JavaScript Scalar.bits())
244        for _ in 0..256 {
245            bits.push(key.bit(0));
246            key = key.shr(1u32);
247        }
248
249        bits
250    }
251
252    /// Update a key-value pair in the tree
253    ///
254    /// Recomputes all nodes along the path and persists them in the backing
255    /// database. This mirrors circomlibjs' update logic where we first
256    /// delete the old leaf and then rebuild the path with the new value
257    /// while tracking every intermediate node for witnesses.
258    ///
259    /// # Arguments
260    ///
261    /// * `key` - Key to update
262    /// * `new_value` - New value to associate with the key
263    ///
264    /// # Returns
265    ///
266    /// Returns `Ok(SMTResult)` containing the old and new roots, siblings, and
267    /// operation metadata, or an error if the key is not found or database
268    /// operations fail.
269    pub fn update(&mut self, key: &BigInt, new_value: &BigInt) -> Result<SMTResult> {
270        let res_find = self.find(key)?;
271        let mut res = SMTResult {
272            old_root: self.root.clone(),
273            new_root: BigInt::from(0u32),
274            siblings: res_find.siblings.clone(),
275            old_key: key.clone(),
276            old_value: res_find.found_value.clone(),
277            new_key: key.clone(),
278            new_value: new_value.clone(),
279            is_old0: res_find.is_old0,
280        };
281
282        let mut inserts = Vec::new();
283        let mut deletes = Vec::new();
284
285        let rt_old = poseidon2_hash3_sparse(key, &res_find.found_value);
286        let rt_new = poseidon2_hash3_sparse(key, new_value);
287        inserts.push((
288            rt_new.clone(),
289            vec![BigInt::from(1u32), key.clone(), new_value.clone()],
290        ));
291        deletes.push(rt_old.clone());
292
293        let key_bits = self.split_bits(key);
294        let mut current_rt_old = rt_old;
295        let mut current_rt_new = rt_new;
296
297        for level in (0..res_find.siblings.len()).rev() {
298            let sibling = &res_find.siblings[level];
299            // Rebuild nodes from the bottom up; depending on the bit we decide left/right
300            // order.
301            let (old_node, new_node) = if key_bits[level] {
302                (
303                    vec![sibling.clone(), current_rt_old.clone()],
304                    vec![sibling.clone(), current_rt_new.clone()],
305                )
306            } else {
307                (
308                    vec![current_rt_old.clone(), sibling.clone()],
309                    vec![current_rt_new.clone(), sibling.clone()],
310                )
311            };
312
313            current_rt_old = poseidon2_compression_sparse(&old_node[0], &old_node[1]);
314            current_rt_new = poseidon2_compression_sparse(&new_node[0], &new_node[1]);
315            deletes.push(current_rt_old.clone());
316            inserts.push((current_rt_new.clone(), new_node));
317        }
318
319        res.new_root = current_rt_new.clone();
320
321        self.db.multi_del(deletes);
322        self.db.multi_ins(inserts);
323        self.db.set_root(current_rt_new.clone());
324        self.root = current_rt_new;
325
326        Ok(res)
327    }
328
329    /// Delete a key from the tree
330    ///
331    /// Handles both sparse branches (single child) and mixed branches (two
332    /// populated children). The logic follows smt.js closely: collapse
333    /// empty branches while keeping collision nodes.
334    ///
335    /// # Arguments
336    ///
337    /// * `key` - Key to delete from the tree
338    ///
339    /// # Returns
340    ///
341    /// Returns `Ok(SMTResult)` containing the old and new roots, siblings, and
342    /// operation metadata, or an error if the key does not exist or
343    /// database operations fail.
344    pub fn delete(&mut self, key: &BigInt) -> Result<SMTResult> {
345        let res_find = self.find(key)?;
346        if !res_find.found {
347            return Err(anyhow!("Key does not exist"));
348        }
349
350        let mut res = SMTResult {
351            old_root: self.root.clone(),
352            new_root: BigInt::from(0u32),
353            siblings: Vec::new(),
354            old_key: key.clone(),
355            old_value: res_find.found_value.clone(),
356            new_key: key.clone(),
357            new_value: BigInt::from(0u32),
358            is_old0: false,
359        };
360
361        let mut deletes = Vec::new();
362        let mut inserts = Vec::new();
363        let mut rt_old = poseidon2_hash3_sparse(key, &res_find.found_value);
364        let mut rt_new;
365        deletes.push(rt_old.clone());
366
367        let key_bits = self.split_bits(key);
368        let mut mixed = false;
369
370        if let Some(last_sibling) = res_find.siblings.last() {
371            if let Some(record) = self.db.get(last_sibling) {
372                if record.len() == 3 && record[0] == BigInt::from(1u32) {
373                    mixed = false;
374                    res.old_key = record[1].clone();
375                    res.old_value = record[2].clone();
376                    res.is_old0 = false;
377                    rt_new = last_sibling.clone();
378                } else if record.len() == 2 {
379                    mixed = true;
380                    res.old_key = key.clone();
381                    res.old_value = BigInt::from(0u32);
382                    res.is_old0 = true;
383                    rt_new = BigInt::from(0u32);
384                } else {
385                    return Err(anyhow!("Invalid node. Database corrupted"));
386                }
387            } else {
388                return Err(anyhow!("Sibling not found"));
389            }
390        } else {
391            rt_new = BigInt::from(0u32);
392            res.old_key = key.clone();
393            res.is_old0 = true;
394        }
395
396        for level in (0..res_find.siblings.len()).rev() {
397            let mut new_sibling = res_find.siblings[level].clone();
398            if Some(level) == res_find.siblings.len().checked_sub(1) && !res.is_old0 {
399                new_sibling = BigInt::from(0u32);
400            }
401            let old_sibling = res_find.siblings[level].clone();
402
403            // Remove the old branch hash because the leaf is being deleted.
404            if key_bits[level] {
405                rt_old = poseidon2_compression_sparse(&old_sibling, &rt_old);
406            } else {
407                rt_old = poseidon2_compression_sparse(&rt_old, &old_sibling);
408            }
409            deletes.push(rt_old.clone());
410
411            if new_sibling != BigInt::from(0u32) {
412                mixed = true;
413            }
414
415            if mixed {
416                // Once we hit a mixed branch we need to keep rebuilding upwards.
417                res.siblings.insert(0, res_find.siblings[level].clone());
418                let new_node = if key_bits[level] {
419                    vec![new_sibling, rt_new.clone()]
420                } else {
421                    vec![rt_new.clone(), new_sibling]
422                };
423                rt_new = poseidon2_compression_sparse(&new_node[0], &new_node[1]);
424                inserts.push((rt_new.clone(), new_node));
425            }
426        }
427
428        self.db.multi_ins(inserts);
429        self.db.set_root(rt_new.clone());
430        self.root = rt_new.clone();
431        self.db.multi_del(deletes);
432
433        res.new_root = rt_new;
434        res.old_root = rt_old;
435
436        Ok(res)
437    }
438
439    /// Insert a new key-value pair
440    ///
441    /// Builds any missing intermediate nodes so the resulting tree mirrors the
442    /// JS SMT.
443    ///
444    /// # Arguments
445    ///
446    /// * `key` - Key to insert
447    /// * `value` - Value to associate with the key
448    ///
449    /// # Returns
450    ///
451    /// Returns `Ok(SMTResult)` containing the old and new roots, siblings, and
452    /// operation metadata, or an error if the key already exists or
453    /// database operations fail.
454    pub fn insert(&mut self, key: &BigInt, value: &BigInt) -> Result<SMTResult> {
455        let mut res = SMTResult {
456            old_root: self.root.clone(),
457            new_root: BigInt::from(0u32),
458            siblings: Vec::new(),
459            old_key: key.clone(),
460            old_value: BigInt::from(0u32),
461            new_key: key.clone(),
462            new_value: value.clone(),
463            is_old0: false,
464        };
465        res.old_root = self.root.clone();
466        let new_key_bits = self.split_bits(key);
467        let res_find = self.find(key)?;
468
469        if res_find.found {
470            return Err(anyhow!("Key already exists"));
471        }
472
473        res.siblings = res_find.siblings.clone();
474        let mut mixed = false;
475        let mut rt_old = BigInt::from(0u32);
476        let mut added_one = false;
477
478        if !res_find.is_old0 {
479            let old_key_bits = self.split_bits(&res_find.not_found_key);
480            let mut i = res.siblings.len();
481            while i < old_key_bits.len() && old_key_bits[i] == new_key_bits[i] {
482                res.siblings.push(BigInt::from(0u32));
483                i = i.saturating_add(1);
484            }
485            rt_old = poseidon2_hash3_sparse(&res_find.not_found_key, &res_find.not_found_value);
486            res.siblings.push(rt_old.clone());
487            added_one = true;
488            mixed = false;
489        } else if !res.siblings.is_empty() {
490            mixed = true;
491            rt_old = BigInt::from(0u32);
492        }
493
494        let mut inserts = Vec::new();
495        let mut deletes = Vec::new();
496
497        let mut rt = poseidon2_hash3_sparse(key, value);
498        inserts.push((
499            rt.clone(),
500            vec![BigInt::from(1u32), key.clone(), value.clone()],
501        ));
502
503        for (i, sibling) in res.siblings.iter().enumerate().rev() {
504            if i < res.siblings.len().saturating_sub(1) && sibling != &BigInt::from(0u32) {
505                mixed = true;
506            }
507
508            if mixed {
509                let old_sibling = res_find.siblings[i].clone();
510                if new_key_bits[i] {
511                    rt_old = poseidon2_compression_sparse(&old_sibling, &rt_old);
512                } else {
513                    rt_old = poseidon2_compression_sparse(&rt_old, &old_sibling);
514                }
515                deletes.push(rt_old.clone());
516            }
517
518            let new_rt = if new_key_bits[i] {
519                poseidon2_compression_sparse(&res.siblings[i], &rt)
520            } else {
521                poseidon2_compression_sparse(&rt, &res.siblings[i])
522            };
523            let new_node = if new_key_bits[i] {
524                vec![res.siblings[i].clone(), rt.clone()]
525            } else {
526                vec![rt.clone(), res.siblings[i].clone()]
527            };
528            inserts.push((new_rt.clone(), new_node));
529            rt = new_rt;
530        }
531
532        if added_one {
533            res.siblings.pop();
534        }
535        while res
536            .siblings
537            .last()
538            .is_some_and(|s| s == &BigInt::from(0u32))
539        {
540            res.siblings.pop();
541        }
542
543        res.old_key = res_find.not_found_key;
544        res.old_value = res_find.not_found_value;
545        res.new_root = rt.clone();
546        res.is_old0 = res_find.is_old0;
547
548        self.db.multi_ins(inserts);
549        self.db.set_root(rt.clone());
550        self.root = rt;
551        self.db.multi_del(deletes);
552        Ok(res)
553    }
554
555    /// Find a key in the tree
556    ///
557    /// Returns the Merkle siblings required to reconstruct the path in
558    /// circuits/tests. Also surfaces whether the path ended in a leaf
559    /// collision (non-existent key with same path).
560    ///
561    /// # Arguments
562    ///
563    /// * `key` - Key to search for in the tree
564    ///
565    /// # Returns
566    ///
567    /// Returns `Ok(FindResult)` containing whether the key was found, siblings
568    /// along the path, and collision information, or an error if database
569    /// operations fail.
570    pub fn find(&self, key: &BigInt) -> Result<FindResult> {
571        let key_bits = self.split_bits(key);
572        self._find(key, &key_bits, &self.root, 0)
573    }
574
575    /// Internal find method
576    /// Recurses through the DB-stored nodes, replicating smt.js behavior
577    /// exactly. It walks the tree using the bit-decomposed key, returning
578    /// collision data when the search stops early (i.e. we reached a leaf
579    /// whose key differs from the query).
580    fn _find(
581        &self,
582        key: &BigInt,
583        key_bits: &[bool],
584        root: &BigInt,
585        level: usize,
586    ) -> Result<FindResult> {
587        if *root == BigInt::from(0u32) {
588            return Ok(FindResult {
589                found: false,
590                siblings: Vec::new(),
591                found_value: BigInt::from(0u32),
592                not_found_key: key.clone(),
593                not_found_value: BigInt::from(0u32),
594                is_old0: true,
595            });
596        }
597
598        if let Some(record) = self.db.get(root) {
599            if record.len() == 3 && record[0] == BigInt::from(1u32) {
600                if record[1] == *key {
601                    Ok(FindResult {
602                        found: true,
603                        siblings: Vec::new(),
604                        found_value: record[2].clone(),
605                        not_found_key: BigInt::from(0u32),
606                        not_found_value: BigInt::from(0u32),
607                        is_old0: false,
608                    })
609                } else {
610                    Ok(FindResult {
611                        found: false,
612                        siblings: Vec::new(),
613                        found_value: BigInt::from(0u32),
614                        not_found_key: record[1].clone(),
615                        not_found_value: record[2].clone(),
616                        is_old0: false,
617                    })
618                }
619            } else if record.len() == 2 {
620                let next_level = level
621                    .checked_add(1)
622                    .expect("tree level overflow in sparse_merkle_tree::_find");
623                let mut res = if !key_bits[level] {
624                    self._find(key, key_bits, &record[0], next_level)?
625                } else {
626                    self._find(key, key_bits, &record[1], next_level)?
627                };
628                res.siblings.insert(
629                    0,
630                    if !key_bits[level] {
631                        record[1].clone()
632                    } else {
633                        record[0].clone()
634                    },
635                );
636                Ok(res)
637            } else {
638                Err(anyhow!("Invalid record format"))
639            }
640        } else {
641            Err(anyhow!("Node not found in database"))
642        }
643    }
644}
645
646/// Proof data tailored for Circom inputs (BigInt-based).
647#[derive(Clone, Debug)]
648pub struct SMTProof {
649    pub found: bool,
650    pub siblings: Vec<BigInt>,
651    pub found_value: BigInt,
652    pub not_found_key: BigInt,
653    pub not_found_value: BigInt,
654    pub is_old0: bool,
655    pub root: BigInt,
656}
657
658fn finalize_proof(tree: &SparseMerkleTree<SMTMemDB>, key: &BigInt, max_levels: usize) -> SMTProof {
659    let find_result = tree.find(key).expect("Failed to find key");
660
661    // Pad siblings with zeros to reach max_levels
662    let mut siblings = find_result.siblings.clone();
663    while siblings.len() < max_levels {
664        siblings.push(BigInt::from(0u32));
665    }
666
667    SMTProof {
668        found: find_result.found,
669        siblings,
670        found_value: find_result.found_value,
671        not_found_key: find_result.not_found_key,
672        not_found_value: find_result.not_found_value,
673        is_old0: find_result.is_old0,
674        root: tree.root().clone(),
675    }
676}
677
678/// Prepare an SMT proof after pre-populating the tree with values 0..100
679///
680/// Creates a new sparse Merkle tree, inserts values from 0 to 100 (or up to
681/// 2^max_levels), and generates a proof for the specified key. The proof
682/// includes siblings padded to max_levels with zeros.
683///
684/// # Arguments
685///
686/// * `key` - Key to generate a proof for
687/// * `max_levels` - Maximum number of tree levels (siblings will be padded to
688///   this length)
689///
690/// # Returns
691///
692/// Returns an `SMTProof` containing the proof data for the specified key.
693pub fn prepare_smt_proof(key: &BigInt, max_levels: usize) -> SMTProof {
694    let db = SMTMemDB::new();
695    let mut smt = SparseMerkleTree::new(db, BigInt::from(0u32));
696
697    // Tree can address at most 2^max_levels leaves.
698    let max_leaves = 1usize
699        .checked_shl(u32::try_from(max_levels).expect("Failed to cast max_levels to u32"))
700        .unwrap_or(usize::MAX);
701
702    let num_leaves = 100usize.min(max_leaves);
703
704    for i in 0..num_leaves {
705        let bi = BigInt::from(i);
706        smt.insert(&bi, &bi).expect("Failed to insert key");
707    }
708
709    finalize_proof(&smt, key, max_levels)
710}
711
712/// Build a sparse SMT from `overrides` and return a proof for `key`.
713/// `overrides` is (key, value) pairs already reduced modulo field.
714pub fn prepare_smt_proof_with_overrides(
715    key: &BigInt,
716    overrides: &[(BigInt, BigInt)],
717    max_levels: usize,
718) -> SMTProof {
719    let db = SMTMemDB::new();
720    let mut smt = SparseMerkleTree::new(db, BigInt::from(0u32));
721
722    for (k, v) in overrides {
723        smt.insert(k, v).expect("SMT insert failed");
724    }
725
726    finalize_proof(&smt, key, max_levels)
727}
728
729/// Create a new empty SMT with an in-memory database
730pub fn new_mem_empty_trie() -> SparseMerkleTree<SMTMemDB> {
731    let db = SMTMemDB::new();
732    let root = db.get_root();
733    SparseMerkleTree::new(db, root)
734}
735
736#[cfg(test)]
737mod tests {
738    use super::*;
739    use num_bigint::BigInt;
740    use std::str::FromStr;
741
742    #[test]
743    fn test_smt_creation() {
744        let smt = new_mem_empty_trie();
745        assert_eq!(*smt.root(), BigInt::from(0u32));
746    }
747
748    #[test]
749    fn test_smt_insert() {
750        let mut smt = new_mem_empty_trie();
751        let key = BigInt::from(1u32);
752        let value = BigInt::from(42u32);
753
754        let result = smt.insert(&key, &value).expect("Insert method failed");
755        assert_eq!(result.new_key, key);
756        assert_eq!(result.new_value, value);
757        assert!(result.is_old0); // First insert should be old0
758    }
759
760    #[test]
761    fn test_smt_update() {
762        let mut smt = new_mem_empty_trie();
763        let key = BigInt::from(42u32);
764        let value1 = BigInt::from(42u32);
765        let value2 = BigInt::from(100u32);
766
767        smt.insert(&key, &value1).expect("Insert method failed");
768        let result = smt.update(&key, &value2).expect("Update method failed");
769
770        assert_eq!(result.old_value, value1);
771        assert_eq!(result.new_value, value2);
772        assert!(!result.is_old0); // Update should not be old0
773    }
774
775    #[test]
776    fn test_smt_delete() {
777        let mut smt = new_mem_empty_trie();
778        let key = BigInt::from(1u32);
779        let value = BigInt::from(42u32);
780
781        smt.insert(&key, &value).expect("Insert method failed");
782        let result = smt.delete(&key).expect("Delete method failed");
783
784        assert_eq!(result.old_key, key);
785        assert_eq!(result.old_value, value);
786    }
787
788    #[test]
789    fn test_smt_find() {
790        let mut smt = new_mem_empty_trie();
791        let key = BigInt::from(1u32);
792        let value = BigInt::from(42u32);
793
794        smt.insert(&key, &value).expect("Insert method failed");
795        let find_result = smt.find(&key).expect("Find method failed");
796
797        assert!(find_result.found);
798        assert_eq!(find_result.found_value, value);
799    }
800
801    #[test]
802    fn test_smt_multiple_keys() {
803        let mut smt = new_mem_empty_trie();
804        let keys = [
805            BigInt::from(1u32),
806            BigInt::from(2u32),
807            BigInt::from(3u32),
808            BigInt::from(100u32),
809        ];
810
811        for (i, key) in keys.iter().enumerate() {
812            let value =
813                BigInt::from(u32::try_from((i + 1) * 10).expect("Could not convert into u32"));
814            smt.insert(key, &value).expect("Insert method failed");
815        }
816
817        for (i, key) in keys.iter().enumerate() {
818            let find_result = smt.find(key).expect("Find method failed");
819            assert!(find_result.found);
820            assert_eq!(
821                find_result.found_value,
822                BigInt::from(u32::try_from((i + 1) * 10).expect("Could not convert into u32"))
823            );
824        }
825    }
826
827    #[test]
828    fn test_smt_duplicate_insert() {
829        let mut smt = new_mem_empty_trie();
830        let key = BigInt::from(1u32);
831        let value = BigInt::from(42u32);
832
833        smt.insert(&key, &value).expect("Insert method failed");
834        let result = smt.insert(&key, &value);
835
836        assert!(result.is_err());
837        assert!(
838            result
839                .expect_err("Expected error")
840                .to_string()
841                .contains("Key already exists")
842        );
843    }
844
845    #[test]
846    fn test_smt_delete_nonexistent() {
847        let mut smt = new_mem_empty_trie();
848        let key = BigInt::from(1u32);
849
850        let result = smt.delete(&key);
851        assert!(result.is_err());
852        assert!(
853            result
854                .expect_err("Expected error")
855                .to_string()
856                .contains("Key does not exist")
857        );
858    }
859
860    // Test to verify our SMT implementation works correctly
861    // Expected values are extracted from the original JS implementation
862    #[test]
863    fn test_new_tree() {
864        let mut smt = new_mem_empty_trie();
865        assert_eq!(*smt.root(), BigInt::from(0u32));
866
867        let result = smt
868            .insert(&BigInt::from(1u32), &BigInt::from(42u32))
869            .expect("Insert method failed");
870
871        // The root should change after insertion
872        assert_ne!(result.old_root, result.new_root);
873        assert_eq!(result.old_root, BigInt::from(0u32));
874
875        // For the first insertion, the root should be
876        let expected_root = BigInt::from_str(
877            "16367784008464358864143154554494062552082491393210070322357217564588163898018",
878        )
879        .expect("Could not transform expected root into str");
880        assert_eq!(result.new_root, expected_root);
881
882        // Test update
883        let result = smt
884            .update(&BigInt::from(1u32), &BigInt::from(100u32))
885            .expect("Update method failed");
886
887        // Root should change after update
888        assert_ne!(result.old_root, result.new_root);
889        let expected_root = BigInt::from_str(
890            "12569474685065514766800302626776627658362290519786081498087070427717152263146",
891        )
892        .expect("Could not transform expected root into str");
893        assert_eq!(result.new_root, expected_root);
894
895        // Verify we can find the updated value
896        let find_result = smt.find(&BigInt::from(1u32)).expect("Find method failed");
897        assert!(find_result.found);
898        assert_eq!(find_result.found_value, BigInt::from(100u32));
899        assert!(find_result.found);
900        assert_eq!(find_result.found_value, BigInt::from(100u32));
901
902        // Add a new leaf
903        let result = smt
904            .insert(&BigInt::from(2u32), &BigInt::from(324u32))
905            .expect("Insert method failed");
906        let expected_root = BigInt::from_str(
907            "3902199042378325593738217753401508381332249645815458444537710669740236044308",
908        )
909        .expect("Could not transform expected root into str");
910        assert_eq!(result.new_root, expected_root);
911    }
912    // Test to verify our SMT implementation works correctly
913    // Expected values are extracted from the original JS implementation
914    #[test]
915    fn test_tree_proofs() {
916        let mut smt = new_mem_empty_trie();
917        assert_eq!(*smt.root(), BigInt::from(0u32));
918
919        // Add some leaves
920        smt.insert(&BigInt::from(1u32), &BigInt::from(1u32))
921            .expect("Insert method failed");
922
923        let find_result = smt.find(&BigInt::from(1u32)).expect("Find method failed");
924        assert!(find_result.found);
925        assert_eq!(find_result.found_value, BigInt::from(1u32));
926        assert_eq!(find_result.siblings.len(), 0);
927        assert!(!find_result.is_old0);
928
929        // Let's try to find a non-existent key
930        let find_result = smt.find(&BigInt::from(999u32)).expect("Find method failed");
931        assert!(!find_result.found);
932        assert_eq!(find_result.found_value, BigInt::from(0u32));
933        assert_eq!(find_result.siblings.len(), 0);
934        assert!(!find_result.is_old0);
935
936        // Add more keys
937        for i in 2u32..100 {
938            smt.insert(&BigInt::from(i), &BigInt::from(i))
939                .expect("Insert method failed");
940        }
941
942        // Check that we can find some of the keys
943        let find_result = smt.find(&BigInt::from(77u32)).expect("Find method failed");
944        assert!(find_result.found);
945        assert_eq!(find_result.found_value, BigInt::from(77u32));
946        assert_eq!(find_result.siblings.len(), 7);
947        assert_eq!(
948            find_result.siblings,
949            vec![
950                BigInt::from_str(
951                    "13574531720454277968647792690830483941675832953896828594235298772144774821296"
952                )
953                .expect("Could not transform sibling into str"),
954                BigInt::from_str(
955                    "21822809487696252201955801325867744685997250399099680635153759270255930459663"
956                )
957                .expect("Could not transform sibling into str"),
958                BigInt::from_str(
959                    "2754153135680204810467520704946512020375848021263220175499310526007694622282"
960                )
961                .expect("Could not transform sibling into str"),
962                BigInt::from_str(
963                    "10988861352769866873810486166013377894828418574939430507195536235545006158559"
964                )
965                .expect("Could not transform sibling into str"),
966                BigInt::from_str(
967                    "8745716775239175067716679510281198940457427271514031231047764147465936999003"
968                )
969                .expect("Could not transform sibling into str"),
970                BigInt::from_str(
971                    "10575429519408550180427558328500068421272775679345567502048077733404168359774"
972                )
973                .expect("Could not transform sibling into str"),
974                BigInt::from_str(
975                    "2497489782201357981070733885197437403126039517543044119147834407389467335082"
976                )
977                .expect("Could not transform sibling into str"),
978            ]
979        );
980        assert!(!find_result.is_old0);
981
982        // Look for a non-existing key
983        let find_result = smt.find(&BigInt::from(127u32)).expect("Find method failed");
984        assert!(!find_result.found);
985        assert_eq!(find_result.found_value, BigInt::from(0u32));
986        assert_eq!(find_result.not_found_key, BigInt::from(63u32));
987        assert_eq!(find_result.siblings.len(), 6);
988        assert_eq!(
989            find_result.siblings,
990            vec![
991                BigInt::from_str(
992                    "13574531720454277968647792690830483941675832953896828594235298772144774821296"
993                )
994                .expect("Could not transform sibling into str"),
995                BigInt::from_str(
996                    "1861627833931474771540567070469758409892599524239975114190647783254280704182"
997                )
998                .expect("Could not transform sibling into str"),
999                BigInt::from_str(
1000                    "6337427217730761905851800753670222511821931828056363511575004194996678792977"
1001                )
1002                .expect("Could not transform sibling into str"),
1003                BigInt::from_str(
1004                    "142387899434338503423141257579632358202650467916673674727273804791475103923"
1005                )
1006                .expect("Could not transform sibling into str"),
1007                BigInt::from_str(
1008                    "6499651114777582205199364701529028639517158867351868744143839420261663269505"
1009                )
1010                .expect("Could not transform sibling into str"),
1011                BigInt::from_str(
1012                    "4733877433413380505912252732407068279835546218946596975085447307151515063172"
1013                )
1014                .expect("Could not transform sibling into str"),
1015            ]
1016        );
1017        assert!(!find_result.is_old0);
1018    }
1019
1020    #[test]
1021    fn test_hash_direct() {
1022        use zkhash::{
1023            fields::bn256::FpBN256,
1024            poseidon2::{
1025                poseidon2::Poseidon2,
1026                poseidon2_instance_bn256::{POSEIDON2_BN256_PARAMS_2, POSEIDON2_BN256_PARAMS_3},
1027            },
1028        };
1029        let hash_result = poseidon2_hash3_sparse(&BigInt::from(0u32), &BigInt::from(1u32));
1030        let hash_result2 = poseidon2_compression_sparse(&BigInt::from(0u32), &BigInt::from(1u32));
1031
1032        type Scalar = FpBN256;
1033        // T = 2
1034        let poseidon2 = Poseidon2::new(&POSEIDON2_BN256_PARAMS_2);
1035        let input: Vec<Scalar> = vec![Scalar::from(0u64), Scalar::from(1u64)];
1036        let perm = poseidon2.permutation(&input);
1037
1038        assert_eq!(perm[0].to_string(), hash_result2.to_string());
1039
1040        // T = 3
1041        let poseidon2 = Poseidon2::new(&POSEIDON2_BN256_PARAMS_3);
1042        let input: Vec<Scalar> = vec![Scalar::from(0u64), Scalar::from(1u64), Scalar::from(1u64)];
1043        let perm = poseidon2.permutation(&input);
1044        assert_eq!(perm[0].to_string(), hash_result.to_string());
1045    }
1046}