Skip to main content

zip_plus/
merkle.rs

1use blake3::hazmat;
2use itertools::Itertools;
3use std::{
4    fmt,
5    fmt::{Display, Formatter},
6    ops::Deref,
7};
8use thiserror::Error;
9use zinc_transcript::traits::{ConstTranscribable, GenTranscribable};
10use zinc_utils::{add, cfg_into_iter, cfg_iter, sub};
11
12#[cfg(feature = "parallel")]
13use rayon::prelude::*;
14
15pub const HASH_OUT_LEN: usize = blake3::OUT_LEN;
16
17#[derive(Clone, Debug, PartialEq, Eq)]
18#[repr(transparent)]
19pub struct MtHash(pub(crate) [u8; HASH_OUT_LEN]);
20
21impl Default for MtHash {
22    fn default() -> Self {
23        MtHash([0; HASH_OUT_LEN])
24    }
25}
26
27impl Deref for MtHash {
28    type Target = [u8];
29
30    fn deref(&self) -> &Self::Target {
31        &self.0
32    }
33}
34
35impl Display for MtHash {
36    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
37        let blake3_hash: blake3::Hash = self.0.into();
38        <blake3::Hash as Display>::fmt(&blake3_hash, f)
39    }
40}
41
42impl GenTranscribable for MtHash {
43    fn read_transcription_bytes_exact(buf: &[u8]) -> Self {
44        assert_eq!(buf.len(), HASH_OUT_LEN);
45        MtHash(buf.try_into().expect("Invalid buffer length for MtHash"))
46    }
47
48    fn write_transcription_bytes_exact(&self, buf: &mut [u8]) {
49        assert_eq!(buf.len(), HASH_OUT_LEN);
50        buf.copy_from_slice(&self.0);
51    }
52}
53
54impl ConstTranscribable for MtHash {
55    const NUM_BYTES: usize = HASH_OUT_LEN;
56}
57
58impl<B> From<B> for MtHash
59where
60    B: Into<[u8; HASH_OUT_LEN]>,
61{
62    fn from(b: B) -> Self {
63        MtHash(b.into())
64    }
65}
66
67#[derive(Debug, Clone, Default)]
68pub struct MerkleTree {
69    /// First vector is leaves, last vector is root
70    layers: Vec<Vec<MtHash>>,
71}
72
73impl MerkleTree {
74    pub fn new<S>(rows: &[&[S]]) -> Self
75    where
76        S: ConstTranscribable + Clone + Send + Sync,
77    {
78        assert!(!rows.is_empty());
79        let row_width = rows[0].len();
80        assert!(row_width > 0);
81        assert!(
82            rows.iter().all(|row| row.len() == row_width),
83            "All rows must have the same width"
84        );
85        assert!(row_width.is_power_of_two());
86
87        let leaves = hash_leaves(rows, row_width);
88        build_merkle_tree_from_leaves(leaves)
89    }
90
91    pub fn height(&self) -> usize {
92        self.layers.len()
93    }
94
95    pub fn root(&self) -> MtHash {
96        self.layers
97            .last()
98            .expect("Merkle tree must have at least one layer")
99            .first()
100            .cloned()
101            .expect("Merkle tree must have a root")
102    }
103
104    /// Generates a Merkle proof for the element at the given index.
105    pub fn prove(&self, leaf_index: usize) -> Result<MerkleProof, MerkleError> {
106        let leaf_count = self.layers[0].len();
107
108        if leaf_index >= leaf_count || leaf_count == 0 {
109            return Err(MerkleError::InvalidLeafIndex(leaf_index));
110        }
111
112        // Calculate the sibling path using layer values.
113        let siblings = build_sibling_path(leaf_index, &self.layers);
114
115        Ok(MerkleProof {
116            leaf_index,
117            leaf_count,
118            siblings,
119        })
120    }
121}
122
123/// Serialize all elements of `values` into a single contiguous byte buffer
124/// and hash them with Blake3 in one `update` call.  This lets Blake3 process
125/// full 1 KiB chunks with SIMD, which is significantly faster than the
126/// per-element `update` approach.
127#[allow(clippy::arithmetic_side_effects)]
128fn hash_column<S: ConstTranscribable>(values: &[S]) -> MtHash {
129    let elem_bytes = S::NUM_BYTES;
130    let mut buf = vec![0_u8; values.len() * elem_bytes];
131    for (i, v) in values.iter().enumerate() {
132        let start = i * elem_bytes;
133        v.write_transcription_bytes_exact(&mut buf[start..start + elem_bytes]);
134    }
135    let mut hasher = blake3::Hasher::new();
136    hasher.update(&buf);
137    hasher.finalize().into()
138}
139
140/// Construct the leaves of the Merkle tree by hashing each column across all
141/// rows.
142///
143/// For each column, serializes all row elements into a single contiguous byte
144/// buffer and feeds it to Blake3 in one `update` call.  This lets Blake3
145/// process full 1 KiB chunks with SIMD, which is significantly faster than
146/// the per-element `update` approach when columns are tall (many rows).
147#[allow(clippy::arithmetic_side_effects)]
148fn hash_leaves<S>(rows: &[&[S]], m_cols: usize) -> Vec<MtHash>
149where
150    S: ConstTranscribable + Send + Sync,
151{
152    let num_rows = rows.len();
153    let elem_bytes = S::NUM_BYTES;
154    let col_bytes = num_rows * elem_bytes;
155
156    cfg_into_iter!(0..m_cols)
157        .map(|i| {
158            let mut buf = vec![0_u8; col_bytes];
159            for (r, row) in rows.iter().enumerate() {
160                let start = r * elem_bytes;
161                row[i].write_transcription_bytes_exact(&mut buf[start..start + elem_bytes]);
162            }
163            let mut hasher = blake3::Hasher::new();
164            hasher.update(&buf);
165            hasher.finalize().into()
166        })
167        .collect()
168}
169
170/// Builds a Merkle tree from the given leaves, abusing blake3::hazmat module
171/// for subtree merging.
172fn build_merkle_tree_from_leaves(leaves: Vec<MtHash>) -> MerkleTree {
173    let n = leaves.len();
174
175    if n == 0 {
176        return MerkleTree {
177            layers: vec![vec![blake3::hash(&[]).into()]],
178        };
179    }
180    assert!(
181        n.is_power_of_two(),
182        "Number of leaves must be a power of two"
183    );
184
185    if n == 1 {
186        return MerkleTree {
187            layers: vec![leaves],
188        };
189    }
190
191    // Build all layers from bottom (leaves) to top (root)
192    // layers[i] contains all contiguous subtree roots of size 2^i
193    let root_layer_idx = n.trailing_zeros() as usize; // log2(n)
194    let num_layers = add!(root_layer_idx, 1);
195    let mut layers: Vec<Vec<MtHash>> = Vec::with_capacity(num_layers);
196
197    // Layer 0: individual leaves
198    layers.push(leaves);
199
200    // Build each subsequent layer
201    for layer_idx in 1..num_layers {
202        let is_root_layer = layer_idx == root_layer_idx;
203
204        let prev_layer = &layers[sub!(layer_idx, 1)];
205        let (prev_layer_chunks, _) = prev_layer.as_chunks::<2>();
206
207        let current_layer = cfg_iter!(prev_layer_chunks)
208            .map(|[left, right]| {
209                if is_root_layer {
210                    hazmat::merge_subtrees_root(&left.0, &right.0, hazmat::Mode::Hash).into()
211                } else {
212                    hazmat::merge_subtrees_non_root(&left.0, &right.0, hazmat::Mode::Hash).into()
213                }
214            })
215            .collect();
216
217        layers.push(current_layer);
218    }
219
220    MerkleTree { layers }
221}
222
223#[allow(clippy::arithmetic_side_effects)] // Using intentionally, overflow isn't possible
224fn build_sibling_path(target_index: usize, layers: &[Vec<MtHash>]) -> Vec<MtHash> {
225    let mut siblings = Vec::new();
226    let mut layer_idx = 0;
227    let mut current_layer = &layers[layer_idx];
228    let mut current_index = target_index;
229
230    loop {
231        // Determine if current node is left (even) or right (odd) child
232        let is_left_child = current_index.is_multiple_of(2);
233
234        if is_left_child {
235            // Left child, sibling is on the right
236            let sibling_index = current_index + 1;
237            if sibling_index < current_layer.len() {
238                siblings.push(current_layer[sibling_index].clone());
239            } else {
240                // We've reached the root
241                debug_assert_eq!(layer_idx, layers.len() - 1);
242                debug_assert_eq!(current_layer.len(), 1);
243                break;
244            }
245        } else {
246            // Right child, sibling is on the left
247            let sibling_index = current_index - 1;
248            siblings.push(current_layer[sibling_index].clone());
249        }
250
251        current_index /= 2;
252        layer_idx += 1;
253        current_layer = &layers[layer_idx];
254    }
255
256    siblings
257}
258
259#[derive(Clone, Debug, PartialEq, Eq)]
260pub struct MerkleProof {
261    /// Index of the leaf being proven
262    pub leaf_index: usize,
263    /// Total number of leaves in the tree
264    pub leaf_count: usize,
265    /// The path of sibling chaining values (bottom-up order).
266    pub siblings: Vec<MtHash>,
267}
268
269impl MerkleProof {
270    pub fn new(leaf_index: usize, leaf_count: usize, siblings: Vec<MtHash>) -> Self {
271        assert!(!siblings.is_empty(), "Merkle proof path cannot be empty");
272        assert!(leaf_index < leaf_count, "Leaf index out of bounds");
273        Self {
274            leaf_index,
275            leaf_count,
276            siblings,
277        }
278    }
279
280    /// Verifies the proof against a known root hash and the claimed element
281    /// data.
282    pub fn verify<S>(
283        &self,
284        root: &MtHash,
285        column_values: &[S],
286        leaf_index: usize,
287    ) -> Result<(), MerkleError>
288    where
289        S: ConstTranscribable,
290    {
291        if leaf_index != self.leaf_index {
292            return Err(MerkleError::InvalidLeafIndex(leaf_index));
293        }
294
295        let mut current_cv: MtHash = hash_column(column_values);
296
297        if self.leaf_count == 1 {
298            if self.leaf_index == 0 && self.siblings.is_empty() {
299                // The root is just the hash of the single element.
300                if &current_cv != root {
301                    return Err(MerkleError::InvalidRootHash);
302                }
303                return Ok(());
304            } else {
305                return Err(MerkleError::InvalidMerkleProof(
306                    "Single element Merkle proof is invalid".to_owned(),
307                ));
308            }
309        }
310
311        let directions = get_path_directions(self.leaf_count, self.leaf_index);
312
313        if directions.len() != self.siblings.len() {
314            return Err(MerkleError::InvalidMerklePathLength {
315                expected: self.siblings.len(),
316                actual: directions.len(),
317            });
318        }
319
320        //  Walk up the tree
321        let mut path_iter = self.siblings.iter().zip(directions.iter());
322
323        // Pop the last element for the root merge.
324        let Some((last_sibling, last_direction)) = path_iter.next_back() else {
325            unreachable!("There should always be at least one sibling in the proof");
326        };
327
328        // Iterate over intermediate merges (non-root).
329        for (sibling_cv, direction) in path_iter {
330            let is_left = matches!(direction, PathDirection::Left);
331            if is_left {
332                current_cv = hazmat::merge_subtrees_non_root(
333                    &current_cv.0,
334                    &sibling_cv.0,
335                    hazmat::Mode::Hash,
336                )
337                .into();
338            } else {
339                current_cv = hazmat::merge_subtrees_non_root(
340                    &sibling_cv.0,
341                    &current_cv.0,
342                    hazmat::Mode::Hash,
343                )
344                .into();
345            }
346        }
347
348        // Final root merge.
349        let final_hash: MtHash = if matches!(last_direction, PathDirection::Left) {
350            hazmat::merge_subtrees_root(&current_cv.0, &last_sibling.0, hazmat::Mode::Hash).into()
351        } else {
352            hazmat::merge_subtrees_root(&last_sibling.0, &current_cv.0, hazmat::Mode::Hash).into()
353        };
354
355        if &final_hash != root {
356            return Err(MerkleError::InvalidRootHash);
357        }
358        Ok(())
359    }
360
361    /// Estimate the number of bytes that would be written to [[PcsTranscript]]
362    /// when an instance of this type is transcribed.
363    #[allow(clippy::arithmetic_side_effects)] // Overflow isn't possible
364    pub fn estimate_transcribed_size(merkle_tree_height: usize) -> usize {
365        // Note the proof does not include leaf layer, so we subtract 1.
366        3 * u64::NUM_BYTES + (merkle_tree_height - 1) * MtHash::NUM_BYTES
367    }
368}
369
370impl Display for MerkleProof {
371    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
372        writeln!(f, "Merkle Path: {}", self.siblings.iter().join(", "))?;
373        Ok(())
374    }
375}
376
377#[derive(Debug, Clone, Copy, PartialEq, Eq)]
378enum PathDirection {
379    Left,
380    Right,
381}
382
383/// Helper to determine the path directions (leaf to root).
384#[allow(clippy::arithmetic_side_effects)] // Intentional, no side effects possible.
385fn get_path_directions(total_chunks: usize, target_index: usize) -> Vec<PathDirection> {
386    let mut path = Vec::new();
387    let mut current_size = total_chunks;
388    let mut current_index = target_index;
389
390    // Iterate top-down (Root to Leaf) to determine the path based on BLAKE3 rules.
391    while current_size > 1 {
392        // BLAKE3 split rule: largest power of two less than N
393        // (or N/2 if N is power of 2).
394        let split_len = current_size.next_power_of_two() / 2;
395
396        if current_index < split_len {
397            path.push(PathDirection::Left);
398            current_size = split_len;
399        } else {
400            // Went right.
401            path.push(PathDirection::Right);
402            current_size -= split_len;
403            current_index -= split_len;
404        }
405    }
406    // Reverse the path so it is ordered from leaf to root (bottom-up) for
407    // verification.
408    path.reverse();
409    path
410}
411
412#[derive(Error, Debug)]
413pub enum MerkleError {
414    #[error("Invalid PCS opening: {0}")]
415    InvalidPcsOpen(String),
416
417    #[error("Invalid Merkle proof: {0}")]
418    InvalidMerkleProof(String),
419
420    #[error("Invalid Merkle path length: expected {expected}, got {actual}")]
421    InvalidMerklePathLength { expected: usize, actual: usize },
422
423    #[error("Invalid leaf index: {0}")]
424    InvalidLeafIndex(usize),
425
426    #[error("Invalid root hash")]
427    InvalidRootHash,
428
429    #[error("Failed to read merkle proof")]
430    FailedMerkleProofReading,
431
432    #[error("Failed to write merkle proof")]
433    FailedMerkleProofWriting,
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439    use crypto_bigint::Random;
440    use crypto_primitives::crypto_bigint_int::Int;
441    use rand::rng;
442
443    #[test]
444    fn test_merkle_proof() {
445        const N: usize = 3;
446        let leaves_len = 1024;
447        let mut rng = rng();
448        let leaves_data = (0..leaves_len)
449            .map(|_| Int::random(&mut rng))
450            .collect::<Vec<Int<N>>>();
451
452        let merkle_tree = MerkleTree::new(&[leaves_data.as_slice()]);
453
454        // Print tree structure after merklizing
455        let root = merkle_tree.root();
456        // Create a proof for the first leaf
457        for (i, leaf) in leaves_data.iter().enumerate() {
458            let proof = merkle_tree.prove(i).expect("Merkle proof creation failed");
459
460            // Verify the proof
461            let result = proof.verify(&root, &[*leaf], i);
462            assert!(
463                result.is_ok(),
464                "Merkle proof verification failed for leaf index {i}: {}",
465                result.err().unwrap()
466            );
467        }
468    }
469}