1use alloc::vec::Vec;
7use core::ops::Add;
8use zkhash::{
9 fields::bn256::FpBN256 as Scalar,
10 poseidon2::{poseidon2::Poseidon2, poseidon2_instance_bn256::POSEIDON2_BN256_PARAMS_2},
11};
12
13#[inline]
18pub fn poseidon2_compression(left: Scalar, right: Scalar) -> Scalar {
19 let poseidon2 = Poseidon2::new(&POSEIDON2_BN256_PARAMS_2);
20 let input = [left, right];
21 let perm = poseidon2.permutation(&input);
22 perm[0].add(input[0])
23}
24
25pub fn merkle_root(mut leaves: Vec<Scalar>) -> Scalar {
34 assert!(!leaves.is_empty(), "leaves cannot be empty");
35 assert!(
36 leaves.len().is_power_of_two(),
37 "leaves length must be a power of 2"
38 );
39 while leaves.len() > 1 {
40 let mut next = Vec::with_capacity(leaves.len() / 2);
41 for pair in leaves.chunks_exact(2) {
42 next.push(poseidon2_compression(pair[0], pair[1]));
43 }
44 leaves = next;
45 }
46 leaves[0]
47}
48
49pub fn merkle_proof(leaves: &[Scalar], mut index: usize) -> (Vec<Scalar>, u64, usize) {
63 assert!(!leaves.is_empty() && leaves.len().is_power_of_two());
64 let mut level_nodes = leaves.to_vec();
65 let levels = level_nodes.len().ilog2() as usize;
66
67 let mut path_elems = Vec::with_capacity(levels);
68 let mut path_indices_bits_lsb = Vec::with_capacity(levels);
69
70 for _level in 0..levels {
71 let sib_index = if index.is_multiple_of(2) {
72 index.checked_add(1).expect("sibling index overflow")
73 } else {
74 index.checked_sub(1).expect("sibling index underflow")
75 };
76
77 path_elems.push(level_nodes[sib_index]);
78 path_indices_bits_lsb.push((index & 1) as u64);
79
80 let mut next = Vec::with_capacity(leaves.len() / 2);
81 for pair in level_nodes.chunks_exact(2) {
82 next.push(poseidon2_compression(pair[0], pair[1]));
83 }
84 level_nodes = next;
85 index /= 2;
86 }
87
88 let mut path_indices: u64 = 0;
89 for (i, b) in path_indices_bits_lsb.iter().copied().enumerate() {
90 path_indices |= b << i;
91 }
92
93 (path_elems, path_indices, levels)
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99
100 #[test]
101 fn test_merkle_root_single_leaf() {
102 let leaf = Scalar::from(42u64);
103 let root = merkle_root(alloc::vec![leaf]);
104 assert_eq!(root, leaf);
105 }
106
107 #[test]
108 fn test_merkle_root_two_leaves() {
109 let leaves = alloc::vec![Scalar::from(1u64), Scalar::from(2u64)];
110 let root = merkle_root(leaves.clone());
111 let expected = poseidon2_compression(leaves[0], leaves[1]);
112 assert_eq!(root, expected);
113 }
114
115 #[test]
116 fn test_merkle_proof_basics() {
117 let leaves: Vec<Scalar> = (0..4).map(Scalar::from).collect();
118 let (path, indices, levels) = merkle_proof(&leaves, 0);
119
120 assert_eq!(levels, 2);
121 assert_eq!(path.len(), 2);
122 assert_eq!(indices, 0);
123 }
124
125 #[test]
126 fn test_merkle_proof_verifies() {
127 let leaves: Vec<Scalar> = (0..4).map(Scalar::from).collect();
128 let root = merkle_root(leaves.clone());
129
130 for idx in 0..4 {
131 let (path, indices, levels) = merkle_proof(&leaves, idx);
132 let mut current = leaves[idx];
133
134 for (level, elem) in path.iter().enumerate().take(levels) {
135 let is_right = (indices >> level) & 1 == 1;
136 current = if is_right {
137 poseidon2_compression(*elem, current)
138 } else {
139 poseidon2_compression(current, *elem)
140 };
141 }
142
143 assert_eq!(current, root, "Proof verification failed for index {}", idx);
144 }
145 }
146}