nd_async_rusqlite/wal_pool/
builder.rs1use super::InnerWalPool;
2use super::Message;
3use super::WalPool;
4use crate::Error;
5use crate::SyncWrapper;
6use std::path::Path;
7use std::path::PathBuf;
8use std::sync::Arc;
9
10const DEFAULT_NUM_READ_CONNECTIONS: usize = 4;
11
12type ConnectionInitFn =
13 Arc<dyn Fn(&mut rusqlite::Connection) -> Result<(), Error> + Send + Sync + 'static>;
14
15pub struct WalPoolBuilder {
17 pub num_read_connections: usize,
19
20 pub reader_init_fn: Option<ConnectionInitFn>,
22
23 pub writer_init_fn: Option<ConnectionInitFn>,
25}
26
27impl WalPoolBuilder {
28 pub fn new() -> Self {
30 Self {
32 num_read_connections: DEFAULT_NUM_READ_CONNECTIONS,
33
34 writer_init_fn: None,
35 reader_init_fn: None,
36 }
37 }
38
39 pub fn num_read_connections(&mut self, num_read_connections: usize) -> &mut Self {
43 self.num_read_connections = num_read_connections;
44 self
45 }
46
47 pub fn writer_init_fn<F>(&mut self, writer_init_fn: F) -> &mut Self
49 where
50 F: Fn(&mut rusqlite::Connection) -> Result<(), Error> + Send + Sync + 'static,
51 {
52 self.writer_init_fn = Some(Arc::new(writer_init_fn));
53 self
54 }
55
56 pub fn reader_init_fn<F>(&mut self, reader_init_fn: F) -> &mut Self
58 where
59 F: Fn(&mut rusqlite::Connection) -> Result<(), Error> + Send + Sync + 'static,
60 {
61 self.reader_init_fn = Some(Arc::new(reader_init_fn));
62 self
63 }
64
65 pub async fn open<P>(&self, path: P) -> Result<WalPool, Error>
67 where
68 P: AsRef<Path>,
69 {
70 let path = path.as_ref().to_path_buf();
71
72 let num_read_connections = self.num_read_connections;
74
75 let writer_tx = {
77 let (writer_tx, writer_rx) = crossbeam_channel::unbounded::<Message>();
78 let path = path.clone();
79 let flags = rusqlite::OpenFlags::default();
80 let writer_init_fn = self.writer_init_fn.clone();
81 let (open_write_tx, open_write_rx) = tokio::sync::oneshot::channel();
82 std::thread::spawn(move || {
83 connection_thread_impl(writer_rx, path, flags, writer_init_fn, open_write_tx)
84 });
85
86 open_write_rx
87 .await
88 .map_err(|_| Error::Aborted)
89 .and_then(std::convert::identity)?;
90
91 writer_tx
92 };
93
94 let (readers_tx, readers_rx) = crossbeam_channel::unbounded::<Message>();
96 let mut open_read_rx_list = Vec::with_capacity(num_read_connections);
97 for _ in 0..num_read_connections {
98 let readers_rx = readers_rx.clone();
99 let path = path.clone();
100
101 let mut flags = rusqlite::OpenFlags::default();
103 flags.remove(rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE);
104 flags.remove(rusqlite::OpenFlags::SQLITE_OPEN_CREATE);
105 flags.insert(rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY);
106
107 let reader_init_fn = self.reader_init_fn.clone();
108
109 let (open_read_tx, open_read_rx) = tokio::sync::oneshot::channel();
110 std::thread::spawn(move || {
111 connection_thread_impl(readers_rx, path, flags, reader_init_fn, open_read_tx)
112 });
113 open_read_rx_list.push(open_read_rx);
114 }
115 drop(readers_rx);
116
117 let wal_pool = WalPool {
120 inner: Arc::new(InnerWalPool {
121 writer_tx,
122 readers_tx,
123 }),
124 };
125
126 let mut last_error = Ok(());
127
128 for open_read_rx in open_read_rx_list {
129 if let Err(error) = open_read_rx
130 .await
131 .map_err(|_| Error::Aborted)
132 .and_then(std::convert::identity)
133 {
134 last_error = Err(error);
135 }
136 }
137
138 if let Err(error) = last_error {
139 let _ = wal_pool.close().await.is_ok();
142
143 return Err(error);
144 }
145
146 Ok(wal_pool)
147 }
148}
149
150impl Default for WalPoolBuilder {
151 fn default() -> Self {
152 Self::new()
153 }
154}
155
156fn set_wal_journal_mode(connection: &rusqlite::Connection) -> Result<(), Error> {
161 let journal_mode: String =
162 connection.pragma_update_and_check(None, "journal_mode", "WAL", |row| row.get(0))?;
163
164 if journal_mode != "wal" {
165 return Err(Error::InvalidJournalMode(journal_mode));
166 }
167
168 Ok(())
169}
170
171fn connection_thread_impl(
173 rx: crossbeam_channel::Receiver<Message>,
174 path: PathBuf,
175 flags: rusqlite::OpenFlags,
176 init_fn: Option<ConnectionInitFn>,
177 connection_open_tx: tokio::sync::oneshot::Sender<Result<(), Error>>,
178) {
179 let open_result = rusqlite::Connection::open_with_flags(path, flags);
181 let mut connection = match open_result {
182 Ok(connection) => connection,
183 Err(error) => {
184 let _ = connection_open_tx.send(Err(Error::Rusqlite(error))).is_ok();
186 return;
187 }
188 };
189
190 if let Err(error) = set_wal_journal_mode(&connection) {
193 let _ = connection_open_tx.send(Err(error)).is_ok();
195 return;
196 }
197
198 if let Some(init_fn) = init_fn {
200 let init_fn = std::panic::AssertUnwindSafe(|| init_fn(&mut connection));
201 let init_result = std::panic::catch_unwind(init_fn);
202 let init_result =
203 init_result.map_err(|panic_data| Error::AccessPanic(SyncWrapper::new(panic_data)));
204 if let Err(error) = init_result {
205 let _ = connection_open_tx.send(Err(error)).is_ok();
207 return;
208 }
209 }
210
211 if connection_open_tx.send(Ok(())).is_err() {
213 return;
214 }
215
216 let mut close_tx = None;
217 for message in rx.iter() {
218 match message {
219 Message::Close { tx } => {
220 close_tx = Some(tx);
221 break;
222 }
223 Message::Access { func } => {
224 func(&mut connection);
225 }
226 }
227 }
228
229 drop(rx);
233
234 let result = connection.close();
235 if let Some(tx) = close_tx {
236 let _ = tx
237 .send(result.map_err(|(_connection, error)| Error::from(error)))
238 .is_ok();
239 }
240}