nd_async_rusqlite/wal_pool/
builder.rs1use super::InnerWalPool;
2use super::Message;
3use super::WalPool;
4use crate::Error;
5use crate::SyncWrapper;
6use std::path::Path;
7use std::path::PathBuf;
8use std::sync::Arc;
9
10const DEFAULT_READERS: usize = 4;
11
12type ConnectionSetupFn =
13 Arc<dyn Fn(&mut rusqlite::Connection) -> Result<(), Error> + Send + Sync + 'static>;
14
15pub struct WalPoolBuilder {
17 pub readers: usize,
19
20 pub reader_setup: Option<ConnectionSetupFn>,
22
23 pub writer_setup: Option<ConnectionSetupFn>,
25}
26
27impl WalPoolBuilder {
28 pub fn new() -> Self {
30 Self {
32 readers: DEFAULT_READERS,
33
34 reader_setup: None,
35 writer_setup: None,
36 }
37 }
38
39 pub fn readers(&mut self, readers: usize) -> &mut Self {
43 self.readers = readers;
44 self
45 }
46
47 pub fn reader_setup<F>(&mut self, reader_setup: F) -> &mut Self
49 where
50 F: Fn(&mut rusqlite::Connection) -> Result<(), Error> + Send + Sync + 'static,
51 {
52 self.reader_setup = Some(Arc::new(reader_setup));
53 self
54 }
55
56 pub fn writer_setup<F>(&mut self, writer_setup: F) -> &mut Self
58 where
59 F: Fn(&mut rusqlite::Connection) -> Result<(), Error> + Send + Sync + 'static,
60 {
61 self.writer_setup = Some(Arc::new(writer_setup));
62 self
63 }
64
65 pub async fn open<P>(&self, path: P) -> Result<WalPool, Error>
67 where
68 P: AsRef<Path>,
69 {
70 let path = path.as_ref().to_path_buf();
71
72 let readers = self.readers;
73 if self.readers == 0 {
74 return Err(Error::Generic("`readers` cannot be 0"));
75 }
76
77 let writer_tx = {
79 let (writer_tx, writer_rx) = crossbeam_channel::unbounded::<Message>();
80 let path = path.clone();
81 let flags = rusqlite::OpenFlags::default();
82 let writer_setup = self.writer_setup.clone();
83 let (open_write_tx, open_write_rx) = tokio::sync::oneshot::channel();
84 std::thread::spawn(move || {
85 connection_thread_impl(writer_rx, path, flags, writer_setup, open_write_tx)
86 });
87
88 open_write_rx
89 .await
90 .map_err(|_| Error::Aborted)
91 .and_then(std::convert::identity)?;
92
93 writer_tx
94 };
95
96 let (readers_tx, readers_rx) = crossbeam_channel::unbounded::<Message>();
98 let mut open_read_rx_list = Vec::with_capacity(readers);
99 for _ in 0..readers {
100 let readers_rx = readers_rx.clone();
101 let path = path.clone();
102
103 let mut flags = rusqlite::OpenFlags::default();
105 flags.remove(rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE);
106 flags.remove(rusqlite::OpenFlags::SQLITE_OPEN_CREATE);
107 flags.insert(rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY);
108
109 let reader_setup = self.reader_setup.clone();
110
111 let (open_read_tx, open_read_rx) = tokio::sync::oneshot::channel();
112 std::thread::spawn(move || {
113 connection_thread_impl(readers_rx, path, flags, reader_setup, open_read_tx)
114 });
115 open_read_rx_list.push(open_read_rx);
116 }
117 drop(readers_rx);
118
119 let wal_pool = WalPool {
122 inner: Arc::new(InnerWalPool {
123 writer_tx,
124 readers_tx,
125 }),
126 };
127
128 let mut last_error = Ok(());
129
130 for open_read_rx in open_read_rx_list {
131 if let Err(error) = open_read_rx
132 .await
133 .map_err(|_| Error::Aborted)
134 .and_then(std::convert::identity)
135 {
136 last_error = Err(error);
137 }
138 }
139
140 if let Err(error) = last_error {
141 let _ = wal_pool.close().await.is_ok();
144
145 return Err(error);
146 }
147
148 Ok(wal_pool)
149 }
150}
151
152impl Default for WalPoolBuilder {
153 fn default() -> Self {
154 Self::new()
155 }
156}
157
158fn set_wal_journal_mode(connection: &rusqlite::Connection) -> Result<(), Error> {
163 let journal_mode: String =
164 connection.pragma_update_and_check(None, "journal_mode", "WAL", |row| row.get(0))?;
165
166 if journal_mode != "wal" {
167 return Err(Error::InvalidJournalMode(journal_mode));
168 }
169
170 Ok(())
171}
172
173fn connection_thread_impl(
175 rx: crossbeam_channel::Receiver<Message>,
176 path: PathBuf,
177 flags: rusqlite::OpenFlags,
178 init_fn: Option<ConnectionSetupFn>,
179 connection_open_tx: tokio::sync::oneshot::Sender<Result<(), Error>>,
180) {
181 let open_result = rusqlite::Connection::open_with_flags(path, flags);
183 let mut connection = match open_result {
184 Ok(connection) => connection,
185 Err(error) => {
186 let _ = connection_open_tx.send(Err(Error::Rusqlite(error))).is_ok();
188 return;
189 }
190 };
191
192 if let Err(error) = set_wal_journal_mode(&connection) {
195 let _ = connection_open_tx.send(Err(error)).is_ok();
197 return;
198 }
199
200 if let Some(init_fn) = init_fn {
202 let init_fn = std::panic::AssertUnwindSafe(|| init_fn(&mut connection));
203 let init_result = std::panic::catch_unwind(init_fn);
204 let init_result =
205 init_result.map_err(|panic_data| Error::AccessPanic(SyncWrapper::new(panic_data)));
206 if let Err(error) = init_result {
207 let _ = connection_open_tx.send(Err(error)).is_ok();
209 return;
210 }
211 }
212
213 if connection_open_tx.send(Ok(())).is_err() {
215 return;
216 }
217
218 let mut close_tx = None;
219 for message in rx.iter() {
220 match message {
221 Message::Close { tx } => {
222 close_tx = Some(tx);
223 break;
224 }
225 Message::Access { func } => {
226 func(&mut connection);
227 }
228 }
229 }
230
231 drop(rx);
235
236 let result = connection.close();
237 if let Some(tx) = close_tx {
238 let _ = tx
239 .send(result.map_err(|(_connection, error)| Error::from(error)))
240 .is_ok();
241 }
242}