nd_async_rusqlite/
wal_pool.rs1mod builder;
2
3pub use self::builder::WalPoolBuilder;
4use crate::Error;
5use crate::SyncWrapper;
6use std::sync::Arc;
7
8enum Message {
10 Access {
11 func: Box<dyn FnOnce(&mut rusqlite::Connection) + Send + 'static>,
12 },
13 Close {
14 tx: tokio::sync::oneshot::Sender<Result<(), Error>>,
15 },
16}
17
18#[derive(Debug, Clone)]
20pub struct WalPool {
21 inner: Arc<InnerWalPool>,
22}
23
24impl WalPool {
25 pub fn builder() -> WalPoolBuilder {
27 WalPoolBuilder::new()
28 }
29
30 pub async fn close(&self) -> Result<(), Error> {
42 let mut last_error = Ok(());
43
44 loop {
45 let (tx, rx) = tokio::sync::oneshot::channel();
46 let send_result = self.inner.readers_tx.send(Message::Close { tx });
47
48 if let Err(_error) = send_result {
49 break;
51 }
52
53 let close_result = rx
54 .await
55 .map_err(|_| Error::Aborted)
56 .and_then(std::convert::identity);
57
58 if let Err(close_error) = close_result {
59 last_error = Err(close_error);
60 }
61 }
62
63 let close_writer_result = async {
65 let (tx, rx) = tokio::sync::oneshot::channel();
66 self.inner
67 .writer_tx
68 .send(Message::Close { tx })
69 .map_err(|_| Error::Aborted)?;
70 rx.await.map_err(|_| Error::Aborted)??;
71 Ok(())
72 }
73 .await;
74
75 if let Err(close_writer_error) = close_writer_result {
76 last_error = Err(close_writer_error);
77 }
78
79 last_error
80 }
81
82 pub async fn read<F, T>(&self, func: F) -> Result<T, Error>
86 where
87 F: FnOnce(&mut rusqlite::Connection) -> T + Send + 'static,
88 T: Send + 'static,
89 {
90 let (tx, rx) = tokio::sync::oneshot::channel();
94 self.inner
95 .readers_tx
96 .send(Message::Access {
97 func: Box::new(move |connection| {
98 let func = std::panic::AssertUnwindSafe(|| func(connection));
101 let result = std::panic::catch_unwind(func);
102 let result = result
103 .map_err(|panic_data| Error::AccessPanic(SyncWrapper::new(panic_data)));
104 let _ = tx.send(result).is_ok();
105 }),
106 })
107 .map_err(|_| Error::Aborted)?;
108 let result = rx.await.map_err(|_| Error::Aborted)??;
109
110 Ok(result)
111 }
112
113 pub async fn write<F, T>(&self, func: F) -> Result<T, Error>
117 where
118 F: FnOnce(&mut rusqlite::Connection) -> T + Send + 'static,
119 T: Send + 'static,
120 {
121 let (tx, rx) = tokio::sync::oneshot::channel();
125 self.inner
126 .writer_tx
127 .send(Message::Access {
128 func: Box::new(move |connection| {
129 let func = std::panic::AssertUnwindSafe(|| func(connection));
132 let result = std::panic::catch_unwind(func);
133 let result = result
134 .map_err(|panic_data| Error::AccessPanic(SyncWrapper::new(panic_data)));
135 let _ = tx.send(result).is_ok();
136 }),
137 })
138 .map_err(|_| Error::Aborted)?;
139 let result = rx.await.map_err(|_| Error::Aborted)??;
140 Ok(result)
141 }
142}
143
144#[derive(Debug)]
146struct InnerWalPool {
147 writer_tx: crossbeam_channel::Sender<Message>,
148 readers_tx: crossbeam_channel::Sender<Message>,
149}
150
151#[cfg(test)]
152mod test {
153 use super::*;
154 use std::path::Path;
155 use std::sync::atomic::AtomicBool;
156 use std::sync::atomic::AtomicUsize;
157 use std::sync::atomic::Ordering;
158
159 #[tokio::test]
160 async fn dir() {
161 let connection_error = WalPool::builder()
162 .open(".")
163 .await
164 .expect_err("pool should not open on a directory");
165 assert!(matches!(connection_error, Error::Rusqlite(_)));
166 }
167
168 #[tokio::test]
169 async fn sanity() {
170 let temp_path = Path::new("test-temp");
171 std::fs::create_dir_all(temp_path).expect("failed to create temp dir");
172
173 let connection_path = temp_path.join("wal-pool-sanity.db");
174 match std::fs::remove_file(&connection_path) {
175 Ok(()) => {}
176 Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
177 Err(error) => {
178 panic!("failed to remove old database: {error:?}");
179 }
180 }
181
182 let writer_init_fn_called = Arc::new(AtomicBool::new(false));
183 let num_reader_init_fn_called = Arc::new(AtomicUsize::new(0));
184 let num_read_connections = 4;
185 let connection = {
186 let writer_init_fn_called = writer_init_fn_called.clone();
187 let num_reader_init_fn_called = num_reader_init_fn_called.clone();
188 WalPool::builder()
189 .num_read_connections(num_read_connections)
190 .writer_init_fn(move |_connection| {
191 writer_init_fn_called.store(true, Ordering::SeqCst);
192 Ok(())
193 })
194 .reader_init_fn(move |_connection| {
195 num_reader_init_fn_called.fetch_add(1, Ordering::SeqCst);
196 Ok(())
197 })
198 .open(&connection_path)
199 .await
200 .expect("connection should be open")
201 };
202 let writer_init_fn_called = writer_init_fn_called.load(Ordering::SeqCst);
203 let num_reader_init_fn_called = num_reader_init_fn_called.load(Ordering::SeqCst);
204
205 assert!(writer_init_fn_called);
206 assert!(num_reader_init_fn_called == num_read_connections);
207
208 let _connection1 = connection.clone();
210
211 let panic_error = connection
213 .write(|_connection| panic!("the connection should survive the panic"))
214 .await
215 .expect_err("the access should have failed");
216
217 assert!(matches!(panic_error, Error::AccessPanic(_)));
218
219 let setup_sql = "PRAGMA foreign_keys = ON; CREATE TABLE USERS (id INTEGER PRIMARY KEY, first_name TEXT NOT NULL, last_name TEXT NOT NULL) STRICT;";
220 connection
221 .write(|connection| connection.execute_batch(setup_sql))
222 .await
223 .expect("failed to create tables")
224 .expect("failed to execute");
225
226 connection
228 .read(|connection| connection.execute_batch(setup_sql))
229 .await
230 .expect("failed to access")
231 .expect_err("write should have failed");
232
233 connection
234 .close()
235 .await
236 .expect("an error occured while closing");
237 }
238
239 #[tokio::test]
240 async fn init_fn_panics() {
241 let temp_path = Path::new("test-temp");
242 let connection_path = temp_path.join("wal-pool-init_fn_panics.db");
243
244 WalPool::builder()
245 .writer_init_fn(move |_connection| {
246 panic!("user panic");
247 })
248 .open(&connection_path)
249 .await
250 .expect_err("panic should become an error");
251
252 WalPool::builder()
253 .reader_init_fn(move |_connection| {
254 panic!("user panic");
255 })
256 .open(&connection_path)
257 .await
258 .expect_err("panic should become an error");
259 }
260}