nd_async_rusqlite/
wal_pool.rs

1mod builder;
2
3pub use self::builder::WalPoolBuilder;
4use crate::Error;
5use crate::SyncWrapper;
6use std::sync::Arc;
7
8/// The channel message
9enum Message {
10    Access {
11        func: Box<dyn FnOnce(&mut rusqlite::Connection) + Send + 'static>,
12    },
13    Close {
14        tx: tokio::sync::oneshot::Sender<Result<(), Error>>,
15    },
16}
17
18/// A handle to a pool of connections, designed for a database in WAL mode.
19#[derive(Debug, Clone)]
20pub struct WalPool {
21    inner: Arc<InnerWalPool>,
22}
23
24impl WalPool {
25    /// Get a builder for a [`WalPool`].
26    pub fn builder() -> WalPoolBuilder {
27        WalPoolBuilder::new()
28    }
29
30    /// Close the pool.
31    ///
32    /// This will queue a close request to each thread.
33    /// When each thread processes the close request,
34    /// it will shut down.
35    /// When the last thread processed the close message,
36    /// all current queued requests will be aborted.
37    ///
38    /// When this function returns,
39    /// the pool will be closed no matter the value of the return.
40    /// The return value will return the last error that occured while closing.
41    pub async fn close(&self) -> Result<(), Error> {
42        let mut last_error = Ok(());
43
44        loop {
45            let (tx, rx) = tokio::sync::oneshot::channel();
46            let send_result = self.inner.readers_tx.send(Message::Close { tx });
47
48            if let Err(_error) = send_result {
49                // All receivers closed, we can stop now.
50                break;
51            }
52
53            let close_result = rx
54                .await
55                .map_err(|_| Error::Aborted)
56                .and_then(std::convert::identity);
57
58            if let Err(close_error) = close_result {
59                last_error = Err(close_error);
60            }
61        }
62
63        // Close the writer
64        let close_writer_result = async {
65            let (tx, rx) = tokio::sync::oneshot::channel();
66            self.inner
67                .writer_tx
68                .send(Message::Close { tx })
69                .map_err(|_| Error::Aborted)?;
70            rx.await.map_err(|_| Error::Aborted)??;
71            Ok(())
72        }
73        .await;
74
75        if let Err(close_writer_error) = close_writer_result {
76            last_error = Err(close_writer_error);
77        }
78
79        last_error
80    }
81
82    /// Access the database with a reader connection.
83    ///
84    /// Note that dropping the returned future will not cancel the database access.
85    pub async fn read<F, T>(&self, func: F) -> Result<T, Error>
86    where
87        F: FnOnce(&mut rusqlite::Connection) -> T + Send + 'static,
88        T: Send + 'static,
89    {
90        // TODO: We should make this a function and have it return a named Future.
91        // This will allow users to avoid spawning a seperate task for each database call.
92
93        let (tx, rx) = tokio::sync::oneshot::channel();
94        self.inner
95            .readers_tx
96            .send(Message::Access {
97                func: Box::new(move |connection| {
98                    // TODO: Consider aborting if rx hung up.
99
100                    let func = std::panic::AssertUnwindSafe(|| func(connection));
101                    let result = std::panic::catch_unwind(func);
102                    let result = result
103                        .map_err(|panic_data| Error::AccessPanic(SyncWrapper::new(panic_data)));
104                    let _ = tx.send(result).is_ok();
105                }),
106            })
107            .map_err(|_| Error::Aborted)?;
108        let result = rx.await.map_err(|_| Error::Aborted)??;
109
110        Ok(result)
111    }
112
113    /// Access the database with a writer connection.
114    ///
115    /// Note that dropping the returned future will not cancel the database access.
116    pub async fn write<F, T>(&self, func: F) -> Result<T, Error>
117    where
118        F: FnOnce(&mut rusqlite::Connection) -> T + Send + 'static,
119        T: Send + 'static,
120    {
121        // TODO: We should make this a function and have it return a named Future.
122        // This will allow users to avoid spawning a seperate task for each database call.
123
124        let (tx, rx) = tokio::sync::oneshot::channel();
125        self.inner
126            .writer_tx
127            .send(Message::Access {
128                func: Box::new(move |connection| {
129                    // TODO: Consider aborting if rx hung up.
130
131                    let func = std::panic::AssertUnwindSafe(|| func(connection));
132                    let result = std::panic::catch_unwind(func);
133                    let result = result
134                        .map_err(|panic_data| Error::AccessPanic(SyncWrapper::new(panic_data)));
135                    let _ = tx.send(result).is_ok();
136                }),
137            })
138            .map_err(|_| Error::Aborted)?;
139        let result = rx.await.map_err(|_| Error::Aborted)??;
140        Ok(result)
141    }
142}
143
144/// The inner wal pool
145#[derive(Debug)]
146struct InnerWalPool {
147    writer_tx: crossbeam_channel::Sender<Message>,
148    readers_tx: crossbeam_channel::Sender<Message>,
149}
150
151#[cfg(test)]
152mod test {
153    use super::*;
154    use std::path::Path;
155    use std::sync::atomic::AtomicBool;
156    use std::sync::atomic::AtomicUsize;
157    use std::sync::atomic::Ordering;
158
159    #[tokio::test]
160    async fn dir() {
161        let connection_error = WalPool::builder()
162            .open(".")
163            .await
164            .expect_err("pool should not open on a directory");
165        assert!(matches!(connection_error, Error::Rusqlite(_)));
166    }
167
168    #[tokio::test]
169    async fn sanity() {
170        let temp_path = Path::new("test-temp");
171        std::fs::create_dir_all(temp_path).expect("failed to create temp dir");
172
173        let connection_path = temp_path.join("wal-pool-sanity.db");
174        match std::fs::remove_file(&connection_path) {
175            Ok(()) => {}
176            Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
177            Err(error) => {
178                panic!("failed to remove old database: {error:?}");
179            }
180        }
181
182        let writer_init_fn_called = Arc::new(AtomicBool::new(false));
183        let num_reader_init_fn_called = Arc::new(AtomicUsize::new(0));
184        let num_read_connections = 4;
185        let connection = {
186            let writer_init_fn_called = writer_init_fn_called.clone();
187            let num_reader_init_fn_called = num_reader_init_fn_called.clone();
188            WalPool::builder()
189                .num_read_connections(num_read_connections)
190                .writer_init_fn(move |_connection| {
191                    writer_init_fn_called.store(true, Ordering::SeqCst);
192                    Ok(())
193                })
194                .reader_init_fn(move |_connection| {
195                    num_reader_init_fn_called.fetch_add(1, Ordering::SeqCst);
196                    Ok(())
197                })
198                .open(&connection_path)
199                .await
200                .expect("connection should be open")
201        };
202        let writer_init_fn_called = writer_init_fn_called.load(Ordering::SeqCst);
203        let num_reader_init_fn_called = num_reader_init_fn_called.load(Ordering::SeqCst);
204
205        assert!(writer_init_fn_called);
206        assert!(num_reader_init_fn_called == num_read_connections);
207
208        // Ensure connection is clone
209        let _connection1 = connection.clone();
210
211        // Ensure write connection survives panic
212        let panic_error = connection
213            .write(|_connection| panic!("the connection should survive the panic"))
214            .await
215            .expect_err("the access should have failed");
216
217        assert!(matches!(panic_error, Error::AccessPanic(_)));
218
219        let setup_sql = "PRAGMA foreign_keys = ON; CREATE TABLE USERS (id INTEGER PRIMARY KEY, first_name TEXT NOT NULL, last_name TEXT NOT NULL) STRICT;";
220        connection
221            .write(|connection| connection.execute_batch(setup_sql))
222            .await
223            .expect("failed to create tables")
224            .expect("failed to execute");
225
226        // Reader should not be able to write
227        connection
228            .read(|connection| connection.execute_batch(setup_sql))
229            .await
230            .expect("failed to access")
231            .expect_err("write should have failed");
232
233        connection
234            .close()
235            .await
236            .expect("an error occured while closing");
237    }
238
239    #[tokio::test]
240    async fn init_fn_panics() {
241        let temp_path = Path::new("test-temp");
242        let connection_path = temp_path.join("wal-pool-init_fn_panics.db");
243
244        WalPool::builder()
245            .writer_init_fn(move |_connection| {
246                panic!("user panic");
247            })
248            .open(&connection_path)
249            .await
250            .expect_err("panic should become an error");
251
252        WalPool::builder()
253            .reader_init_fn(move |_connection| {
254                panic!("user panic");
255            })
256            .open(&connection_path)
257            .await
258            .expect_err("panic should become an error");
259    }
260}