prover/
merkle.rs

1//! Merkle tree utilities for proof generation
2//!
3//! Provides merkle tree operations matching the Circom circuit implementations.
4//! Core merkle functions are re-exported from `circuits::core::merkle`.
5
6use alloc::{format, vec, vec::Vec};
7
8use wasm_bindgen::prelude::*;
9use zkhash::fields::bn256::FpBN256 as Scalar;
10
11use crate::{
12    serialization::{bytes_to_scalar, scalar_to_bytes},
13    types::FIELD_SIZE,
14};
15
16// Re-export core merkle functions from circuits
17pub use circuits::core::merkle::{
18    merkle_proof as merkle_proof_internal, merkle_root, poseidon2_compression,
19};
20
21/// Merkle proof data returned to JavaScript
22#[wasm_bindgen]
23pub struct MerkleProof {
24    /// Path elements
25    path_elements: Vec<u8>,
26    /// Path indices as a single scalar
27    path_indices: Vec<u8>,
28    /// Computed root
29    root: Vec<u8>,
30    /// Number of levels
31    levels: usize,
32}
33
34#[wasm_bindgen]
35impl MerkleProof {
36    /// Get path elements as flat bytes (levels * 32 bytes)
37    #[wasm_bindgen(getter)]
38    pub fn path_elements(&self) -> Vec<u8> {
39        self.path_elements.clone()
40    }
41
42    /// Get path indices as bytes (32 bytes)
43    #[wasm_bindgen(getter)]
44    pub fn path_indices(&self) -> Vec<u8> {
45        self.path_indices.clone()
46    }
47
48    /// Get computed root as bytes (32 bytes)
49    #[wasm_bindgen(getter)]
50    pub fn root(&self) -> Vec<u8> {
51        self.root.clone()
52    }
53
54    /// Get number of levels
55    #[wasm_bindgen(getter)]
56    pub fn levels(&self) -> usize {
57        self.levels
58    }
59}
60
61/// Simple Merkle tree for proof generation
62#[wasm_bindgen]
63pub struct MerkleTree {
64    /// Tree levels
65    levels_data: Vec<Vec<Scalar>>,
66    /// Number of levels
67    depth: usize,
68    /// Next leaf index to insert
69    next_index: u64,
70}
71
72// TODO: For now we implement a full merkle tree. We should study if a partial
73// merkle tree is enough. To minimize storage on user side
74#[wasm_bindgen]
75impl MerkleTree {
76    /// Create a new Merkle tree with given depth and default zero leaf (0)
77    #[wasm_bindgen(constructor)]
78    pub fn new(depth: usize) -> Result<MerkleTree, JsValue> {
79        Self::build_tree(depth, Scalar::from(0u64))
80    }
81
82    /// Create a new Merkle tree with a custom zero leaf value.
83    /// This allows matching contract implementations that use non-zero empty
84    /// leaves (e.g., poseidon2("XLM") as the zero value).
85    ///
86    /// # Arguments
87    /// * `depth` - Tree depth (1-32)
88    /// * `zero_leaf_bytes` - Custom zero leaf value as 32 bytes (Little-Endian)
89    #[wasm_bindgen]
90    pub fn new_with_zero_leaf(depth: usize, zero_leaf_bytes: &[u8]) -> Result<MerkleTree, JsValue> {
91        let zero = bytes_to_scalar(zero_leaf_bytes)?;
92        Self::build_tree(depth, zero)
93    }
94
95    /// Internal helper to build the tree with a given zero value
96    fn build_tree(depth: usize, zero: Scalar) -> Result<MerkleTree, JsValue> {
97        if depth == 0 || depth > 32 {
98            return Err(JsValue::from_str("Depth must be between 1 and 32"));
99        }
100
101        // Use checked shift to avoid overflow
102        let depth_u32 = u32::try_from(depth).expect("Depth didn't fit in u32");
103        let num_leaves = 1usize.checked_shl(depth_u32).ok_or_else(|| {
104            JsValue::from_str("Depth too large for this platform, would overflow")
105        })?;
106
107        // Initialize all levels with zeros
108        let capacity = depth
109            .checked_add(1)
110            .ok_or_else(|| JsValue::from_str("Depth overflow"))?;
111        let mut levels_data = Vec::with_capacity(capacity);
112
113        // Leaves at level 0
114        levels_data.push(vec![zero; num_leaves]);
115
116        // Build empty tree by hashing up
117        let mut current_level_size = num_leaves;
118        let mut prev_hash = zero;
119
120        for _ in 0..depth {
121            current_level_size /= 2;
122            prev_hash = poseidon2_compression(prev_hash, prev_hash);
123            levels_data.push(vec![prev_hash; current_level_size]);
124        }
125
126        Ok(MerkleTree {
127            levels_data,
128            depth,
129            next_index: 0,
130        })
131    }
132
133    /// Insert a leaf and return its index
134    #[wasm_bindgen]
135    pub fn insert(&mut self, leaf_bytes: &[u8]) -> Result<u32, JsValue> {
136        let leaf = bytes_to_scalar(leaf_bytes)?;
137        let index = self.next_index;
138
139        let max_leaves = 1u64 << self.depth;
140        if index >= max_leaves {
141            return Err(JsValue::from_str("Merkle tree is full"));
142        }
143
144        let index_usize =
145            usize::try_from(index).map_err(|_| JsValue::from_str("Index too large"))?;
146
147        // Insert leaf
148        self.levels_data[0][index_usize] = leaf;
149
150        // Update path to root
151        let mut current_index = index_usize;
152        let mut current_hash = leaf;
153
154        for level in 0..self.depth {
155            let sibling_index = current_index ^ 1; // Toggle last bit to get sibling
156            let sibling = self.levels_data[level][sibling_index];
157
158            // Compute parent hash
159            let (left, right) = if current_index.is_multiple_of(2) {
160                (current_hash, sibling)
161            } else {
162                (sibling, current_hash)
163            };
164
165            current_hash = poseidon2_compression(left, right);
166            current_index /= 2;
167
168            // Update parent level
169            let parent_level = level
170                .checked_add(1)
171                .ok_or_else(|| JsValue::from_str("Level overflow"))?;
172            self.levels_data[parent_level][current_index] = current_hash;
173        }
174
175        self.next_index = self
176            .next_index
177            .checked_add(1)
178            .ok_or_else(|| JsValue::from_str("Index overflow"))?;
179
180        // index is bounded by max_leaves (1 << depth where depth <= 32)
181        u32::try_from(index).map_err(|_| JsValue::from_str("Index too large for u32"))
182    }
183
184    /// Get the current root
185    #[wasm_bindgen]
186    pub fn root(&self) -> Vec<u8> {
187        let root = self.levels_data[self.depth][0];
188        scalar_to_bytes(&root)
189    }
190
191    /// Get merkle proof for a leaf at given index
192    #[wasm_bindgen]
193    pub fn get_proof(&self, index: u32) -> Result<MerkleProof, JsValue> {
194        let index = usize::try_from(index).map_err(|_| JsValue::from_str("Index too large"))?;
195        let max_leaves = 1usize << self.depth;
196
197        if index >= max_leaves {
198            return Err(JsValue::from_str("Index out of bounds"));
199        }
200
201        let capacity = self
202            .depth
203            .checked_mul(FIELD_SIZE)
204            .ok_or_else(|| JsValue::from_str("Overflow calculating path capacity"))?;
205        let mut path_elements = Vec::with_capacity(capacity);
206        let mut path_indices_bits: u64 = 0;
207        let mut current_index = index;
208
209        for level in 0..self.depth {
210            let sibling_index = current_index ^ 1;
211            let sibling = self.levels_data[level][sibling_index];
212
213            // Add sibling to path
214            path_elements.extend_from_slice(&scalar_to_bytes(&sibling));
215
216            // Record direction (0 = left, 1 = right)
217            if !current_index.is_multiple_of(2) {
218                path_indices_bits |= 1u64 << level;
219            }
220
221            current_index /= 2;
222        }
223
224        let path_indices = scalar_to_bytes(&Scalar::from(path_indices_bits));
225        let root = scalar_to_bytes(&self.levels_data[self.depth][0]);
226
227        Ok(MerkleProof {
228            path_elements,
229            path_indices,
230            root,
231            levels: self.depth,
232        })
233    }
234
235    /// Get the next available leaf index
236    #[wasm_bindgen(getter)]
237    pub fn next_index(&self) -> u64 {
238        self.next_index
239    }
240
241    /// Get tree depth
242    #[wasm_bindgen(getter)]
243    pub fn depth(&self) -> usize {
244        self.depth
245    }
246}
247
248/// Compute merkle root from leaves
249#[wasm_bindgen]
250pub fn compute_merkle_root(leaves_bytes: &[u8], depth: usize) -> Result<Vec<u8>, JsValue> {
251    if !leaves_bytes.len().is_multiple_of(FIELD_SIZE) {
252        return Err(JsValue::from_str("Leaves bytes must be multiple of 32"));
253    }
254
255    let num_leaves = leaves_bytes.len() / FIELD_SIZE;
256    let expected_leaves = 1usize << depth;
257
258    if num_leaves != expected_leaves {
259        return Err(JsValue::from_str(&format!(
260            "Expected {} leaves for depth {}, got {}",
261            expected_leaves, depth, num_leaves
262        )));
263    }
264
265    // Parse leaves
266    let mut current_level: Vec<Scalar> = Vec::with_capacity(num_leaves);
267    for i in 0..num_leaves {
268        let start = i
269            .checked_mul(FIELD_SIZE)
270            .ok_or_else(|| JsValue::from_str("Index overflow"))?;
271        let end = i
272            .checked_add(1)
273            .and_then(|n| n.checked_mul(FIELD_SIZE))
274            .ok_or_else(|| JsValue::from_str("Index overflow"))?;
275        let chunk = &leaves_bytes[start..end];
276        current_level.push(bytes_to_scalar(chunk)?);
277    }
278
279    // Hash up the tree
280    for _ in 0..depth {
281        let mut next_level = Vec::with_capacity(current_level.len() / 2);
282        for pair in current_level.chunks(2) {
283            next_level.push(poseidon2_compression(pair[0], pair[1]));
284        }
285        current_level = next_level;
286    }
287
288    Ok(scalar_to_bytes(&current_level[0]))
289}