nd_async_rusqlite/wal_pool/
builder.rs

1use super::InnerWalPool;
2use super::Message;
3use super::WalPool;
4use crate::Error;
5use crate::SyncWrapper;
6use std::path::Path;
7use std::path::PathBuf;
8use std::sync::Arc;
9
10const DEFAULT_READERS: usize = 4;
11
12type ConnectionSetupFn =
13    Arc<dyn Fn(&mut rusqlite::Connection) -> Result<(), Error> + Send + Sync + 'static>;
14
15/// A builder for a [`WalPool`].
16pub struct WalPoolBuilder {
17    /// The number of read connections
18    pub readers: usize,
19
20    /// A function to be called to initialize each reader.
21    pub reader_setup: Option<ConnectionSetupFn>,
22
23    /// A function to be called to initialize the writer.
24    pub writer_setup: Option<ConnectionSetupFn>,
25}
26
27impl WalPoolBuilder {
28    /// Make a new [`WalPoolBuilder`].
29    pub fn new() -> Self {
30        // TODO: Try to find some sane defaults experimentally.
31        Self {
32            readers: DEFAULT_READERS,
33
34            reader_setup: None,
35            writer_setup: None,
36        }
37    }
38
39    /// Set the number of read connections.
40    ///
41    /// This must be greater than 0.
42    pub fn readers(&mut self, readers: usize) -> &mut Self {
43        self.readers = readers;
44        self
45    }
46
47    /// Add a function to be called when a reader connection initializes.
48    pub fn reader_setup<F>(&mut self, reader_setup: F) -> &mut Self
49    where
50        F: Fn(&mut rusqlite::Connection) -> Result<(), Error> + Send + Sync + 'static,
51    {
52        self.reader_setup = Some(Arc::new(reader_setup));
53        self
54    }
55
56    /// Add a function to be called when the writer connection initializes.
57    pub fn writer_setup<F>(&mut self, writer_setup: F) -> &mut Self
58    where
59        F: Fn(&mut rusqlite::Connection) -> Result<(), Error> + Send + Sync + 'static,
60    {
61        self.writer_setup = Some(Arc::new(writer_setup));
62        self
63    }
64
65    /// Open the pool.
66    pub async fn open<P>(&self, path: P) -> Result<WalPool, Error>
67    where
68        P: AsRef<Path>,
69    {
70        let path = path.as_ref().to_path_buf();
71
72        let readers = self.readers;
73        if self.readers == 0 {
74            return Err(Error::Generic("`readers` cannot be 0"));
75        }
76
77        // Only the writer can create the database, make sure it does so before doing anything else.
78        let writer_tx = {
79            let (writer_tx, writer_rx) = crossbeam_channel::unbounded::<Message>();
80            let path = path.clone();
81            let flags = rusqlite::OpenFlags::default();
82            let writer_setup = self.writer_setup.clone();
83            let (open_write_tx, open_write_rx) = tokio::sync::oneshot::channel();
84            std::thread::spawn(move || {
85                connection_thread_impl(writer_rx, path, flags, writer_setup, open_write_tx)
86            });
87
88            open_write_rx
89                .await
90                .map_err(|_| Error::Aborted)
91                .and_then(std::convert::identity)?;
92
93            writer_tx
94        };
95
96        // Bring reader connections up all at once for speed.
97        let (readers_tx, readers_rx) = crossbeam_channel::unbounded::<Message>();
98        let mut open_read_rx_list = Vec::with_capacity(readers);
99        for _ in 0..readers {
100            let readers_rx = readers_rx.clone();
101            let path = path.clone();
102
103            // We cannot allow writing in reader connections, forcibly set bits.
104            let mut flags = rusqlite::OpenFlags::default();
105            flags.remove(rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE);
106            flags.remove(rusqlite::OpenFlags::SQLITE_OPEN_CREATE);
107            flags.insert(rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY);
108
109            let reader_setup = self.reader_setup.clone();
110
111            let (open_read_tx, open_read_rx) = tokio::sync::oneshot::channel();
112            std::thread::spawn(move || {
113                connection_thread_impl(readers_rx, path, flags, reader_setup, open_read_tx)
114            });
115            open_read_rx_list.push(open_read_rx);
116        }
117        drop(readers_rx);
118
119        // Create the wal pool here.
120        // This lets us at least attempt to close it later.
121        let wal_pool = WalPool {
122            inner: Arc::new(InnerWalPool {
123                writer_tx,
124                readers_tx,
125            }),
126        };
127
128        let mut last_error = Ok(());
129
130        for open_read_rx in open_read_rx_list {
131            if let Err(error) = open_read_rx
132                .await
133                .map_err(|_| Error::Aborted)
134                .and_then(std::convert::identity)
135            {
136                last_error = Err(error);
137            }
138        }
139
140        if let Err(error) = last_error {
141            // At least try to bring it down nicely.
142            // We ignore the error, since the original error is much more important.
143            let _ = wal_pool.close().await.is_ok();
144
145            return Err(error);
146        }
147
148        Ok(wal_pool)
149    }
150}
151
152impl Default for WalPoolBuilder {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158/// Set the journal_mode to WAL.
159///
160/// # References
161/// * https://www.sqlite.org/wal.html#activating_and_configuring_wal_mode
162fn set_wal_journal_mode(connection: &rusqlite::Connection) -> Result<(), Error> {
163    let journal_mode: String =
164        connection.pragma_update_and_check(None, "journal_mode", "WAL", |row| row.get(0))?;
165
166    if journal_mode != "wal" {
167        return Err(Error::InvalidJournalMode(journal_mode));
168    }
169
170    Ok(())
171}
172
173/// The impl for the connection background thread.
174fn connection_thread_impl(
175    rx: crossbeam_channel::Receiver<Message>,
176    path: PathBuf,
177    flags: rusqlite::OpenFlags,
178    init_fn: Option<ConnectionSetupFn>,
179    connection_open_tx: tokio::sync::oneshot::Sender<Result<(), Error>>,
180) {
181    // Open the database, reporting errors as necessary.
182    let open_result = rusqlite::Connection::open_with_flags(path, flags);
183    let mut connection = match open_result {
184        Ok(connection) => connection,
185        Err(error) => {
186            // Don't care if we succed since we should exit in either case.
187            let _ = connection_open_tx.send(Err(Error::Rusqlite(error))).is_ok();
188            return;
189        }
190    };
191
192    // If WAL mode fails to enable, we should exit.
193    // This abstraction is fairly worthless outside of WAL mode.
194    if let Err(error) = set_wal_journal_mode(&connection) {
195        // Don't care if we succed since we should exit in either case.
196        let _ = connection_open_tx.send(Err(error)).is_ok();
197        return;
198    }
199
200    // Run init fn.
201    if let Some(init_fn) = init_fn {
202        let init_fn = std::panic::AssertUnwindSafe(|| init_fn(&mut connection));
203        let init_result = std::panic::catch_unwind(init_fn);
204        let init_result =
205            init_result.map_err(|panic_data| Error::AccessPanic(SyncWrapper::new(panic_data)));
206        if let Err(error) = init_result {
207            // Don't care if we succeed since we should exit in either case.
208            let _ = connection_open_tx.send(Err(error)).is_ok();
209            return;
210        }
211    }
212
213    // Check if the user cancelled the opening of the database connection and return early if needed.
214    if connection_open_tx.send(Ok(())).is_err() {
215        return;
216    }
217
218    let mut close_tx = None;
219    for message in rx.iter() {
220        match message {
221            Message::Close { tx } => {
222                close_tx = Some(tx);
223                break;
224            }
225            Message::Access { func } => {
226                func(&mut connection);
227            }
228        }
229    }
230
231    // Drop rx.
232    // This will abort all queued messages, dropping them without sending a response.
233    // This is considered aborting the request.
234    drop(rx);
235
236    let result = connection.close();
237    if let Some(tx) = close_tx {
238        let _ = tx
239            .send(result.map_err(|(_connection, error)| Error::from(error)))
240            .is_ok();
241    }
242}