1use crate::{ZipError, merkle::MerkleProof, pcs::structs::ZipPlusCommitment};
2use crypto_primitives::PrimeField;
3use itertools::Itertools;
4use std::io::{Cursor, ErrorKind, Read, Write};
5use zinc_transcript::{
6 Blake3Transcript,
7 traits::{ConstTranscribable, GenTranscribable, Transcribable, Transcript},
8};
9use zinc_utils::{add, mul, rem};
10
11macro_rules! safe_cast {
12 ($value:expr, $from:ident, $to:ident) => {
13 $to::try_from($value).map_err(|_err| {
14 ZipError::Transcript(
15 ErrorKind::Unsupported,
16 format!(
17 "Failed to convert {} to {}",
18 stringify!($from),
19 stringify!($to)
20 ),
21 )
22 })
23 };
24}
25
26macro_rules! common_methods {
27 () => {
28 #[allow(clippy::unwrap_used)]
32 pub fn squeeze_challenge_idx(&mut self, cap: usize) -> usize {
33 let num = safe_cast!(self.fs_transcript.get_challenge::<u32>(), u32, usize)
34 .expect("Conversion from u32 to usize should never fail");
35 rem!(num, cap, "Challenge cap is zero")
36 }
37 };
38}
39
40#[derive(Debug, Clone)]
44pub struct PcsProverTranscript {
45 pub fs_transcript: Blake3Transcript,
49
50 pub stream: Cursor<Vec<u8>>,
53}
54
55impl PcsProverTranscript {
58 pub fn new_from_commitment(comm: &ZipPlusCommitment) -> Self {
59 Self::new_from_commitments(std::slice::from_ref(comm).iter())
60 }
61
62 pub fn new_from_commitments<'a>(comms: impl Iterator<Item = &'a ZipPlusCommitment>) -> Self {
63 let mut result = Self {
64 fs_transcript: Blake3Transcript::default(),
65 stream: Cursor::default(),
66 };
67
68 for comm in comms {
69 result.fs_transcript.absorb_slice(&comm.root);
70 }
71
72 result
73 }
74
75 pub fn reserve_capacity(&mut self, additional_capacity: usize) {
76 self.stream.get_mut().reserve(additional_capacity)
77 }
78
79 pub fn into_verification_transcript(self) -> PcsVerifierTranscript {
84 let mut result = PcsVerifierTranscript {
85 fs_transcript: Blake3Transcript::default(),
86 stream: self.stream,
87 };
88 result.stream.set_position(0);
89
90 result
91 }
92
93 common_methods!();
94
95 pub fn write_field_elements<F>(&mut self, elems: &[F]) -> Result<(), ZipError>
100 where
101 F: PrimeField,
102 F::Inner: Transcribable,
103 F::Modulus: Transcribable,
104 {
105 if !elems.is_empty() {
106 debug_assert_eq!(F::Inner::LENGTH_NUM_BYTES, F::Modulus::LENGTH_NUM_BYTES);
107 let num_bytes = F::Inner::get_num_bytes(elems[0].inner());
108 debug_assert_eq!(num_bytes, F::Modulus::get_num_bytes(&elems[0].modulus()));
109 let num_bytes_arr = num_bytes
110 .to_le_bytes()
111 .into_iter()
112 .take(F::Inner::LENGTH_NUM_BYTES)
113 .collect_vec();
114 self.stream.write_all(&num_bytes_arr)?;
115
116 let mut buf = vec![0; num_bytes];
117 for elem in elems {
118 self.write_field_element_no_length(elem, &mut buf)?;
119 }
120 }
121
122 Ok(())
123 }
124
125 fn write_field_element_no_length<F>(&mut self, fe: &F, buf: &mut [u8]) -> Result<(), ZipError>
131 where
132 F: PrimeField,
133 F::Inner: Transcribable,
134 F::Modulus: Transcribable,
135 {
136 self.fs_transcript.absorb_random_field(fe, buf);
137 fe.modulus().write_transcription_bytes_exact(buf);
138 self.stream.write_all(buf)?;
139 fe.inner().write_transcription_bytes_exact(buf);
140 self.stream.write_all(buf)?;
141 Ok(())
142 }
143
144 pub fn write<T: Transcribable>(&mut self, v: &T) -> Result<(), ZipError> {
145 let data_len = v.get_num_bytes();
146
147 if T::LENGTH_NUM_BYTES > 0 {
149 let len_bytes = data_len
150 .to_le_bytes()
151 .into_iter()
152 .take(T::LENGTH_NUM_BYTES)
153 .collect_vec();
154 self.stream.write_all(&len_bytes)?;
155 }
156
157 let prev_pos = safe_cast!(self.stream.position(), u64, usize)?;
158 let next_pos = add!(prev_pos, data_len);
159
160 let inner = self.stream.get_mut();
161 if inner.len() < next_pos {
162 inner.resize(next_pos, 0_u8);
163 }
164
165 v.write_transcription_bytes_exact(&mut inner[prev_pos..next_pos]);
166
167 self.stream.set_position(safe_cast!(next_pos, usize, u64)?);
168 Ok(())
169 }
170
171 pub fn write_const_many<T: ConstTranscribable>(&mut self, vs: &[T]) -> Result<(), ZipError> {
175 self.write_const_many_iter(vs.iter(), vs.len())
176 }
177
178 pub fn write_const_many_iter<'a, T, I>(&mut self, vs: I, vs_len: usize) -> Result<(), ZipError>
182 where
183 T: ConstTranscribable + 'a,
184 I: IntoIterator<Item = &'a T>,
185 {
186 let prev_pos = safe_cast!(self.stream.position(), u64, usize)?;
187 let data_len = mul!(vs_len, T::NUM_BYTES);
188 let next_pos = add!(prev_pos, data_len);
189
190 let inner = self.stream.get_mut();
191 if inner.len() < next_pos {
193 inner.resize(next_pos, 0_u8);
194 }
195
196 inner[prev_pos..next_pos]
197 .chunks_mut(T::NUM_BYTES)
198 .zip(vs)
199 .for_each(|(chunk, v)| v.write_transcription_bytes_exact(chunk));
200
201 self.stream.set_position(next_pos as u64);
202 Ok(())
203 }
204
205 fn write_usize(&mut self, value: usize) -> Result<(), ZipError> {
206 let value_u64 = safe_cast!(value, usize, u64)?;
207 self.write(&value_u64)
208 }
209
210 pub fn write_merkle_proof(&mut self, proof: &MerkleProof) -> Result<(), ZipError> {
211 self.write_usize(proof.leaf_index)?;
213 self.write_usize(proof.leaf_count)?;
214
215 self.write_usize(proof.siblings.len())?;
217
218 self.write_const_many(&proof.siblings)?;
220 Ok(())
221 }
222}
223
224#[derive(Debug, Clone)]
226pub struct PcsVerifierTranscript {
227 pub fs_transcript: Blake3Transcript,
231
232 pub stream: Cursor<Vec<u8>>,
235}
236
237impl PcsVerifierTranscript {
238 common_methods!();
239
240 pub fn read_field_elements<F>(&mut self, n: usize) -> Result<Vec<F>, ZipError>
243 where
244 F: PrimeField,
245 F::Inner: Transcribable,
246 F::Modulus: Transcribable,
247 {
248 if n > 0 {
249 debug_assert_eq!(F::Inner::LENGTH_NUM_BYTES, F::Modulus::LENGTH_NUM_BYTES);
250 let mut buf = vec![0; F::Inner::LENGTH_NUM_BYTES];
251 self.stream.read_exact(&mut buf)?;
252 let num_bytes = F::Inner::read_num_bytes(&buf);
253 debug_assert_eq!(num_bytes, F::Modulus::read_num_bytes(&buf));
254
255 let mut buf = vec![0; num_bytes];
256 (0..n)
257 .map(|_| self.read_field_element_no_length(&mut buf))
258 .collect::<Result<Vec<_>, _>>()
259 } else {
260 Ok(vec![])
261 }
262 }
263
264 fn read_field_element_no_length<F>(&mut self, buf: &mut [u8]) -> Result<F, ZipError>
270 where
271 F: PrimeField,
272 F::Inner: Transcribable,
273 F::Modulus: Transcribable,
274 {
275 self.stream.read_exact(buf)?;
276 let modulus = F::Modulus::read_transcription_bytes_exact(buf);
277 self.stream.read_exact(buf)?;
278 let inner = F::Inner::read_transcription_bytes_exact(buf);
279 let field_cfg = F::make_cfg(&modulus)?;
280 let fe = F::new_unchecked_with_cfg(inner, &field_cfg);
281 self.fs_transcript.absorb_random_field(&fe, buf);
282 Ok(fe)
283 }
284
285 pub fn read<T: Transcribable>(&mut self) -> Result<T, ZipError> {
286 let data_len = if T::LENGTH_NUM_BYTES > 0 {
287 let mut len_buf = vec![0u8; T::LENGTH_NUM_BYTES];
288 self.stream.read_exact(&mut len_buf)?;
289 T::read_num_bytes(&len_buf)
290 } else {
291 T::read_num_bytes(&[])
295 };
296
297 read_stream_slice(&mut self.stream, data_len, |slice| {
298 Ok(T::read_transcription_bytes_exact(slice))
299 })
300 }
301
302 pub fn read_const_many<T: ConstTranscribable>(&mut self, n: usize) -> Result<Vec<T>, ZipError> {
303 read_stream_slice(&mut self.stream, mul!(n, T::NUM_BYTES), |slice| {
304 Ok(slice
305 .chunks(T::NUM_BYTES)
306 .map(T::read_transcription_bytes_exact)
307 .collect_vec())
308 })
309 }
310
311 fn read_usize(&mut self) -> Result<usize, ZipError> {
312 let value = self.read::<u64>()?;
313 safe_cast!(value, u64, usize)
314 }
315
316 pub fn read_merkle_proof(&mut self) -> Result<MerkleProof, ZipError> {
317 let leaf_index = self.read_usize()?;
319 let leaf_count = self.read_usize()?;
320
321 let path_length = self.read_usize()?;
323
324 let merkle_path = self.read_const_many(path_length)?;
326
327 Ok(MerkleProof::new(leaf_index, leaf_count, merkle_path))
328 }
329}
330
331#[inline]
335fn read_stream_slice<T>(
336 stream: &mut Cursor<Vec<u8>>,
337 length: usize,
338 action: impl Fn(&[u8]) -> Result<T, ZipError>,
339) -> Result<T, ZipError> {
340 let prev_pos = safe_cast!(stream.position(), u64, usize)?;
341 let next_pos = add!(prev_pos, length);
342
343 let stream_vec = stream.get_ref();
344 if next_pos > stream_vec.len() {
345 return Err(ZipError::Transcript(
346 ErrorKind::UnexpectedEof,
347 format!(
348 "Attempted to read beyond the end of the stream: {} + {} exceeds stream length {}",
349 prev_pos,
350 length,
351 stream_vec.len()
352 ),
353 ));
354 }
355 let res = action(&stream_vec[prev_pos..next_pos])?;
356 stream.set_position(safe_cast!(next_pos, usize, u64)?);
357 Ok(res)
358}
359
360impl From<std::io::Error> for ZipError {
362 fn from(err: std::io::Error) -> Self {
363 ZipError::Transcript(err.kind(), err.to_string())
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use crate::merkle::MtHash;
371
372 #[allow(unused_macros)]
373 macro_rules! test_read_write {
374 ($write_fn:ident, $read_fn:ident, $original_value:expr, $assert_msg:expr) => {{
376 let comm = ZipPlusCommitment::default();
377 let mut transcript = PcsProverTranscript::new_from_commitment(&comm);
378 transcript
379 .$write_fn(&$original_value)
380 .expect(&format!("Failed to write {}", $assert_msg));
381 let mut transcript: PcsVerifierTranscript = transcript.into_verification_transcript();
382 transcript.fs_transcript.absorb_slice(&comm.root);
383 let read_value = transcript
384 .$read_fn()
385 .expect(&format!("Failed to read {}", $assert_msg));
386 assert_eq!(
387 $original_value, read_value,
388 "{} read does not match original",
389 $assert_msg
390 );
391 }};
392 }
393
394 #[allow(unused_macros)]
395 macro_rules! test_read_write_vec {
396 ($write_fn:ident, $read_fn:ident, $original_values:expr, $assert_msg:expr) => {{
398 let comm = ZipPlusCommitment::default();
399 let mut transcript = PcsProverTranscript::new_from_commitment(&comm);
400 transcript
401 .$write_fn(&$original_values)
402 .expect(&format!("Failed to write {}", $assert_msg));
403 let mut transcript: PcsVerifierTranscript = transcript.into_verification_transcript();
404 transcript.fs_transcript.absorb_slice(&comm.root);
405 let read_values = transcript
406 .$read_fn($original_values.len())
407 .expect(&format!("Failed to read {}", $assert_msg));
408 assert_eq!(
409 $original_values, read_values,
410 "{} read does not match original",
411 $assert_msg
412 );
413 }};
414 }
415
416 #[test]
417 fn test_pcs_transcript_read_write() {
418 let original_hash = MtHash::default();
420 test_read_write!(write, read, original_hash, "hash");
421
422 let original_hashes = vec![MtHash::default(); 1024];
424 test_read_write_vec!(
425 write_const_many,
426 read_const_many,
427 original_hashes,
428 "hashes vector"
429 );
430 }
431}