1use crate::FileKey;
2use crate::FileValidationError;
3use crate::FileValidator;
4use cbc::cipher::KeyIvInit;
5use cbc::cipher::StreamCipher;
6use pin_project_lite::pin_project;
7use std::pin::Pin;
8use std::task::Context;
9use std::task::Poll;
10use std::task::ready;
11use tokio::io::AsyncRead;
12use tokio::io::ReadBuf;
13
14type Aes128Ctr128BE = ctr::Ctr128BE<aes::Aes128>;
15
16pin_project! {
17 pub struct FileDownloadReader<R> {
19 #[pin]
20 reader: R,
21 cipher: Aes128Ctr128BE,
22 validator: Option<FileValidator>,
23 validation_result: Option<Result<(), FileValidationError>>,
24 }
25}
26
27impl<R> FileDownloadReader<R> {
28 pub(crate) fn new(reader: R, file_key: &FileKey, validate: bool) -> Self {
30 let cipher = Aes128Ctr128BE::new(
31 &file_key.key.to_be_bytes().into(),
32 &file_key.iv.to_be_bytes().into(),
33 );
34 let validator = if validate {
35 Some(FileValidator::new(file_key.clone()))
36 } else {
37 None
38 };
39
40 Self {
41 reader,
42 cipher,
43 validator,
44 validation_result: None,
45 }
46 }
47}
48
49impl<R> AsyncRead for FileDownloadReader<R>
50where
51 R: AsyncRead,
52{
53 fn poll_read(
54 mut self: Pin<&mut Self>,
55 cx: &mut Context<'_>,
56 buf: &mut ReadBuf<'_>,
57 ) -> Poll<std::io::Result<()>> {
58 const MAX_LEN: usize = 64 * 1024;
60
61 let this = self.as_mut().project();
62
63 let mut unfilled_buf = buf.take(MAX_LEN);
65
66 let result = ready!(this.reader.poll_read(cx, &mut unfilled_buf));
67 result?;
68
69 let new_bytes = unfilled_buf.filled_mut();
70 let new_bytes_len = new_bytes.len();
71 this.cipher.apply_keystream(new_bytes);
72 if let Some(validator) = this.validator.as_mut() {
73 if new_bytes_len == 0 {
74 let validation_result = match this.validation_result.clone() {
75 Some(validation_result) => validation_result,
76 None => {
77 let validation_result = validator.finish();
78 *this.validation_result = Some(validation_result.clone());
79 validation_result
80 }
81 };
82
83 validation_result.map_err(std::io::Error::other)?
84 } else {
85 validator.feed(new_bytes);
86 }
87 }
88 unsafe {
90 buf.assume_init(new_bytes_len);
91 }
92 buf.advance(new_bytes_len);
93
94 Poll::Ready(Ok(()))
95 }
96}