prover/
sparse_merkle.rs

1//! Sparse Merkle Tree utilities for WASM (no_std compatible)
2//!
3//! Provides sparse merkle tree functionality using BTreeMap for no_std
4//! compatibility.
5//!
6//! Equivalent functionality to `circuits::test::utils::sparse_merkle_tree` in
7//! the circuit crate. But without std dependencies: Bigint and Hashmap
8//! dependencies mostly. SMT interface.
9
10use alloc::{collections::BTreeMap, vec::Vec};
11
12use wasm_bindgen::prelude::*;
13use zkhash::{ark_ff::PrimeField, fields::bn256::FpBN256 as Scalar};
14
15use crate::{
16    crypto::{poseidon2_compression, poseidon2_hash2_internal},
17    serialization::{bytes_to_scalar, scalar_to_bytes},
18};
19
20/// Poseidon2 hash for leaf nodes: Poseidon2(key, value, domain=1)
21fn poseidon2_hash_leaf(key: Scalar, value: Scalar) -> Scalar {
22    poseidon2_hash2_internal(key, value, Some(Scalar::from(1u64)))
23}
24
25/// Split a scalar into 256 bits (LSB first)
26fn scalar_to_bits(scalar: &Scalar) -> Vec<bool> {
27    let bigint = scalar.into_bigint();
28    let mut bits = Vec::with_capacity(256);
29
30    for limb in bigint.0.iter() {
31        for i in 0..64 {
32            bits.push((limb >> i) & 1 == 1);
33        }
34    }
35
36    bits.truncate(256);
37    bits
38}
39
40/// Node type in the sparse merkle tree
41#[derive(Clone, Debug)]
42enum Node {
43    /// Empty node (represents zero)
44    Empty,
45    /// Leaf node containing (key, value)
46    Leaf { key: Scalar, value: Scalar },
47    /// Internal node containing (left_child_hash, right_child_hash)
48    Internal { left: Scalar, right: Scalar },
49}
50
51/// Result of SMT find operation
52#[derive(Clone, Debug)]
53pub struct FindResult {
54    /// Whether the key was found
55    pub found: bool,
56    /// Sibling hashes along the path
57    pub siblings: Vec<Scalar>,
58    /// The found value (if found)
59    pub found_value: Scalar,
60    /// The key that was not found (for collision detection)
61    pub not_found_key: Scalar,
62    /// The value at collision (if not found)
63    pub not_found_value: Scalar,
64    /// Whether the path ended at zero
65    pub is_old0: bool,
66}
67
68/// Result of SMT operations (insert/update/delete)
69#[derive(Clone, Debug)]
70pub struct SMTResult {
71    /// The old root before the operation
72    pub old_root: Scalar,
73    /// The new root after the operation
74    pub new_root: Scalar,
75    /// Sibling hashes along the path
76    pub siblings: Vec<Scalar>,
77    /// The old key
78    pub old_key: Scalar,
79    /// The old value
80    pub old_value: Scalar,
81    /// The new key
82    pub new_key: Scalar,
83    /// The new value
84    pub new_value: Scalar,
85    /// Whether the old value was zero
86    pub is_old0: bool,
87}
88
89/// Sparse Merkle Tree using BTreeMap for no_std compatibility
90pub struct SparseMerkleTree {
91    /// Database storing nodes by their hash
92    db: BTreeMap<[u8; 32], Node>,
93    /// Current root hash
94    root: Scalar,
95}
96
97impl Default for SparseMerkleTree {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103impl SparseMerkleTree {
104    /// Create a new empty sparse merkle tree
105    pub fn new() -> Self {
106        SparseMerkleTree {
107            db: BTreeMap::new(),
108            root: Scalar::from(0u64),
109        }
110    }
111
112    /// Get the current root
113    pub fn root(&self) -> Scalar {
114        self.root
115    }
116
117    /// Convert scalar to bytes for use as BTreeMap key
118    fn scalar_to_key(s: &Scalar) -> [u8; 32] {
119        let mut key = [0u8; 32];
120        let bytes = scalar_to_bytes(s);
121        key.copy_from_slice(&bytes);
122        key
123    }
124
125    /// Get a node from the database
126    fn get_node(&self, hash: &Scalar) -> Option<&Node> {
127        if *hash == Scalar::from(0u64) {
128            return Some(&Node::Empty);
129        }
130        self.db.get(&Self::scalar_to_key(hash))
131    }
132
133    /// Store a node in the database
134    fn put_node(&mut self, hash: Scalar, node: Node) {
135        if hash != Scalar::from(0u64) {
136            self.db.insert(Self::scalar_to_key(&hash), node);
137        }
138    }
139
140    /// Find a key in the tree
141    pub fn find(&self, key: &Scalar) -> Result<FindResult, &'static str> {
142        let key_bits = scalar_to_bits(key);
143        let mut result = self.find_internal(key, &key_bits, &self.root, 0)?;
144        result.siblings.reverse();
145        Ok(result)
146    }
147
148    fn find_internal(
149        &self,
150        key: &Scalar,
151        key_bits: &[bool],
152        current_hash: &Scalar,
153        level: usize,
154    ) -> Result<FindResult, &'static str> {
155        if level >= 256 {
156            return Err("Maximum tree depth exceeded");
157        }
158
159        if *current_hash == Scalar::from(0u64) {
160            return Ok(FindResult {
161                found: false,
162                siblings: Vec::new(),
163                found_value: Scalar::from(0u64),
164                not_found_key: *key,
165                not_found_value: Scalar::from(0u64),
166                is_old0: true,
167            });
168        }
169
170        match self.get_node(current_hash) {
171            Some(Node::Leaf {
172                key: leaf_key,
173                value: leaf_value,
174            }) => {
175                if leaf_key == key {
176                    Ok(FindResult {
177                        found: true,
178                        siblings: Vec::new(),
179                        found_value: *leaf_value,
180                        not_found_key: Scalar::from(0u64),
181                        not_found_value: Scalar::from(0u64),
182                        is_old0: false,
183                    })
184                } else {
185                    Ok(FindResult {
186                        found: false,
187                        siblings: Vec::new(),
188                        found_value: Scalar::from(0u64),
189                        not_found_key: *leaf_key,
190                        not_found_value: *leaf_value,
191                        is_old0: false,
192                    })
193                }
194            }
195            Some(Node::Internal { left, right }) => {
196                let (child, sibling) = if key_bits[level] {
197                    (right, left)
198                } else {
199                    (left, right)
200                };
201
202                let next_level = level
203                    .checked_add(1)
204                    .ok_or("Level overflow in find_internal")?;
205                let mut result = self.find_internal(key, key_bits, child, next_level)?;
206                result.siblings.push(*sibling);
207                Ok(result)
208            }
209            Some(Node::Empty) => Ok(FindResult {
210                found: false,
211                siblings: Vec::new(),
212                found_value: Scalar::from(0u64),
213                not_found_key: *key,
214                not_found_value: Scalar::from(0u64),
215                is_old0: true,
216            }),
217            None => Err("Node not found in database"),
218        }
219    }
220
221    /// Insert a key-value pair
222    pub fn insert(&mut self, key: &Scalar, value: &Scalar) -> Result<SMTResult, &'static str> {
223        let find_result = self.find(key)?;
224
225        if find_result.found {
226            return Err("Key already exists");
227        }
228
229        let old_root = self.root;
230        let key_bits = scalar_to_bits(key);
231
232        // Create the new leaf
233        let new_leaf_hash = poseidon2_hash_leaf(*key, *value);
234        self.put_node(
235            new_leaf_hash,
236            Node::Leaf {
237                key: *key,
238                value: *value,
239            },
240        );
241
242        // Build the path from leaf to root
243        let mut current_hash = new_leaf_hash;
244        let mut siblings = find_result.siblings.clone();
245
246        // If there's a collision (not_found_key != 0 and is_old0 == false), we need to
247        // extend the path
248        if !find_result.is_old0 {
249            let old_key_bits = scalar_to_bits(&find_result.not_found_key);
250
251            // Find where the paths diverge
252            let mut diverge_level = siblings.len();
253            while diverge_level < 256 && old_key_bits[diverge_level] == key_bits[diverge_level] {
254                siblings.push(Scalar::from(0u64));
255                diverge_level = diverge_level.saturating_add(1);
256            }
257
258            // Add the old leaf as a sibling at the divergence point
259            let old_leaf_hash =
260                poseidon2_hash_leaf(find_result.not_found_key, find_result.not_found_value);
261            siblings.push(old_leaf_hash);
262        }
263
264        // Build path from bottom to top
265        for (level, sibling) in siblings.iter().enumerate().rev() {
266            let (left, right) = if key_bits[level] {
267                (*sibling, current_hash)
268            } else {
269                (current_hash, *sibling)
270            };
271
272            current_hash = poseidon2_compression(left, right);
273            self.put_node(current_hash, Node::Internal { left, right });
274        }
275
276        self.root = current_hash;
277
278        // Trim trailing zeros from siblings for the result
279        let mut result_siblings = siblings;
280        while result_siblings.last() == Some(&Scalar::from(0u64)) {
281            result_siblings.pop();
282        }
283        // Remove the collision leaf if we added one
284        if !find_result.is_old0 && !result_siblings.is_empty() {
285            result_siblings.pop();
286        }
287
288        Ok(SMTResult {
289            old_root,
290            new_root: self.root,
291            siblings: result_siblings,
292            old_key: find_result.not_found_key,
293            old_value: find_result.not_found_value,
294            new_key: *key,
295            new_value: *value,
296            is_old0: find_result.is_old0,
297        })
298    }
299
300    /// Update a key's value
301    pub fn update(&mut self, key: &Scalar, new_value: &Scalar) -> Result<SMTResult, &'static str> {
302        let find_result = self.find(key)?;
303
304        if !find_result.found {
305            return Err("Key does not exist");
306        }
307
308        let old_root = self.root;
309        let old_value = find_result.found_value;
310        let key_bits = scalar_to_bits(key);
311
312        // Create the new leaf
313        let new_leaf_hash = poseidon2_hash_leaf(*key, *new_value);
314        self.put_node(
315            new_leaf_hash,
316            Node::Leaf {
317                key: *key,
318                value: *new_value,
319            },
320        );
321
322        // Build path from bottom to top
323        let mut current_hash = new_leaf_hash;
324        for (level, sibling) in find_result.siblings.iter().enumerate().rev() {
325            let (left, right) = if key_bits[level] {
326                (*sibling, current_hash)
327            } else {
328                (current_hash, *sibling)
329            };
330
331            current_hash = poseidon2_compression(left, right);
332            self.put_node(current_hash, Node::Internal { left, right });
333        }
334
335        self.root = current_hash;
336
337        Ok(SMTResult {
338            old_root,
339            new_root: self.root,
340            siblings: find_result.siblings,
341            old_key: *key,
342            old_value,
343            new_key: *key,
344            new_value: *new_value,
345            is_old0: false,
346        })
347    }
348}
349
350/// WASM-friendly Sparse Merkle Tree wrapper
351#[wasm_bindgen]
352pub struct WasmSparseMerkleTree {
353    inner: SparseMerkleTree,
354}
355
356#[wasm_bindgen]
357impl WasmSparseMerkleTree {
358    /// Create a new empty sparse merkle tree
359    #[wasm_bindgen(constructor)]
360    pub fn new() -> WasmSparseMerkleTree {
361        WasmSparseMerkleTree {
362            inner: SparseMerkleTree::new(),
363        }
364    }
365
366    /// Get the current root as bytes (32 bytes, Little-Endian)
367    #[wasm_bindgen]
368    pub fn root(&self) -> Vec<u8> {
369        scalar_to_bytes(&self.inner.root())
370    }
371
372    /// Insert a key-value pair into the tree
373    ///
374    /// # Arguments
375    /// * `key_bytes` - Key as 32 bytes (Little-Endian)
376    /// * `value_bytes` - Value as 32 bytes (Little-Endian)
377    #[wasm_bindgen]
378    pub fn insert(
379        &mut self,
380        key_bytes: &[u8],
381        value_bytes: &[u8],
382    ) -> Result<WasmSMTResult, JsValue> {
383        let key = bytes_to_scalar(key_bytes)?;
384        let value = bytes_to_scalar(value_bytes)?;
385
386        let result = self.inner.insert(&key, &value).map_err(JsValue::from_str)?;
387
388        Ok(WasmSMTResult::from_result(&result))
389    }
390
391    /// Update a key's value in the tree
392    #[wasm_bindgen]
393    pub fn update(
394        &mut self,
395        key_bytes: &[u8],
396        new_value_bytes: &[u8],
397    ) -> Result<WasmSMTResult, JsValue> {
398        let key = bytes_to_scalar(key_bytes)?;
399        let new_value = bytes_to_scalar(new_value_bytes)?;
400
401        let result = self
402            .inner
403            .update(&key, &new_value)
404            .map_err(JsValue::from_str)?;
405
406        Ok(WasmSMTResult::from_result(&result))
407    }
408
409    /// Find a key in the tree and get a membership/non-membership proof
410    #[wasm_bindgen]
411    pub fn find(&self, key_bytes: &[u8]) -> Result<WasmFindResult, JsValue> {
412        let key = bytes_to_scalar(key_bytes)?;
413
414        let result = self.inner.find(&key).map_err(JsValue::from_str)?;
415
416        Ok(WasmFindResult::from_result(&result, &self.inner.root()))
417    }
418
419    /// Get a proof for a key, padded to max_levels
420    #[wasm_bindgen]
421    pub fn get_proof(&self, key_bytes: &[u8], max_levels: usize) -> Result<WasmSMTProof, JsValue> {
422        let key = bytes_to_scalar(key_bytes)?;
423
424        let find_result = self.inner.find(&key).map_err(JsValue::from_str)?;
425
426        // Pad siblings to max_levels
427        let mut siblings = find_result.siblings.clone();
428        while siblings.len() < max_levels {
429            siblings.push(Scalar::from(0u64));
430        }
431
432        Ok(WasmSMTProof {
433            found: find_result.found,
434            siblings: siblings.iter().flat_map(scalar_to_bytes).collect(),
435            found_value: scalar_to_bytes(&find_result.found_value),
436            not_found_key: scalar_to_bytes(&find_result.not_found_key),
437            not_found_value: scalar_to_bytes(&find_result.not_found_value),
438            is_old0: find_result.is_old0,
439            root: scalar_to_bytes(&self.inner.root()),
440            num_siblings: siblings.len(),
441        })
442    }
443}
444
445impl Default for WasmSparseMerkleTree {
446    fn default() -> Self {
447        Self::new()
448    }
449}
450
451/// Result of SMT operations (insert/update/delete)
452#[wasm_bindgen]
453pub struct WasmSMTResult {
454    old_root: Vec<u8>,
455    new_root: Vec<u8>,
456    siblings: Vec<u8>,
457    old_key: Vec<u8>,
458    old_value: Vec<u8>,
459    new_key: Vec<u8>,
460    new_value: Vec<u8>,
461    is_old0: bool,
462    num_siblings: usize,
463}
464
465#[wasm_bindgen]
466impl WasmSMTResult {
467    /// Get the old root before the operation
468    #[wasm_bindgen(getter)]
469    pub fn old_root(&self) -> Vec<u8> {
470        self.old_root.clone()
471    }
472
473    /// Get the new root after the operation
474    #[wasm_bindgen(getter)]
475    pub fn new_root(&self) -> Vec<u8> {
476        self.new_root.clone()
477    }
478
479    /// Get siblings as flat bytes
480    #[wasm_bindgen(getter)]
481    pub fn siblings(&self) -> Vec<u8> {
482        self.siblings.clone()
483    }
484
485    /// Get number of siblings
486    #[wasm_bindgen(getter)]
487    pub fn num_siblings(&self) -> usize {
488        self.num_siblings
489    }
490
491    /// Get the old key
492    #[wasm_bindgen(getter)]
493    pub fn old_key(&self) -> Vec<u8> {
494        self.old_key.clone()
495    }
496
497    /// Get the old value
498    #[wasm_bindgen(getter)]
499    pub fn old_value(&self) -> Vec<u8> {
500        self.old_value.clone()
501    }
502
503    /// Get the new key
504    #[wasm_bindgen(getter)]
505    pub fn new_key(&self) -> Vec<u8> {
506        self.new_key.clone()
507    }
508
509    /// Get the new value
510    #[wasm_bindgen(getter)]
511    pub fn new_value(&self) -> Vec<u8> {
512        self.new_value.clone()
513    }
514
515    /// Whether old value was zero
516    #[wasm_bindgen(getter)]
517    pub fn is_old0(&self) -> bool {
518        self.is_old0
519    }
520}
521
522impl WasmSMTResult {
523    fn from_result(r: &SMTResult) -> Self {
524        WasmSMTResult {
525            old_root: scalar_to_bytes(&r.old_root),
526            new_root: scalar_to_bytes(&r.new_root),
527            siblings: r.siblings.iter().flat_map(scalar_to_bytes).collect(),
528            old_key: scalar_to_bytes(&r.old_key),
529            old_value: scalar_to_bytes(&r.old_value),
530            new_key: scalar_to_bytes(&r.new_key),
531            new_value: scalar_to_bytes(&r.new_value),
532            is_old0: r.is_old0,
533            num_siblings: r.siblings.len(),
534        }
535    }
536}
537
538/// Result of SMT find operation
539#[wasm_bindgen]
540pub struct WasmFindResult {
541    found: bool,
542    siblings: Vec<u8>,
543    found_value: Vec<u8>,
544    not_found_key: Vec<u8>,
545    not_found_value: Vec<u8>,
546    is_old0: bool,
547    root: Vec<u8>,
548    num_siblings: usize,
549}
550
551#[wasm_bindgen]
552impl WasmFindResult {
553    /// Whether the key was found
554    #[wasm_bindgen(getter)]
555    pub fn found(&self) -> bool {
556        self.found
557    }
558
559    /// Get siblings as flat bytes
560    #[wasm_bindgen(getter)]
561    pub fn siblings(&self) -> Vec<u8> {
562        self.siblings.clone()
563    }
564
565    /// Get number of siblings
566    #[wasm_bindgen(getter)]
567    pub fn num_siblings(&self) -> usize {
568        self.num_siblings
569    }
570
571    /// Get found value (if found)
572    #[wasm_bindgen(getter)]
573    pub fn found_value(&self) -> Vec<u8> {
574        self.found_value.clone()
575    }
576
577    /// Get the key that was found at collision (if not found)
578    #[wasm_bindgen(getter)]
579    pub fn not_found_key(&self) -> Vec<u8> {
580        self.not_found_key.clone()
581    }
582
583    /// Get the value at collision (if not found)
584    #[wasm_bindgen(getter)]
585    pub fn not_found_value(&self) -> Vec<u8> {
586        self.not_found_value.clone()
587    }
588
589    /// Whether the path ended at zero
590    #[wasm_bindgen(getter)]
591    pub fn is_old0(&self) -> bool {
592        self.is_old0
593    }
594
595    /// Get the current root
596    #[wasm_bindgen(getter)]
597    pub fn root(&self) -> Vec<u8> {
598        self.root.clone()
599    }
600}
601
602impl WasmFindResult {
603    fn from_result(r: &FindResult, root: &Scalar) -> Self {
604        WasmFindResult {
605            found: r.found,
606            siblings: r.siblings.iter().flat_map(scalar_to_bytes).collect(),
607            found_value: scalar_to_bytes(&r.found_value),
608            not_found_key: scalar_to_bytes(&r.not_found_key),
609            not_found_value: scalar_to_bytes(&r.not_found_value),
610            is_old0: r.is_old0,
611            root: scalar_to_bytes(root),
612            num_siblings: r.siblings.len(),
613        }
614    }
615}
616
617/// SMT Proof for circuit inputs
618#[wasm_bindgen]
619pub struct WasmSMTProof {
620    found: bool,
621    siblings: Vec<u8>,
622    found_value: Vec<u8>,
623    not_found_key: Vec<u8>,
624    not_found_value: Vec<u8>,
625    is_old0: bool,
626    root: Vec<u8>,
627    num_siblings: usize,
628}
629
630#[wasm_bindgen]
631impl WasmSMTProof {
632    /// Whether the key was found
633    #[wasm_bindgen(getter)]
634    pub fn found(&self) -> bool {
635        self.found
636    }
637
638    /// Get siblings as flat bytes (padded to max_levels)
639    #[wasm_bindgen(getter)]
640    pub fn siblings(&self) -> Vec<u8> {
641        self.siblings.clone()
642    }
643
644    /// Get number of siblings
645    #[wasm_bindgen(getter)]
646    pub fn num_siblings(&self) -> usize {
647        self.num_siblings
648    }
649
650    /// Get found value
651    #[wasm_bindgen(getter)]
652    pub fn found_value(&self) -> Vec<u8> {
653        self.found_value.clone()
654    }
655
656    /// Get not found key
657    #[wasm_bindgen(getter)]
658    pub fn not_found_key(&self) -> Vec<u8> {
659        self.not_found_key.clone()
660    }
661
662    /// Get not found value
663    #[wasm_bindgen(getter)]
664    pub fn not_found_value(&self) -> Vec<u8> {
665        self.not_found_value.clone()
666    }
667
668    /// Whether old value was zero
669    #[wasm_bindgen(getter)]
670    pub fn is_old0(&self) -> bool {
671        self.is_old0
672    }
673
674    /// Get root
675    #[wasm_bindgen(getter)]
676    pub fn root(&self) -> Vec<u8> {
677        self.root.clone()
678    }
679}
680
681/// Compute Poseidon2 compression hash of two field elements
682#[wasm_bindgen]
683pub fn smt_hash_pair(left: &[u8], right: &[u8]) -> Result<Vec<u8>, JsValue> {
684    let l = bytes_to_scalar(left)?;
685    let r = bytes_to_scalar(right)?;
686    let result = poseidon2_compression(l, r);
687    Ok(scalar_to_bytes(&result))
688}
689
690/// Compute Poseidon2 hash for leaf nodes: hash(key, value, 1)
691#[wasm_bindgen]
692pub fn smt_hash_leaf(key: &[u8], value: &[u8]) -> Result<Vec<u8>, JsValue> {
693    let k = bytes_to_scalar(key)?;
694    let v = bytes_to_scalar(value)?;
695    let result = poseidon2_hash_leaf(k, v);
696    Ok(scalar_to_bytes(&result))
697}