1use blake3::hazmat;
2use itertools::Itertools;
3use std::{
4 fmt,
5 fmt::{Display, Formatter},
6 ops::Deref,
7};
8use thiserror::Error;
9use zinc_transcript::traits::{ConstTranscribable, GenTranscribable};
10use zinc_utils::{add, cfg_into_iter, cfg_iter, sub};
11
12#[cfg(feature = "parallel")]
13use rayon::prelude::*;
14
15pub const HASH_OUT_LEN: usize = blake3::OUT_LEN;
16
17#[derive(Clone, Debug, PartialEq, Eq)]
18#[repr(transparent)]
19pub struct MtHash(pub(crate) [u8; HASH_OUT_LEN]);
20
21impl Default for MtHash {
22 fn default() -> Self {
23 MtHash([0; HASH_OUT_LEN])
24 }
25}
26
27impl Deref for MtHash {
28 type Target = [u8];
29
30 fn deref(&self) -> &Self::Target {
31 &self.0
32 }
33}
34
35impl Display for MtHash {
36 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
37 let blake3_hash: blake3::Hash = self.0.into();
38 <blake3::Hash as Display>::fmt(&blake3_hash, f)
39 }
40}
41
42impl GenTranscribable for MtHash {
43 fn read_transcription_bytes_exact(buf: &[u8]) -> Self {
44 assert_eq!(buf.len(), HASH_OUT_LEN);
45 MtHash(buf.try_into().expect("Invalid buffer length for MtHash"))
46 }
47
48 fn write_transcription_bytes_exact(&self, buf: &mut [u8]) {
49 assert_eq!(buf.len(), HASH_OUT_LEN);
50 buf.copy_from_slice(&self.0);
51 }
52}
53
54impl ConstTranscribable for MtHash {
55 const NUM_BYTES: usize = HASH_OUT_LEN;
56}
57
58impl<B> From<B> for MtHash
59where
60 B: Into<[u8; HASH_OUT_LEN]>,
61{
62 fn from(b: B) -> Self {
63 MtHash(b.into())
64 }
65}
66
67#[derive(Debug, Clone, Default)]
68pub struct MerkleTree {
69 layers: Vec<Vec<MtHash>>,
71}
72
73impl MerkleTree {
74 pub fn new<S>(rows: &[&[S]]) -> Self
75 where
76 S: ConstTranscribable + Clone + Send + Sync,
77 {
78 assert!(!rows.is_empty());
79 let row_width = rows[0].len();
80 assert!(row_width > 0);
81 assert!(
82 rows.iter().all(|row| row.len() == row_width),
83 "All rows must have the same width"
84 );
85 assert!(row_width.is_power_of_two());
86
87 let leaves = hash_leaves(rows, row_width);
88 build_merkle_tree_from_leaves(leaves)
89 }
90
91 pub fn height(&self) -> usize {
92 self.layers.len()
93 }
94
95 pub fn root(&self) -> MtHash {
96 self.layers
97 .last()
98 .expect("Merkle tree must have at least one layer")
99 .first()
100 .cloned()
101 .expect("Merkle tree must have a root")
102 }
103
104 pub fn prove(&self, leaf_index: usize) -> Result<MerkleProof, MerkleError> {
106 let leaf_count = self.layers[0].len();
107
108 if leaf_index >= leaf_count || leaf_count == 0 {
109 return Err(MerkleError::InvalidLeafIndex(leaf_index));
110 }
111
112 let siblings = build_sibling_path(leaf_index, &self.layers);
114
115 Ok(MerkleProof {
116 leaf_index,
117 leaf_count,
118 siblings,
119 })
120 }
121}
122
123#[allow(clippy::arithmetic_side_effects)]
128fn hash_column<S: ConstTranscribable>(values: &[S]) -> MtHash {
129 let elem_bytes = S::NUM_BYTES;
130 let mut buf = vec![0_u8; values.len() * elem_bytes];
131 for (i, v) in values.iter().enumerate() {
132 let start = i * elem_bytes;
133 v.write_transcription_bytes_exact(&mut buf[start..start + elem_bytes]);
134 }
135 let mut hasher = blake3::Hasher::new();
136 hasher.update(&buf);
137 hasher.finalize().into()
138}
139
140#[allow(clippy::arithmetic_side_effects)]
148fn hash_leaves<S>(rows: &[&[S]], m_cols: usize) -> Vec<MtHash>
149where
150 S: ConstTranscribable + Send + Sync,
151{
152 let num_rows = rows.len();
153 let elem_bytes = S::NUM_BYTES;
154 let col_bytes = num_rows * elem_bytes;
155
156 cfg_into_iter!(0..m_cols)
157 .map(|i| {
158 let mut buf = vec![0_u8; col_bytes];
159 for (r, row) in rows.iter().enumerate() {
160 let start = r * elem_bytes;
161 row[i].write_transcription_bytes_exact(&mut buf[start..start + elem_bytes]);
162 }
163 let mut hasher = blake3::Hasher::new();
164 hasher.update(&buf);
165 hasher.finalize().into()
166 })
167 .collect()
168}
169
170fn build_merkle_tree_from_leaves(leaves: Vec<MtHash>) -> MerkleTree {
173 let n = leaves.len();
174
175 if n == 0 {
176 return MerkleTree {
177 layers: vec![vec![blake3::hash(&[]).into()]],
178 };
179 }
180 assert!(
181 n.is_power_of_two(),
182 "Number of leaves must be a power of two"
183 );
184
185 if n == 1 {
186 return MerkleTree {
187 layers: vec![leaves],
188 };
189 }
190
191 let root_layer_idx = n.trailing_zeros() as usize; let num_layers = add!(root_layer_idx, 1);
195 let mut layers: Vec<Vec<MtHash>> = Vec::with_capacity(num_layers);
196
197 layers.push(leaves);
199
200 for layer_idx in 1..num_layers {
202 let is_root_layer = layer_idx == root_layer_idx;
203
204 let prev_layer = &layers[sub!(layer_idx, 1)];
205 let (prev_layer_chunks, _) = prev_layer.as_chunks::<2>();
206
207 let current_layer = cfg_iter!(prev_layer_chunks)
208 .map(|[left, right]| {
209 if is_root_layer {
210 hazmat::merge_subtrees_root(&left.0, &right.0, hazmat::Mode::Hash).into()
211 } else {
212 hazmat::merge_subtrees_non_root(&left.0, &right.0, hazmat::Mode::Hash).into()
213 }
214 })
215 .collect();
216
217 layers.push(current_layer);
218 }
219
220 MerkleTree { layers }
221}
222
223#[allow(clippy::arithmetic_side_effects)] fn build_sibling_path(target_index: usize, layers: &[Vec<MtHash>]) -> Vec<MtHash> {
225 let mut siblings = Vec::new();
226 let mut layer_idx = 0;
227 let mut current_layer = &layers[layer_idx];
228 let mut current_index = target_index;
229
230 loop {
231 let is_left_child = current_index.is_multiple_of(2);
233
234 if is_left_child {
235 let sibling_index = current_index + 1;
237 if sibling_index < current_layer.len() {
238 siblings.push(current_layer[sibling_index].clone());
239 } else {
240 debug_assert_eq!(layer_idx, layers.len() - 1);
242 debug_assert_eq!(current_layer.len(), 1);
243 break;
244 }
245 } else {
246 let sibling_index = current_index - 1;
248 siblings.push(current_layer[sibling_index].clone());
249 }
250
251 current_index /= 2;
252 layer_idx += 1;
253 current_layer = &layers[layer_idx];
254 }
255
256 siblings
257}
258
259#[derive(Clone, Debug, PartialEq, Eq)]
260pub struct MerkleProof {
261 pub leaf_index: usize,
263 pub leaf_count: usize,
265 pub siblings: Vec<MtHash>,
267}
268
269impl MerkleProof {
270 pub fn new(leaf_index: usize, leaf_count: usize, siblings: Vec<MtHash>) -> Self {
271 assert!(!siblings.is_empty(), "Merkle proof path cannot be empty");
272 assert!(leaf_index < leaf_count, "Leaf index out of bounds");
273 Self {
274 leaf_index,
275 leaf_count,
276 siblings,
277 }
278 }
279
280 pub fn verify<S>(
283 &self,
284 root: &MtHash,
285 column_values: &[S],
286 leaf_index: usize,
287 ) -> Result<(), MerkleError>
288 where
289 S: ConstTranscribable,
290 {
291 if leaf_index != self.leaf_index {
292 return Err(MerkleError::InvalidLeafIndex(leaf_index));
293 }
294
295 let mut current_cv: MtHash = hash_column(column_values);
296
297 if self.leaf_count == 1 {
298 if self.leaf_index == 0 && self.siblings.is_empty() {
299 if ¤t_cv != root {
301 return Err(MerkleError::InvalidRootHash);
302 }
303 return Ok(());
304 } else {
305 return Err(MerkleError::InvalidMerkleProof(
306 "Single element Merkle proof is invalid".to_owned(),
307 ));
308 }
309 }
310
311 let directions = get_path_directions(self.leaf_count, self.leaf_index);
312
313 if directions.len() != self.siblings.len() {
314 return Err(MerkleError::InvalidMerklePathLength {
315 expected: self.siblings.len(),
316 actual: directions.len(),
317 });
318 }
319
320 let mut path_iter = self.siblings.iter().zip(directions.iter());
322
323 let Some((last_sibling, last_direction)) = path_iter.next_back() else {
325 unreachable!("There should always be at least one sibling in the proof");
326 };
327
328 for (sibling_cv, direction) in path_iter {
330 let is_left = matches!(direction, PathDirection::Left);
331 if is_left {
332 current_cv = hazmat::merge_subtrees_non_root(
333 ¤t_cv.0,
334 &sibling_cv.0,
335 hazmat::Mode::Hash,
336 )
337 .into();
338 } else {
339 current_cv = hazmat::merge_subtrees_non_root(
340 &sibling_cv.0,
341 ¤t_cv.0,
342 hazmat::Mode::Hash,
343 )
344 .into();
345 }
346 }
347
348 let final_hash: MtHash = if matches!(last_direction, PathDirection::Left) {
350 hazmat::merge_subtrees_root(¤t_cv.0, &last_sibling.0, hazmat::Mode::Hash).into()
351 } else {
352 hazmat::merge_subtrees_root(&last_sibling.0, ¤t_cv.0, hazmat::Mode::Hash).into()
353 };
354
355 if &final_hash != root {
356 return Err(MerkleError::InvalidRootHash);
357 }
358 Ok(())
359 }
360
361 #[allow(clippy::arithmetic_side_effects)] pub fn estimate_transcribed_size(merkle_tree_height: usize) -> usize {
365 3 * u64::NUM_BYTES + (merkle_tree_height - 1) * MtHash::NUM_BYTES
367 }
368}
369
370impl Display for MerkleProof {
371 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
372 writeln!(f, "Merkle Path: {}", self.siblings.iter().join(", "))?;
373 Ok(())
374 }
375}
376
377#[derive(Debug, Clone, Copy, PartialEq, Eq)]
378enum PathDirection {
379 Left,
380 Right,
381}
382
383#[allow(clippy::arithmetic_side_effects)] fn get_path_directions(total_chunks: usize, target_index: usize) -> Vec<PathDirection> {
386 let mut path = Vec::new();
387 let mut current_size = total_chunks;
388 let mut current_index = target_index;
389
390 while current_size > 1 {
392 let split_len = current_size.next_power_of_two() / 2;
395
396 if current_index < split_len {
397 path.push(PathDirection::Left);
398 current_size = split_len;
399 } else {
400 path.push(PathDirection::Right);
402 current_size -= split_len;
403 current_index -= split_len;
404 }
405 }
406 path.reverse();
409 path
410}
411
412#[derive(Error, Debug)]
413pub enum MerkleError {
414 #[error("Invalid PCS opening: {0}")]
415 InvalidPcsOpen(String),
416
417 #[error("Invalid Merkle proof: {0}")]
418 InvalidMerkleProof(String),
419
420 #[error("Invalid Merkle path length: expected {expected}, got {actual}")]
421 InvalidMerklePathLength { expected: usize, actual: usize },
422
423 #[error("Invalid leaf index: {0}")]
424 InvalidLeafIndex(usize),
425
426 #[error("Invalid root hash")]
427 InvalidRootHash,
428
429 #[error("Failed to read merkle proof")]
430 FailedMerkleProofReading,
431
432 #[error("Failed to write merkle proof")]
433 FailedMerkleProofWriting,
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439 use crypto_bigint::Random;
440 use crypto_primitives::crypto_bigint_int::Int;
441 use rand::rng;
442
443 #[test]
444 fn test_merkle_proof() {
445 const N: usize = 3;
446 let leaves_len = 1024;
447 let mut rng = rng();
448 let leaves_data = (0..leaves_len)
449 .map(|_| Int::random(&mut rng))
450 .collect::<Vec<Int<N>>>();
451
452 let merkle_tree = MerkleTree::new(&[leaves_data.as_slice()]);
453
454 let root = merkle_tree.root();
456 for (i, leaf) in leaves_data.iter().enumerate() {
458 let proof = merkle_tree.prove(i).expect("Merkle proof creation failed");
459
460 let result = proof.verify(&root, &[*leaf], i);
462 assert!(
463 result.is_ok(),
464 "Merkle proof verification failed for leaf index {i}: {}",
465 result.err().unwrap()
466 );
467 }
468 }
469}