nd_async_rusqlite/wal_pool/
builder.rs

1use super::InnerWalPool;
2use super::Message;
3use super::WalPool;
4use crate::Error;
5use crate::SyncWrapper;
6use std::path::Path;
7use std::path::PathBuf;
8use std::sync::Arc;
9
10const DEFAULT_NUM_READ_CONNECTIONS: usize = 4;
11
12type ConnectionInitFn =
13    Arc<dyn Fn(&mut rusqlite::Connection) -> Result<(), Error> + Send + Sync + 'static>;
14
15/// A builder for a [`WalPool`].
16pub struct WalPoolBuilder {
17    /// The number of read connections
18    pub num_read_connections: usize,
19
20    /// A function to be called to initialize each reader.
21    pub reader_init_fn: Option<ConnectionInitFn>,
22
23    /// A function to be called to initialize the writer.
24    pub writer_init_fn: Option<ConnectionInitFn>,
25}
26
27impl WalPoolBuilder {
28    /// Make a new [`WalPoolBuilder`].
29    pub fn new() -> Self {
30        // TODO: Try to find some sane defaults experimentally.
31        Self {
32            num_read_connections: DEFAULT_NUM_READ_CONNECTIONS,
33
34            writer_init_fn: None,
35            reader_init_fn: None,
36        }
37    }
38
39    /// Set the number of read connections.
40    ///
41    /// This must be greater than 0.
42    pub fn num_read_connections(&mut self, num_read_connections: usize) -> &mut Self {
43        self.num_read_connections = num_read_connections;
44        self
45    }
46
47    /// Add a function to be called when the writer connection initializes.
48    pub fn writer_init_fn<F>(&mut self, writer_init_fn: F) -> &mut Self
49    where
50        F: Fn(&mut rusqlite::Connection) -> Result<(), Error> + Send + Sync + 'static,
51    {
52        self.writer_init_fn = Some(Arc::new(writer_init_fn));
53        self
54    }
55
56    /// Add a function to be called when a reader connection initializes.
57    pub fn reader_init_fn<F>(&mut self, reader_init_fn: F) -> &mut Self
58    where
59        F: Fn(&mut rusqlite::Connection) -> Result<(), Error> + Send + Sync + 'static,
60    {
61        self.reader_init_fn = Some(Arc::new(reader_init_fn));
62        self
63    }
64
65    /// Open the pool.
66    pub async fn open<P>(&self, path: P) -> Result<WalPool, Error>
67    where
68        P: AsRef<Path>,
69    {
70        let path = path.as_ref().to_path_buf();
71
72        // TODO: Validate these values are not 0.
73        let num_read_connections = self.num_read_connections;
74
75        // Only the writer can create the database, make sure it does so before doing anything else.
76        let writer_tx = {
77            let (writer_tx, writer_rx) = crossbeam_channel::unbounded::<Message>();
78            let path = path.clone();
79            let flags = rusqlite::OpenFlags::default();
80            let writer_init_fn = self.writer_init_fn.clone();
81            let (open_write_tx, open_write_rx) = tokio::sync::oneshot::channel();
82            std::thread::spawn(move || {
83                connection_thread_impl(writer_rx, path, flags, writer_init_fn, open_write_tx)
84            });
85
86            open_write_rx
87                .await
88                .map_err(|_| Error::Aborted)
89                .and_then(std::convert::identity)?;
90
91            writer_tx
92        };
93
94        // Bring reader connections up all at once for speed.
95        let (readers_tx, readers_rx) = crossbeam_channel::unbounded::<Message>();
96        let mut open_read_rx_list = Vec::with_capacity(num_read_connections);
97        for _ in 0..num_read_connections {
98            let readers_rx = readers_rx.clone();
99            let path = path.clone();
100
101            // We cannot allow writing in reader connections, forcibly set bits.
102            let mut flags = rusqlite::OpenFlags::default();
103            flags.remove(rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE);
104            flags.remove(rusqlite::OpenFlags::SQLITE_OPEN_CREATE);
105            flags.insert(rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY);
106
107            let reader_init_fn = self.reader_init_fn.clone();
108
109            let (open_read_tx, open_read_rx) = tokio::sync::oneshot::channel();
110            std::thread::spawn(move || {
111                connection_thread_impl(readers_rx, path, flags, reader_init_fn, open_read_tx)
112            });
113            open_read_rx_list.push(open_read_rx);
114        }
115        drop(readers_rx);
116
117        // Create the wal pool here.
118        // This lets us at least attempt to close it later.
119        let wal_pool = WalPool {
120            inner: Arc::new(InnerWalPool {
121                writer_tx,
122                readers_tx,
123            }),
124        };
125
126        let mut last_error = Ok(());
127
128        for open_read_rx in open_read_rx_list {
129            if let Err(error) = open_read_rx
130                .await
131                .map_err(|_| Error::Aborted)
132                .and_then(std::convert::identity)
133            {
134                last_error = Err(error);
135            }
136        }
137
138        if let Err(error) = last_error {
139            // At least try to bring it down nicely.
140            // We ignore the error, since the original error is much more important.
141            let _ = wal_pool.close().await.is_ok();
142
143            return Err(error);
144        }
145
146        Ok(wal_pool)
147    }
148}
149
150impl Default for WalPoolBuilder {
151    fn default() -> Self {
152        Self::new()
153    }
154}
155
156/// Set the journal_mode to WAL.
157///
158/// # References
159/// * https://www.sqlite.org/wal.html#activating_and_configuring_wal_mode
160fn set_wal_journal_mode(connection: &rusqlite::Connection) -> Result<(), Error> {
161    let journal_mode: String =
162        connection.pragma_update_and_check(None, "journal_mode", "WAL", |row| row.get(0))?;
163
164    if journal_mode != "wal" {
165        return Err(Error::InvalidJournalMode(journal_mode));
166    }
167
168    Ok(())
169}
170
171/// The impl for the connection background thread.
172fn connection_thread_impl(
173    rx: crossbeam_channel::Receiver<Message>,
174    path: PathBuf,
175    flags: rusqlite::OpenFlags,
176    init_fn: Option<ConnectionInitFn>,
177    connection_open_tx: tokio::sync::oneshot::Sender<Result<(), Error>>,
178) {
179    // Open the database, reporting errors as necessary.
180    let open_result = rusqlite::Connection::open_with_flags(path, flags);
181    let mut connection = match open_result {
182        Ok(connection) => connection,
183        Err(error) => {
184            // Don't care if we succed since we should exit in either case.
185            let _ = connection_open_tx.send(Err(Error::Rusqlite(error))).is_ok();
186            return;
187        }
188    };
189
190    // If WAL mode fails to enable, we should exit.
191    // This abstraction is fairly worthless outside of WAL mode.
192    if let Err(error) = set_wal_journal_mode(&connection) {
193        // Don't care if we succed since we should exit in either case.
194        let _ = connection_open_tx.send(Err(error)).is_ok();
195        return;
196    }
197
198    // Run init fn.
199    if let Some(init_fn) = init_fn {
200        let init_fn = std::panic::AssertUnwindSafe(|| init_fn(&mut connection));
201        let init_result = std::panic::catch_unwind(init_fn);
202        let init_result =
203            init_result.map_err(|panic_data| Error::AccessPanic(SyncWrapper::new(panic_data)));
204        if let Err(error) = init_result {
205            // Don't care if we succeed since we should exit in either case.
206            let _ = connection_open_tx.send(Err(error)).is_ok();
207            return;
208        }
209    }
210
211    // Check if the user cancelled the opening of the database connection and return early if needed.
212    if connection_open_tx.send(Ok(())).is_err() {
213        return;
214    }
215
216    let mut close_tx = None;
217    for message in rx.iter() {
218        match message {
219            Message::Close { tx } => {
220                close_tx = Some(tx);
221                break;
222            }
223            Message::Access { func } => {
224                func(&mut connection);
225            }
226        }
227    }
228
229    // Drop rx.
230    // This will abort all queued messages, dropping them without sending a response.
231    // This is considered aborting the request.
232    drop(rx);
233
234    let result = connection.close();
235    if let Some(tx) = close_tx {
236        let _ = tx
237            .send(result.map_err(|(_connection, error)| Error::from(error)))
238            .is_ok();
239    }
240}