Skip to main content

zip_plus/
pcs_transcript.rs

1use crate::{ZipError, merkle::MerkleProof, pcs::structs::ZipPlusCommitment};
2use crypto_primitives::PrimeField;
3use itertools::Itertools;
4use std::io::{Cursor, ErrorKind, Read, Write};
5use zinc_transcript::{
6    Blake3Transcript,
7    traits::{ConstTranscribable, GenTranscribable, Transcribable, Transcript},
8};
9use zinc_utils::{add, mul, rem};
10
11macro_rules! safe_cast {
12    ($value:expr, $from:ident, $to:ident) => {
13        $to::try_from($value).map_err(|_err| {
14            ZipError::Transcript(
15                ErrorKind::Unsupported,
16                format!(
17                    "Failed to convert {} to {}",
18                    stringify!($from),
19                    stringify!($to)
20                ),
21            )
22        })
23    };
24}
25
26macro_rules! common_methods {
27    () => {
28        /// Generates a pseudorandom index based on the current transcript state.
29        /// Used to create deterministic challenges for zero-knowledge protocols.
30        /// Returns an index between 0 and cap-1.
31        #[allow(clippy::unwrap_used)]
32        pub fn squeeze_challenge_idx(&mut self, cap: usize) -> usize {
33            let num = safe_cast!(self.fs_transcript.get_challenge::<u32>(), u32, usize)
34                .expect("Conversion from u32 to usize should never fail");
35            rem!(num, cap, "Challenge cap is zero")
36        }
37    };
38}
39
40/// A transcript for Polynomial Commitment Scheme (PCS) operations.
41/// Manages both Fiat-Shamir transformations and serialization/deserialization
42/// of proof data.
43#[derive(Debug, Clone)]
44pub struct PcsProverTranscript {
45    /// Handles Fiat-Shamir transformations for non-interactive zero-knowledge
46    /// proofs. Used to absorb field elements and generate cryptographic
47    /// challenges.
48    pub fs_transcript: Blake3Transcript,
49
50    /// Manages serialization and deserialization of proof data as a byte
51    /// stream.
52    pub stream: Cursor<Vec<u8>>,
53}
54
55// TODO(alex): Review this vs Transcribable, there is some overlap that needs to
56//             be resolved
57impl PcsProverTranscript {
58    pub fn new_from_commitment(comm: &ZipPlusCommitment) -> Self {
59        Self::new_from_commitments(std::slice::from_ref(comm).iter())
60    }
61
62    pub fn new_from_commitments<'a>(comms: impl Iterator<Item = &'a ZipPlusCommitment>) -> Self {
63        let mut result = Self {
64            fs_transcript: Blake3Transcript::default(),
65            stream: Cursor::default(),
66        };
67
68        for comm in comms {
69            result.fs_transcript.absorb_slice(&comm.root);
70        }
71
72        result
73    }
74
75    pub fn reserve_capacity(&mut self, additional_capacity: usize) {
76        self.stream.get_mut().reserve(additional_capacity)
77    }
78
79    /// Transform the prover transcript into a verifier transcript by resetting
80    /// the stream. Note that the commitment must be absorbed again into the
81    /// verifier transcript. This would normally be done by the verifier, but
82    /// this allows us more flexibility in how we use the transcript.
83    pub fn into_verification_transcript(self) -> PcsVerifierTranscript {
84        let mut result = PcsVerifierTranscript {
85            fs_transcript: Blake3Transcript::default(),
86            stream: self.stream,
87        };
88        result.stream.set_position(0);
89
90        result
91    }
92
93    common_methods!();
94
95    // Note: Currently this only works for fields whose modulus and inner element
96    // have the same byte length
97    //
98    // TODO if we change this to an iterator we may be able to save some memory
99    pub fn write_field_elements<F>(&mut self, elems: &[F]) -> Result<(), ZipError>
100    where
101        F: PrimeField,
102        F::Inner: Transcribable,
103        F::Modulus: Transcribable,
104    {
105        if !elems.is_empty() {
106            debug_assert_eq!(F::Inner::LENGTH_NUM_BYTES, F::Modulus::LENGTH_NUM_BYTES);
107            let num_bytes = F::Inner::get_num_bytes(elems[0].inner());
108            debug_assert_eq!(num_bytes, F::Modulus::get_num_bytes(&elems[0].modulus()));
109            let num_bytes_arr = num_bytes
110                .to_le_bytes()
111                .into_iter()
112                .take(F::Inner::LENGTH_NUM_BYTES)
113                .collect_vec();
114            self.stream.write_all(&num_bytes_arr)?;
115
116            let mut buf = vec![0; num_bytes];
117            for elem in elems {
118                self.write_field_element_no_length(elem, &mut buf)?;
119            }
120        }
121
122        Ok(())
123    }
124
125    /// Writes a field element to the proof stream and absorbs it into the
126    /// transcript. Used during proof generation to store field elements for
127    /// later verification.
128    ///
129    /// Field element length must've been written before calling this method.
130    fn write_field_element_no_length<F>(&mut self, fe: &F, buf: &mut [u8]) -> Result<(), ZipError>
131    where
132        F: PrimeField,
133        F::Inner: Transcribable,
134        F::Modulus: Transcribable,
135    {
136        self.fs_transcript.absorb_random_field(fe, buf);
137        fe.modulus().write_transcription_bytes_exact(buf);
138        self.stream.write_all(buf)?;
139        fe.inner().write_transcription_bytes_exact(buf);
140        self.stream.write_all(buf)?;
141        Ok(())
142    }
143
144    pub fn write<T: Transcribable>(&mut self, v: &T) -> Result<(), ZipError> {
145        let data_len = v.get_num_bytes();
146
147        // Write the length prefix when it is not known at compile time.
148        if T::LENGTH_NUM_BYTES > 0 {
149            let len_bytes = data_len
150                .to_le_bytes()
151                .into_iter()
152                .take(T::LENGTH_NUM_BYTES)
153                .collect_vec();
154            self.stream.write_all(&len_bytes)?;
155        }
156
157        let prev_pos = safe_cast!(self.stream.position(), u64, usize)?;
158        let next_pos = add!(prev_pos, data_len);
159
160        let inner = self.stream.get_mut();
161        if inner.len() < next_pos {
162            inner.resize(next_pos, 0_u8);
163        }
164
165        v.write_transcription_bytes_exact(&mut inner[prev_pos..next_pos]);
166
167        self.stream.set_position(safe_cast!(next_pos, usize, u64)?);
168        Ok(())
169    }
170
171    // Note(alex):
172    // Parallelizing this greatly degrades performance rather than improving it.
173    // Maybe we should think of breakpoints for parallelization later.
174    pub fn write_const_many<T: ConstTranscribable>(&mut self, vs: &[T]) -> Result<(), ZipError> {
175        self.write_const_many_iter(vs.iter(), vs.len())
176    }
177
178    // Note(alex):
179    // Parallelizing this greatly degrades performance rather than improving it.
180    // Maybe we should think of breakpoints for parallelization later.
181    pub fn write_const_many_iter<'a, T, I>(&mut self, vs: I, vs_len: usize) -> Result<(), ZipError>
182    where
183        T: ConstTranscribable + 'a,
184        I: IntoIterator<Item = &'a T>,
185    {
186        let prev_pos = safe_cast!(self.stream.position(), u64, usize)?;
187        let data_len = mul!(vs_len, T::NUM_BYTES);
188        let next_pos = add!(prev_pos, data_len);
189
190        let inner = self.stream.get_mut();
191        // Enlarge the inner buffer if needed
192        if inner.len() < next_pos {
193            inner.resize(next_pos, 0_u8);
194        }
195
196        inner[prev_pos..next_pos]
197            .chunks_mut(T::NUM_BYTES)
198            .zip(vs)
199            .for_each(|(chunk, v)| v.write_transcription_bytes_exact(chunk));
200
201        self.stream.set_position(next_pos as u64);
202        Ok(())
203    }
204
205    fn write_usize(&mut self, value: usize) -> Result<(), ZipError> {
206        let value_u64 = safe_cast!(value, usize, u64)?;
207        self.write(&value_u64)
208    }
209
210    pub fn write_merkle_proof(&mut self, proof: &MerkleProof) -> Result<(), ZipError> {
211        // Write the dimensions of matrix used to construct the Merkle tree
212        self.write_usize(proof.leaf_index)?;
213        self.write_usize(proof.leaf_count)?;
214
215        // Write the length of the merkle path first
216        self.write_usize(proof.siblings.len())?;
217
218        // Write each element of the merkle path
219        self.write_const_many(&proof.siblings)?;
220        Ok(())
221    }
222}
223
224/// Version of [[PcsProverTranscript]] used for proof verification.
225#[derive(Debug, Clone)]
226pub struct PcsVerifierTranscript {
227    /// Handles Fiat-Shamir transformations for non-interactive zero-knowledge
228    /// proofs. Used to absorb field elements and generate cryptographic
229    /// challenges.
230    pub fs_transcript: Blake3Transcript,
231
232    /// Manages serialization and deserialization of proof data as a byte
233    /// stream.
234    pub stream: Cursor<Vec<u8>>,
235}
236
237impl PcsVerifierTranscript {
238    common_methods!();
239
240    // Note: Currently this only works for fields whose modulus and inner element
241    // have the same byte length
242    pub fn read_field_elements<F>(&mut self, n: usize) -> Result<Vec<F>, ZipError>
243    where
244        F: PrimeField,
245        F::Inner: Transcribable,
246        F::Modulus: Transcribable,
247    {
248        if n > 0 {
249            debug_assert_eq!(F::Inner::LENGTH_NUM_BYTES, F::Modulus::LENGTH_NUM_BYTES);
250            let mut buf = vec![0; F::Inner::LENGTH_NUM_BYTES];
251            self.stream.read_exact(&mut buf)?;
252            let num_bytes = F::Inner::read_num_bytes(&buf);
253            debug_assert_eq!(num_bytes, F::Modulus::read_num_bytes(&buf));
254
255            let mut buf = vec![0; num_bytes];
256            (0..n)
257                .map(|_| self.read_field_element_no_length(&mut buf))
258                .collect::<Result<Vec<_>, _>>()
259        } else {
260            Ok(vec![])
261        }
262    }
263
264    /// Reads a field element from the proof stream and absorbs it into the
265    /// transcript. Used during proof verification to retrieve and process
266    /// field elements.
267    ///
268    /// Provided buffer must be of exact size of the field element.
269    fn read_field_element_no_length<F>(&mut self, buf: &mut [u8]) -> Result<F, ZipError>
270    where
271        F: PrimeField,
272        F::Inner: Transcribable,
273        F::Modulus: Transcribable,
274    {
275        self.stream.read_exact(buf)?;
276        let modulus = F::Modulus::read_transcription_bytes_exact(buf);
277        self.stream.read_exact(buf)?;
278        let inner = F::Inner::read_transcription_bytes_exact(buf);
279        let field_cfg = F::make_cfg(&modulus)?;
280        let fe = F::new_unchecked_with_cfg(inner, &field_cfg);
281        self.fs_transcript.absorb_random_field(&fe, buf);
282        Ok(fe)
283    }
284
285    pub fn read<T: Transcribable>(&mut self) -> Result<T, ZipError> {
286        let data_len = if T::LENGTH_NUM_BYTES > 0 {
287            let mut len_buf = vec![0u8; T::LENGTH_NUM_BYTES];
288            self.stream.read_exact(&mut len_buf)?;
289            T::read_num_bytes(&len_buf)
290        } else {
291            // LENGTH_NUM_BYTES == 0 means size is known at compile time via
292            // the ConstTranscribable blanket impl; read_num_bytes accepts an
293            // empty slice in that case.
294            T::read_num_bytes(&[])
295        };
296
297        read_stream_slice(&mut self.stream, data_len, |slice| {
298            Ok(T::read_transcription_bytes_exact(slice))
299        })
300    }
301
302    pub fn read_const_many<T: ConstTranscribable>(&mut self, n: usize) -> Result<Vec<T>, ZipError> {
303        read_stream_slice(&mut self.stream, mul!(n, T::NUM_BYTES), |slice| {
304            Ok(slice
305                .chunks(T::NUM_BYTES)
306                .map(T::read_transcription_bytes_exact)
307                .collect_vec())
308        })
309    }
310
311    fn read_usize(&mut self) -> Result<usize, ZipError> {
312        let value = self.read::<u64>()?;
313        safe_cast!(value, u64, usize)
314    }
315
316    pub fn read_merkle_proof(&mut self) -> Result<MerkleProof, ZipError> {
317        // Read the dimensions of matrix used to construct the Merkle tree
318        let leaf_index = self.read_usize()?;
319        let leaf_count = self.read_usize()?;
320
321        // Read the length of the merkle path first
322        let path_length = self.read_usize()?;
323
324        // Read each element of the merkle path
325        let merkle_path = self.read_const_many(path_length)?;
326
327        Ok(MerkleProof::new(leaf_index, leaf_count, merkle_path))
328    }
329}
330
331/// Perform a bounds-checked read from the stream for a length, and
332/// execute an action on the resulting slice. After the action is executed,
333/// advance the stream position by the length.
334#[inline]
335fn read_stream_slice<T>(
336    stream: &mut Cursor<Vec<u8>>,
337    length: usize,
338    action: impl Fn(&[u8]) -> Result<T, ZipError>,
339) -> Result<T, ZipError> {
340    let prev_pos = safe_cast!(stream.position(), u64, usize)?;
341    let next_pos = add!(prev_pos, length);
342
343    let stream_vec = stream.get_ref();
344    if next_pos > stream_vec.len() {
345        return Err(ZipError::Transcript(
346            ErrorKind::UnexpectedEof,
347            format!(
348                "Attempted to read beyond the end of the stream: {} + {} exceeds stream length {}",
349                prev_pos,
350                length,
351                stream_vec.len()
352            ),
353        ));
354    }
355    let res = action(&stream_vec[prev_pos..next_pos])?;
356    stream.set_position(safe_cast!(next_pos, usize, u64)?);
357    Ok(res)
358}
359
360// Do not expose this outside
361impl From<std::io::Error> for ZipError {
362    fn from(err: std::io::Error) -> Self {
363        ZipError::Transcript(err.kind(), err.to_string())
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use crate::merkle::MtHash;
371
372    #[allow(unused_macros)]
373    macro_rules! test_read_write {
374        // TODO: N is magic
375        ($write_fn:ident, $read_fn:ident, $original_value:expr, $assert_msg:expr) => {{
376            let comm = ZipPlusCommitment::default();
377            let mut transcript = PcsProverTranscript::new_from_commitment(&comm);
378            transcript
379                .$write_fn(&$original_value)
380                .expect(&format!("Failed to write {}", $assert_msg));
381            let mut transcript: PcsVerifierTranscript = transcript.into_verification_transcript();
382            transcript.fs_transcript.absorb_slice(&comm.root);
383            let read_value = transcript
384                .$read_fn()
385                .expect(&format!("Failed to read {}", $assert_msg));
386            assert_eq!(
387                $original_value, read_value,
388                "{} read does not match original",
389                $assert_msg
390            );
391        }};
392    }
393
394    #[allow(unused_macros)]
395    macro_rules! test_read_write_vec {
396        // TODO: N is magic
397        ($write_fn:ident, $read_fn:ident, $original_values:expr, $assert_msg:expr) => {{
398            let comm = ZipPlusCommitment::default();
399            let mut transcript = PcsProverTranscript::new_from_commitment(&comm);
400            transcript
401                .$write_fn(&$original_values)
402                .expect(&format!("Failed to write {}", $assert_msg));
403            let mut transcript: PcsVerifierTranscript = transcript.into_verification_transcript();
404            transcript.fs_transcript.absorb_slice(&comm.root);
405            let read_values = transcript
406                .$read_fn($original_values.len())
407                .expect(&format!("Failed to read {}", $assert_msg));
408            assert_eq!(
409                $original_values, read_values,
410                "{} read does not match original",
411                $assert_msg
412            );
413        }};
414    }
415
416    #[test]
417    fn test_pcs_transcript_read_write() {
418        // Test hash
419        let original_hash = MtHash::default();
420        test_read_write!(write, read, original_hash, "hash");
421
422        // Test vector of hashed
423        let original_hashes = vec![MtHash::default(); 1024];
424        test_read_write_vec!(
425            write_const_many,
426            read_const_many,
427            original_hashes,
428            "hashes vector"
429        );
430    }
431}