circuits/core/
merkle.rs

1//! Merkle tree utilities using Poseidon2 hash
2//!
3//! Provides merkle tree operations for use in ZK circuits. These functions
4//! match the Circom circuit implementations and produce identical roots/proofs.
5
6use alloc::vec::Vec;
7use core::ops::Add;
8use zkhash::{
9    fields::bn256::FpBN256 as Scalar,
10    poseidon2::{poseidon2::Poseidon2, poseidon2_instance_bn256::POSEIDON2_BN256_PARAMS_2},
11};
12
13/// Poseidon2 compression for merkle tree nodes
14///
15/// Computes `P(left, right)[0] + left` where P is the Poseidon2 permutation.
16/// This matches the feed-forward compression used in Circom circuits.
17#[inline]
18pub fn poseidon2_compression(left: Scalar, right: Scalar) -> Scalar {
19    let poseidon2 = Poseidon2::new(&POSEIDON2_BN256_PARAMS_2);
20    let input = [left, right];
21    let perm = poseidon2.permutation(&input);
22    perm[0].add(input[0])
23}
24
25/// Build a Merkle root from a full list of leaves
26///
27/// Computes the Merkle root by repeatedly hashing pairs of nodes until
28/// a single root remains.
29///
30/// # Panics
31///
32/// Panics if `leaves` is empty.
33pub fn merkle_root(mut leaves: Vec<Scalar>) -> Scalar {
34    assert!(!leaves.is_empty(), "leaves cannot be empty");
35    assert!(
36        leaves.len().is_power_of_two(),
37        "leaves length must be a power of 2"
38    );
39    while leaves.len() > 1 {
40        let mut next = Vec::with_capacity(leaves.len() / 2);
41        for pair in leaves.chunks_exact(2) {
42            next.push(poseidon2_compression(pair[0], pair[1]));
43        }
44        leaves = next;
45    }
46    leaves[0]
47}
48
49/// Compute the Merkle path and path index bits for a given leaf
50/// index
51///
52/// Generates the Merkle proof for a leaf at the given index, including all
53/// sibling nodes along the path to the root and the path indices encoded as
54/// a bit pattern.
55///
56/// # Returns
57///
58/// Returns a tuple containing:
59/// - `path_elements`: Vector of sibling scalar values along the path
60/// - `path_indices`: Path indices encoded as a u64 bit pattern
61/// - `levels`: Number of levels in the tree
62pub fn merkle_proof(leaves: &[Scalar], mut index: usize) -> (Vec<Scalar>, u64, usize) {
63    assert!(!leaves.is_empty() && leaves.len().is_power_of_two());
64    let mut level_nodes = leaves.to_vec();
65    let levels = level_nodes.len().ilog2() as usize;
66
67    let mut path_elems = Vec::with_capacity(levels);
68    let mut path_indices_bits_lsb = Vec::with_capacity(levels);
69
70    for _level in 0..levels {
71        let sib_index = if index.is_multiple_of(2) {
72            index.checked_add(1).expect("sibling index overflow")
73        } else {
74            index.checked_sub(1).expect("sibling index underflow")
75        };
76
77        path_elems.push(level_nodes[sib_index]);
78        path_indices_bits_lsb.push((index & 1) as u64);
79
80        let mut next = Vec::with_capacity(leaves.len() / 2);
81        for pair in level_nodes.chunks_exact(2) {
82            next.push(poseidon2_compression(pair[0], pair[1]));
83        }
84        level_nodes = next;
85        index /= 2;
86    }
87
88    let mut path_indices: u64 = 0;
89    for (i, b) in path_indices_bits_lsb.iter().copied().enumerate() {
90        path_indices |= b << i;
91    }
92
93    (path_elems, path_indices, levels)
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    #[test]
101    fn test_merkle_root_single_leaf() {
102        let leaf = Scalar::from(42u64);
103        let root = merkle_root(alloc::vec![leaf]);
104        assert_eq!(root, leaf);
105    }
106
107    #[test]
108    fn test_merkle_root_two_leaves() {
109        let leaves = alloc::vec![Scalar::from(1u64), Scalar::from(2u64)];
110        let root = merkle_root(leaves.clone());
111        let expected = poseidon2_compression(leaves[0], leaves[1]);
112        assert_eq!(root, expected);
113    }
114
115    #[test]
116    fn test_merkle_proof_basics() {
117        let leaves: Vec<Scalar> = (0..4).map(Scalar::from).collect();
118        let (path, indices, levels) = merkle_proof(&leaves, 0);
119
120        assert_eq!(levels, 2);
121        assert_eq!(path.len(), 2);
122        assert_eq!(indices, 0);
123    }
124
125    #[test]
126    fn test_merkle_proof_verifies() {
127        let leaves: Vec<Scalar> = (0..4).map(Scalar::from).collect();
128        let root = merkle_root(leaves.clone());
129
130        for idx in 0..4 {
131            let (path, indices, levels) = merkle_proof(&leaves, idx);
132            let mut current = leaves[idx];
133
134            for (level, elem) in path.iter().enumerate().take(levels) {
135                let is_right = (indices >> level) & 1 == 1;
136                current = if is_right {
137                    poseidon2_compression(*elem, current)
138                } else {
139                    poseidon2_compression(current, *elem)
140                };
141            }
142
143            assert_eq!(current, root, "Proof verification failed for index {}", idx);
144        }
145    }
146}