1use alloc::{format, vec, vec::Vec};
7
8use wasm_bindgen::prelude::*;
9use zkhash::fields::bn256::FpBN256 as Scalar;
10
11use crate::{
12 serialization::{bytes_to_scalar, scalar_to_bytes},
13 types::FIELD_SIZE,
14};
15
16pub use circuits::core::merkle::{
18 merkle_proof as merkle_proof_internal, merkle_root, poseidon2_compression,
19};
20
21#[wasm_bindgen]
23pub struct MerkleProof {
24 path_elements: Vec<u8>,
26 path_indices: Vec<u8>,
28 root: Vec<u8>,
30 levels: usize,
32}
33
34#[wasm_bindgen]
35impl MerkleProof {
36 #[wasm_bindgen(getter)]
38 pub fn path_elements(&self) -> Vec<u8> {
39 self.path_elements.clone()
40 }
41
42 #[wasm_bindgen(getter)]
44 pub fn path_indices(&self) -> Vec<u8> {
45 self.path_indices.clone()
46 }
47
48 #[wasm_bindgen(getter)]
50 pub fn root(&self) -> Vec<u8> {
51 self.root.clone()
52 }
53
54 #[wasm_bindgen(getter)]
56 pub fn levels(&self) -> usize {
57 self.levels
58 }
59}
60
61#[wasm_bindgen]
63pub struct MerkleTree {
64 levels_data: Vec<Vec<Scalar>>,
66 depth: usize,
68 next_index: u64,
70}
71
72#[wasm_bindgen]
75impl MerkleTree {
76 #[wasm_bindgen(constructor)]
78 pub fn new(depth: usize) -> Result<MerkleTree, JsValue> {
79 Self::build_tree(depth, Scalar::from(0u64))
80 }
81
82 #[wasm_bindgen]
90 pub fn new_with_zero_leaf(depth: usize, zero_leaf_bytes: &[u8]) -> Result<MerkleTree, JsValue> {
91 let zero = bytes_to_scalar(zero_leaf_bytes)?;
92 Self::build_tree(depth, zero)
93 }
94
95 fn build_tree(depth: usize, zero: Scalar) -> Result<MerkleTree, JsValue> {
97 if depth == 0 || depth > 32 {
98 return Err(JsValue::from_str("Depth must be between 1 and 32"));
99 }
100
101 let depth_u32 = u32::try_from(depth).expect("Depth didn't fit in u32");
103 let num_leaves = 1usize.checked_shl(depth_u32).ok_or_else(|| {
104 JsValue::from_str("Depth too large for this platform, would overflow")
105 })?;
106
107 let capacity = depth
109 .checked_add(1)
110 .ok_or_else(|| JsValue::from_str("Depth overflow"))?;
111 let mut levels_data = Vec::with_capacity(capacity);
112
113 levels_data.push(vec![zero; num_leaves]);
115
116 let mut current_level_size = num_leaves;
118 let mut prev_hash = zero;
119
120 for _ in 0..depth {
121 current_level_size /= 2;
122 prev_hash = poseidon2_compression(prev_hash, prev_hash);
123 levels_data.push(vec![prev_hash; current_level_size]);
124 }
125
126 Ok(MerkleTree {
127 levels_data,
128 depth,
129 next_index: 0,
130 })
131 }
132
133 #[wasm_bindgen]
135 pub fn insert(&mut self, leaf_bytes: &[u8]) -> Result<u32, JsValue> {
136 let leaf = bytes_to_scalar(leaf_bytes)?;
137 let index = self.next_index;
138
139 let max_leaves = 1u64 << self.depth;
140 if index >= max_leaves {
141 return Err(JsValue::from_str("Merkle tree is full"));
142 }
143
144 let index_usize =
145 usize::try_from(index).map_err(|_| JsValue::from_str("Index too large"))?;
146
147 self.levels_data[0][index_usize] = leaf;
149
150 let mut current_index = index_usize;
152 let mut current_hash = leaf;
153
154 for level in 0..self.depth {
155 let sibling_index = current_index ^ 1; let sibling = self.levels_data[level][sibling_index];
157
158 let (left, right) = if current_index.is_multiple_of(2) {
160 (current_hash, sibling)
161 } else {
162 (sibling, current_hash)
163 };
164
165 current_hash = poseidon2_compression(left, right);
166 current_index /= 2;
167
168 let parent_level = level
170 .checked_add(1)
171 .ok_or_else(|| JsValue::from_str("Level overflow"))?;
172 self.levels_data[parent_level][current_index] = current_hash;
173 }
174
175 self.next_index = self
176 .next_index
177 .checked_add(1)
178 .ok_or_else(|| JsValue::from_str("Index overflow"))?;
179
180 u32::try_from(index).map_err(|_| JsValue::from_str("Index too large for u32"))
182 }
183
184 #[wasm_bindgen]
186 pub fn root(&self) -> Vec<u8> {
187 let root = self.levels_data[self.depth][0];
188 scalar_to_bytes(&root)
189 }
190
191 #[wasm_bindgen]
193 pub fn get_proof(&self, index: u32) -> Result<MerkleProof, JsValue> {
194 let index = usize::try_from(index).map_err(|_| JsValue::from_str("Index too large"))?;
195 let max_leaves = 1usize << self.depth;
196
197 if index >= max_leaves {
198 return Err(JsValue::from_str("Index out of bounds"));
199 }
200
201 let capacity = self
202 .depth
203 .checked_mul(FIELD_SIZE)
204 .ok_or_else(|| JsValue::from_str("Overflow calculating path capacity"))?;
205 let mut path_elements = Vec::with_capacity(capacity);
206 let mut path_indices_bits: u64 = 0;
207 let mut current_index = index;
208
209 for level in 0..self.depth {
210 let sibling_index = current_index ^ 1;
211 let sibling = self.levels_data[level][sibling_index];
212
213 path_elements.extend_from_slice(&scalar_to_bytes(&sibling));
215
216 if !current_index.is_multiple_of(2) {
218 path_indices_bits |= 1u64 << level;
219 }
220
221 current_index /= 2;
222 }
223
224 let path_indices = scalar_to_bytes(&Scalar::from(path_indices_bits));
225 let root = scalar_to_bytes(&self.levels_data[self.depth][0]);
226
227 Ok(MerkleProof {
228 path_elements,
229 path_indices,
230 root,
231 levels: self.depth,
232 })
233 }
234
235 #[wasm_bindgen(getter)]
237 pub fn next_index(&self) -> u64 {
238 self.next_index
239 }
240
241 #[wasm_bindgen(getter)]
243 pub fn depth(&self) -> usize {
244 self.depth
245 }
246}
247
248#[wasm_bindgen]
250pub fn compute_merkle_root(leaves_bytes: &[u8], depth: usize) -> Result<Vec<u8>, JsValue> {
251 if !leaves_bytes.len().is_multiple_of(FIELD_SIZE) {
252 return Err(JsValue::from_str("Leaves bytes must be multiple of 32"));
253 }
254
255 let num_leaves = leaves_bytes.len() / FIELD_SIZE;
256 let expected_leaves = 1usize << depth;
257
258 if num_leaves != expected_leaves {
259 return Err(JsValue::from_str(&format!(
260 "Expected {} leaves for depth {}, got {}",
261 expected_leaves, depth, num_leaves
262 )));
263 }
264
265 let mut current_level: Vec<Scalar> = Vec::with_capacity(num_leaves);
267 for i in 0..num_leaves {
268 let start = i
269 .checked_mul(FIELD_SIZE)
270 .ok_or_else(|| JsValue::from_str("Index overflow"))?;
271 let end = i
272 .checked_add(1)
273 .and_then(|n| n.checked_mul(FIELD_SIZE))
274 .ok_or_else(|| JsValue::from_str("Index overflow"))?;
275 let chunk = &leaves_bytes[start..end];
276 current_level.push(bytes_to_scalar(chunk)?);
277 }
278
279 for _ in 0..depth {
281 let mut next_level = Vec::with_capacity(current_level.len() / 2);
282 for pair in current_level.chunks(2) {
283 next_level.push(poseidon2_compression(pair[0], pair[1]));
284 }
285 current_level = next_level;
286 }
287
288 Ok(scalar_to_bytes(¤t_level[0]))
289}