1use alloc::{collections::BTreeMap, vec::Vec};
11
12use wasm_bindgen::prelude::*;
13use zkhash::{ark_ff::PrimeField, fields::bn256::FpBN256 as Scalar};
14
15use crate::{
16 crypto::{poseidon2_compression, poseidon2_hash2_internal},
17 serialization::{bytes_to_scalar, scalar_to_bytes},
18};
19
20fn poseidon2_hash_leaf(key: Scalar, value: Scalar) -> Scalar {
22 poseidon2_hash2_internal(key, value, Some(Scalar::from(1u64)))
23}
24
25fn scalar_to_bits(scalar: &Scalar) -> Vec<bool> {
27 let bigint = scalar.into_bigint();
28 let mut bits = Vec::with_capacity(256);
29
30 for limb in bigint.0.iter() {
31 for i in 0..64 {
32 bits.push((limb >> i) & 1 == 1);
33 }
34 }
35
36 bits.truncate(256);
37 bits
38}
39
40#[derive(Clone, Debug)]
42enum Node {
43 Empty,
45 Leaf { key: Scalar, value: Scalar },
47 Internal { left: Scalar, right: Scalar },
49}
50
51#[derive(Clone, Debug)]
53pub struct FindResult {
54 pub found: bool,
56 pub siblings: Vec<Scalar>,
58 pub found_value: Scalar,
60 pub not_found_key: Scalar,
62 pub not_found_value: Scalar,
64 pub is_old0: bool,
66}
67
68#[derive(Clone, Debug)]
70pub struct SMTResult {
71 pub old_root: Scalar,
73 pub new_root: Scalar,
75 pub siblings: Vec<Scalar>,
77 pub old_key: Scalar,
79 pub old_value: Scalar,
81 pub new_key: Scalar,
83 pub new_value: Scalar,
85 pub is_old0: bool,
87}
88
89pub struct SparseMerkleTree {
91 db: BTreeMap<[u8; 32], Node>,
93 root: Scalar,
95}
96
97impl Default for SparseMerkleTree {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103impl SparseMerkleTree {
104 pub fn new() -> Self {
106 SparseMerkleTree {
107 db: BTreeMap::new(),
108 root: Scalar::from(0u64),
109 }
110 }
111
112 pub fn root(&self) -> Scalar {
114 self.root
115 }
116
117 fn scalar_to_key(s: &Scalar) -> [u8; 32] {
119 let mut key = [0u8; 32];
120 let bytes = scalar_to_bytes(s);
121 key.copy_from_slice(&bytes);
122 key
123 }
124
125 fn get_node(&self, hash: &Scalar) -> Option<&Node> {
127 if *hash == Scalar::from(0u64) {
128 return Some(&Node::Empty);
129 }
130 self.db.get(&Self::scalar_to_key(hash))
131 }
132
133 fn put_node(&mut self, hash: Scalar, node: Node) {
135 if hash != Scalar::from(0u64) {
136 self.db.insert(Self::scalar_to_key(&hash), node);
137 }
138 }
139
140 pub fn find(&self, key: &Scalar) -> Result<FindResult, &'static str> {
142 let key_bits = scalar_to_bits(key);
143 let mut result = self.find_internal(key, &key_bits, &self.root, 0)?;
144 result.siblings.reverse();
145 Ok(result)
146 }
147
148 fn find_internal(
149 &self,
150 key: &Scalar,
151 key_bits: &[bool],
152 current_hash: &Scalar,
153 level: usize,
154 ) -> Result<FindResult, &'static str> {
155 if level >= 256 {
156 return Err("Maximum tree depth exceeded");
157 }
158
159 if *current_hash == Scalar::from(0u64) {
160 return Ok(FindResult {
161 found: false,
162 siblings: Vec::new(),
163 found_value: Scalar::from(0u64),
164 not_found_key: *key,
165 not_found_value: Scalar::from(0u64),
166 is_old0: true,
167 });
168 }
169
170 match self.get_node(current_hash) {
171 Some(Node::Leaf {
172 key: leaf_key,
173 value: leaf_value,
174 }) => {
175 if leaf_key == key {
176 Ok(FindResult {
177 found: true,
178 siblings: Vec::new(),
179 found_value: *leaf_value,
180 not_found_key: Scalar::from(0u64),
181 not_found_value: Scalar::from(0u64),
182 is_old0: false,
183 })
184 } else {
185 Ok(FindResult {
186 found: false,
187 siblings: Vec::new(),
188 found_value: Scalar::from(0u64),
189 not_found_key: *leaf_key,
190 not_found_value: *leaf_value,
191 is_old0: false,
192 })
193 }
194 }
195 Some(Node::Internal { left, right }) => {
196 let (child, sibling) = if key_bits[level] {
197 (right, left)
198 } else {
199 (left, right)
200 };
201
202 let next_level = level
203 .checked_add(1)
204 .ok_or("Level overflow in find_internal")?;
205 let mut result = self.find_internal(key, key_bits, child, next_level)?;
206 result.siblings.push(*sibling);
207 Ok(result)
208 }
209 Some(Node::Empty) => Ok(FindResult {
210 found: false,
211 siblings: Vec::new(),
212 found_value: Scalar::from(0u64),
213 not_found_key: *key,
214 not_found_value: Scalar::from(0u64),
215 is_old0: true,
216 }),
217 None => Err("Node not found in database"),
218 }
219 }
220
221 pub fn insert(&mut self, key: &Scalar, value: &Scalar) -> Result<SMTResult, &'static str> {
223 let find_result = self.find(key)?;
224
225 if find_result.found {
226 return Err("Key already exists");
227 }
228
229 let old_root = self.root;
230 let key_bits = scalar_to_bits(key);
231
232 let new_leaf_hash = poseidon2_hash_leaf(*key, *value);
234 self.put_node(
235 new_leaf_hash,
236 Node::Leaf {
237 key: *key,
238 value: *value,
239 },
240 );
241
242 let mut current_hash = new_leaf_hash;
244 let mut siblings = find_result.siblings.clone();
245
246 if !find_result.is_old0 {
249 let old_key_bits = scalar_to_bits(&find_result.not_found_key);
250
251 let mut diverge_level = siblings.len();
253 while diverge_level < 256 && old_key_bits[diverge_level] == key_bits[diverge_level] {
254 siblings.push(Scalar::from(0u64));
255 diverge_level = diverge_level.saturating_add(1);
256 }
257
258 let old_leaf_hash =
260 poseidon2_hash_leaf(find_result.not_found_key, find_result.not_found_value);
261 siblings.push(old_leaf_hash);
262 }
263
264 for (level, sibling) in siblings.iter().enumerate().rev() {
266 let (left, right) = if key_bits[level] {
267 (*sibling, current_hash)
268 } else {
269 (current_hash, *sibling)
270 };
271
272 current_hash = poseidon2_compression(left, right);
273 self.put_node(current_hash, Node::Internal { left, right });
274 }
275
276 self.root = current_hash;
277
278 let mut result_siblings = siblings;
280 while result_siblings.last() == Some(&Scalar::from(0u64)) {
281 result_siblings.pop();
282 }
283 if !find_result.is_old0 && !result_siblings.is_empty() {
285 result_siblings.pop();
286 }
287
288 Ok(SMTResult {
289 old_root,
290 new_root: self.root,
291 siblings: result_siblings,
292 old_key: find_result.not_found_key,
293 old_value: find_result.not_found_value,
294 new_key: *key,
295 new_value: *value,
296 is_old0: find_result.is_old0,
297 })
298 }
299
300 pub fn update(&mut self, key: &Scalar, new_value: &Scalar) -> Result<SMTResult, &'static str> {
302 let find_result = self.find(key)?;
303
304 if !find_result.found {
305 return Err("Key does not exist");
306 }
307
308 let old_root = self.root;
309 let old_value = find_result.found_value;
310 let key_bits = scalar_to_bits(key);
311
312 let new_leaf_hash = poseidon2_hash_leaf(*key, *new_value);
314 self.put_node(
315 new_leaf_hash,
316 Node::Leaf {
317 key: *key,
318 value: *new_value,
319 },
320 );
321
322 let mut current_hash = new_leaf_hash;
324 for (level, sibling) in find_result.siblings.iter().enumerate().rev() {
325 let (left, right) = if key_bits[level] {
326 (*sibling, current_hash)
327 } else {
328 (current_hash, *sibling)
329 };
330
331 current_hash = poseidon2_compression(left, right);
332 self.put_node(current_hash, Node::Internal { left, right });
333 }
334
335 self.root = current_hash;
336
337 Ok(SMTResult {
338 old_root,
339 new_root: self.root,
340 siblings: find_result.siblings,
341 old_key: *key,
342 old_value,
343 new_key: *key,
344 new_value: *new_value,
345 is_old0: false,
346 })
347 }
348}
349
350#[wasm_bindgen]
352pub struct WasmSparseMerkleTree {
353 inner: SparseMerkleTree,
354}
355
356#[wasm_bindgen]
357impl WasmSparseMerkleTree {
358 #[wasm_bindgen(constructor)]
360 pub fn new() -> WasmSparseMerkleTree {
361 WasmSparseMerkleTree {
362 inner: SparseMerkleTree::new(),
363 }
364 }
365
366 #[wasm_bindgen]
368 pub fn root(&self) -> Vec<u8> {
369 scalar_to_bytes(&self.inner.root())
370 }
371
372 #[wasm_bindgen]
378 pub fn insert(
379 &mut self,
380 key_bytes: &[u8],
381 value_bytes: &[u8],
382 ) -> Result<WasmSMTResult, JsValue> {
383 let key = bytes_to_scalar(key_bytes)?;
384 let value = bytes_to_scalar(value_bytes)?;
385
386 let result = self.inner.insert(&key, &value).map_err(JsValue::from_str)?;
387
388 Ok(WasmSMTResult::from_result(&result))
389 }
390
391 #[wasm_bindgen]
393 pub fn update(
394 &mut self,
395 key_bytes: &[u8],
396 new_value_bytes: &[u8],
397 ) -> Result<WasmSMTResult, JsValue> {
398 let key = bytes_to_scalar(key_bytes)?;
399 let new_value = bytes_to_scalar(new_value_bytes)?;
400
401 let result = self
402 .inner
403 .update(&key, &new_value)
404 .map_err(JsValue::from_str)?;
405
406 Ok(WasmSMTResult::from_result(&result))
407 }
408
409 #[wasm_bindgen]
411 pub fn find(&self, key_bytes: &[u8]) -> Result<WasmFindResult, JsValue> {
412 let key = bytes_to_scalar(key_bytes)?;
413
414 let result = self.inner.find(&key).map_err(JsValue::from_str)?;
415
416 Ok(WasmFindResult::from_result(&result, &self.inner.root()))
417 }
418
419 #[wasm_bindgen]
421 pub fn get_proof(&self, key_bytes: &[u8], max_levels: usize) -> Result<WasmSMTProof, JsValue> {
422 let key = bytes_to_scalar(key_bytes)?;
423
424 let find_result = self.inner.find(&key).map_err(JsValue::from_str)?;
425
426 let mut siblings = find_result.siblings.clone();
428 while siblings.len() < max_levels {
429 siblings.push(Scalar::from(0u64));
430 }
431
432 Ok(WasmSMTProof {
433 found: find_result.found,
434 siblings: siblings.iter().flat_map(scalar_to_bytes).collect(),
435 found_value: scalar_to_bytes(&find_result.found_value),
436 not_found_key: scalar_to_bytes(&find_result.not_found_key),
437 not_found_value: scalar_to_bytes(&find_result.not_found_value),
438 is_old0: find_result.is_old0,
439 root: scalar_to_bytes(&self.inner.root()),
440 num_siblings: siblings.len(),
441 })
442 }
443}
444
445impl Default for WasmSparseMerkleTree {
446 fn default() -> Self {
447 Self::new()
448 }
449}
450
451#[wasm_bindgen]
453pub struct WasmSMTResult {
454 old_root: Vec<u8>,
455 new_root: Vec<u8>,
456 siblings: Vec<u8>,
457 old_key: Vec<u8>,
458 old_value: Vec<u8>,
459 new_key: Vec<u8>,
460 new_value: Vec<u8>,
461 is_old0: bool,
462 num_siblings: usize,
463}
464
465#[wasm_bindgen]
466impl WasmSMTResult {
467 #[wasm_bindgen(getter)]
469 pub fn old_root(&self) -> Vec<u8> {
470 self.old_root.clone()
471 }
472
473 #[wasm_bindgen(getter)]
475 pub fn new_root(&self) -> Vec<u8> {
476 self.new_root.clone()
477 }
478
479 #[wasm_bindgen(getter)]
481 pub fn siblings(&self) -> Vec<u8> {
482 self.siblings.clone()
483 }
484
485 #[wasm_bindgen(getter)]
487 pub fn num_siblings(&self) -> usize {
488 self.num_siblings
489 }
490
491 #[wasm_bindgen(getter)]
493 pub fn old_key(&self) -> Vec<u8> {
494 self.old_key.clone()
495 }
496
497 #[wasm_bindgen(getter)]
499 pub fn old_value(&self) -> Vec<u8> {
500 self.old_value.clone()
501 }
502
503 #[wasm_bindgen(getter)]
505 pub fn new_key(&self) -> Vec<u8> {
506 self.new_key.clone()
507 }
508
509 #[wasm_bindgen(getter)]
511 pub fn new_value(&self) -> Vec<u8> {
512 self.new_value.clone()
513 }
514
515 #[wasm_bindgen(getter)]
517 pub fn is_old0(&self) -> bool {
518 self.is_old0
519 }
520}
521
522impl WasmSMTResult {
523 fn from_result(r: &SMTResult) -> Self {
524 WasmSMTResult {
525 old_root: scalar_to_bytes(&r.old_root),
526 new_root: scalar_to_bytes(&r.new_root),
527 siblings: r.siblings.iter().flat_map(scalar_to_bytes).collect(),
528 old_key: scalar_to_bytes(&r.old_key),
529 old_value: scalar_to_bytes(&r.old_value),
530 new_key: scalar_to_bytes(&r.new_key),
531 new_value: scalar_to_bytes(&r.new_value),
532 is_old0: r.is_old0,
533 num_siblings: r.siblings.len(),
534 }
535 }
536}
537
538#[wasm_bindgen]
540pub struct WasmFindResult {
541 found: bool,
542 siblings: Vec<u8>,
543 found_value: Vec<u8>,
544 not_found_key: Vec<u8>,
545 not_found_value: Vec<u8>,
546 is_old0: bool,
547 root: Vec<u8>,
548 num_siblings: usize,
549}
550
551#[wasm_bindgen]
552impl WasmFindResult {
553 #[wasm_bindgen(getter)]
555 pub fn found(&self) -> bool {
556 self.found
557 }
558
559 #[wasm_bindgen(getter)]
561 pub fn siblings(&self) -> Vec<u8> {
562 self.siblings.clone()
563 }
564
565 #[wasm_bindgen(getter)]
567 pub fn num_siblings(&self) -> usize {
568 self.num_siblings
569 }
570
571 #[wasm_bindgen(getter)]
573 pub fn found_value(&self) -> Vec<u8> {
574 self.found_value.clone()
575 }
576
577 #[wasm_bindgen(getter)]
579 pub fn not_found_key(&self) -> Vec<u8> {
580 self.not_found_key.clone()
581 }
582
583 #[wasm_bindgen(getter)]
585 pub fn not_found_value(&self) -> Vec<u8> {
586 self.not_found_value.clone()
587 }
588
589 #[wasm_bindgen(getter)]
591 pub fn is_old0(&self) -> bool {
592 self.is_old0
593 }
594
595 #[wasm_bindgen(getter)]
597 pub fn root(&self) -> Vec<u8> {
598 self.root.clone()
599 }
600}
601
602impl WasmFindResult {
603 fn from_result(r: &FindResult, root: &Scalar) -> Self {
604 WasmFindResult {
605 found: r.found,
606 siblings: r.siblings.iter().flat_map(scalar_to_bytes).collect(),
607 found_value: scalar_to_bytes(&r.found_value),
608 not_found_key: scalar_to_bytes(&r.not_found_key),
609 not_found_value: scalar_to_bytes(&r.not_found_value),
610 is_old0: r.is_old0,
611 root: scalar_to_bytes(root),
612 num_siblings: r.siblings.len(),
613 }
614 }
615}
616
617#[wasm_bindgen]
619pub struct WasmSMTProof {
620 found: bool,
621 siblings: Vec<u8>,
622 found_value: Vec<u8>,
623 not_found_key: Vec<u8>,
624 not_found_value: Vec<u8>,
625 is_old0: bool,
626 root: Vec<u8>,
627 num_siblings: usize,
628}
629
630#[wasm_bindgen]
631impl WasmSMTProof {
632 #[wasm_bindgen(getter)]
634 pub fn found(&self) -> bool {
635 self.found
636 }
637
638 #[wasm_bindgen(getter)]
640 pub fn siblings(&self) -> Vec<u8> {
641 self.siblings.clone()
642 }
643
644 #[wasm_bindgen(getter)]
646 pub fn num_siblings(&self) -> usize {
647 self.num_siblings
648 }
649
650 #[wasm_bindgen(getter)]
652 pub fn found_value(&self) -> Vec<u8> {
653 self.found_value.clone()
654 }
655
656 #[wasm_bindgen(getter)]
658 pub fn not_found_key(&self) -> Vec<u8> {
659 self.not_found_key.clone()
660 }
661
662 #[wasm_bindgen(getter)]
664 pub fn not_found_value(&self) -> Vec<u8> {
665 self.not_found_value.clone()
666 }
667
668 #[wasm_bindgen(getter)]
670 pub fn is_old0(&self) -> bool {
671 self.is_old0
672 }
673
674 #[wasm_bindgen(getter)]
676 pub fn root(&self) -> Vec<u8> {
677 self.root.clone()
678 }
679}
680
681#[wasm_bindgen]
683pub fn smt_hash_pair(left: &[u8], right: &[u8]) -> Result<Vec<u8>, JsValue> {
684 let l = bytes_to_scalar(left)?;
685 let r = bytes_to_scalar(right)?;
686 let result = poseidon2_compression(l, r);
687 Ok(scalar_to_bytes(&result))
688}
689
690#[wasm_bindgen]
692pub fn smt_hash_leaf(key: &[u8], value: &[u8]) -> Result<Vec<u8>, JsValue> {
693 let k = bytes_to_scalar(key)?;
694 let v = bytes_to_scalar(value)?;
695 let result = poseidon2_hash_leaf(k, v);
696 Ok(scalar_to_bytes(&result))
697}