mega/easy/
reader.rs

1use crate::FileKey;
2use crate::FileValidationError;
3use crate::FileValidator;
4use cbc::cipher::KeyIvInit;
5use cbc::cipher::StreamCipher;
6use pin_project_lite::pin_project;
7use std::pin::Pin;
8use std::task::Context;
9use std::task::Poll;
10use std::task::ready;
11use tokio::io::AsyncRead;
12use tokio::io::ReadBuf;
13
14type Aes128Ctr128BE = ctr::Ctr128BE<aes::Aes128>;
15
16pin_project! {
17    /// A reader for a file.
18    pub struct FileDownloadReader<R> {
19        #[pin]
20        reader: R,
21        cipher: Aes128Ctr128BE,
22        validator: Option<FileValidator>,
23        validation_result: Option<Result<(), FileValidationError>>,
24    }
25}
26
27impl<R> FileDownloadReader<R> {
28    /// Make a new reader.
29    pub(crate) fn new(reader: R, file_key: &FileKey, validate: bool) -> Self {
30        let cipher = Aes128Ctr128BE::new(
31            &file_key.key.to_be_bytes().into(),
32            &file_key.iv.to_be_bytes().into(),
33        );
34        let validator = if validate {
35            Some(FileValidator::new(file_key.clone()))
36        } else {
37            None
38        };
39
40        Self {
41            reader,
42            cipher,
43            validator,
44            validation_result: None,
45        }
46    }
47}
48
49impl<R> AsyncRead for FileDownloadReader<R>
50where
51    R: AsyncRead,
52{
53    fn poll_read(
54        mut self: Pin<&mut Self>,
55        cx: &mut Context<'_>,
56        buf: &mut ReadBuf<'_>,
57    ) -> Poll<std::io::Result<()>> {
58        // See: https://users.rust-lang.org/t/blocking-permit/36865/5
59        const MAX_LEN: usize = 64 * 1024;
60
61        let this = self.as_mut().project();
62
63        // Limit max chunk processed at a time to avoid blocking.
64        let mut unfilled_buf = buf.take(MAX_LEN);
65
66        let result = ready!(this.reader.poll_read(cx, &mut unfilled_buf));
67        result?;
68
69        let new_bytes = unfilled_buf.filled_mut();
70        let new_bytes_len = new_bytes.len();
71        this.cipher.apply_keystream(new_bytes);
72        if let Some(validator) = this.validator.as_mut() {
73            if new_bytes_len == 0 {
74                let validation_result = match this.validation_result.clone() {
75                    Some(validation_result) => validation_result,
76                    None => {
77                        let validation_result = validator.finish();
78                        *this.validation_result = Some(validation_result.clone());
79                        validation_result
80                    }
81                };
82
83                validation_result.map_err(std::io::Error::other)?
84            } else {
85                validator.feed(new_bytes);
86            }
87        }
88        // Safety: This was already initialized via the unfilled_buf sub-buffer.
89        unsafe {
90            buf.assume_init(new_bytes_len);
91        }
92        buf.advance(new_bytes_len);
93
94        Poll::Ready(Ok(()))
95    }
96}