opendal/services/sqlite/
backend.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::fmt::Debug;
19use std::fmt::Formatter;
20use std::pin::Pin;
21use std::str::FromStr;
22use std::task::Context;
23use std::task::Poll;
24
25use futures::stream::BoxStream;
26use futures::Stream;
27use futures::StreamExt;
28use ouroboros::self_referencing;
29use sqlx::sqlite::SqliteConnectOptions;
30use sqlx::SqlitePool;
31use tokio::sync::OnceCell;
32
33use crate::raw::adapters::kv;
34use crate::raw::*;
35use crate::services::SqliteConfig;
36use crate::*;
37
38impl Configurator for SqliteConfig {
39    type Builder = SqliteBuilder;
40    fn into_builder(self) -> Self::Builder {
41        SqliteBuilder { config: self }
42    }
43}
44
45#[doc = include_str!("docs.md")]
46#[derive(Default)]
47pub struct SqliteBuilder {
48    config: SqliteConfig,
49}
50
51impl Debug for SqliteBuilder {
52    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
53        let mut ds = f.debug_struct("SqliteBuilder");
54
55        ds.field("config", &self.config);
56        ds.finish()
57    }
58}
59
60impl SqliteBuilder {
61    /// Set the connection_string of the sqlite service.
62    ///
63    /// This connection string is used to connect to the sqlite service. There are url based formats:
64    ///
65    /// ## Url
66    ///
67    /// This format resembles the url format of the sqlite client:
68    ///
69    /// - `sqlite::memory:`
70    /// - `sqlite:data.db`
71    /// - `sqlite://data.db`
72    ///
73    /// For more information, please visit <https://docs.rs/sqlx/latest/sqlx/sqlite/struct.SqliteConnectOptions.html>.
74    pub fn connection_string(mut self, v: &str) -> Self {
75        if !v.is_empty() {
76            self.config.connection_string = Some(v.to_string());
77        }
78        self
79    }
80
81    /// set the working directory, all operations will be performed under it.
82    ///
83    /// default: "/"
84    pub fn root(mut self, root: &str) -> Self {
85        self.config.root = if root.is_empty() {
86            None
87        } else {
88            Some(root.to_string())
89        };
90
91        self
92    }
93
94    /// Set the table name of the sqlite service to read/write.
95    pub fn table(mut self, table: &str) -> Self {
96        if !table.is_empty() {
97            self.config.table = Some(table.to_string());
98        }
99        self
100    }
101
102    /// Set the key field name of the sqlite service to read/write.
103    ///
104    /// Default to `key` if not specified.
105    pub fn key_field(mut self, key_field: &str) -> Self {
106        if !key_field.is_empty() {
107            self.config.key_field = Some(key_field.to_string());
108        }
109        self
110    }
111
112    /// Set the value field name of the sqlite service to read/write.
113    ///
114    /// Default to `value` if not specified.
115    pub fn value_field(mut self, value_field: &str) -> Self {
116        if !value_field.is_empty() {
117            self.config.value_field = Some(value_field.to_string());
118        }
119        self
120    }
121}
122
123impl Builder for SqliteBuilder {
124    type Config = SqliteConfig;
125
126    fn build(self) -> Result<impl Access> {
127        let conn = match self.config.connection_string {
128            Some(v) => v,
129            None => {
130                return Err(Error::new(
131                    ErrorKind::ConfigInvalid,
132                    "connection_string is required but not set",
133                )
134                .with_context("service", Scheme::Sqlite));
135            }
136        };
137
138        let config = SqliteConnectOptions::from_str(&conn).map_err(|err| {
139            Error::new(ErrorKind::ConfigInvalid, "connection_string is invalid")
140                .with_context("service", Scheme::Sqlite)
141                .set_source(err)
142        })?;
143
144        let table = match self.config.table {
145            Some(v) => v,
146            None => {
147                return Err(Error::new(ErrorKind::ConfigInvalid, "table is empty")
148                    .with_context("service", Scheme::Sqlite));
149            }
150        };
151
152        let key_field = self.config.key_field.unwrap_or_else(|| "key".to_string());
153
154        let value_field = self
155            .config
156            .value_field
157            .unwrap_or_else(|| "value".to_string());
158
159        let root = normalize_root(self.config.root.as_deref().unwrap_or("/"));
160
161        Ok(SqliteBackend::new(Adapter {
162            pool: OnceCell::new(),
163            config,
164            table,
165            key_field,
166            value_field,
167        })
168        .with_normalized_root(root))
169    }
170}
171
172pub type SqliteBackend = kv::Backend<Adapter>;
173
174#[derive(Debug, Clone)]
175pub struct Adapter {
176    pool: OnceCell<SqlitePool>,
177    config: SqliteConnectOptions,
178
179    table: String,
180    key_field: String,
181    value_field: String,
182}
183
184impl Adapter {
185    async fn get_client(&self) -> Result<&SqlitePool> {
186        self.pool
187            .get_or_try_init(|| async {
188                let pool = SqlitePool::connect_with(self.config.clone())
189                    .await
190                    .map_err(parse_sqlite_error)?;
191                Ok(pool)
192            })
193            .await
194    }
195}
196
197#[self_referencing]
198pub struct SqliteScanner {
199    pool: SqlitePool,
200    query: String,
201
202    #[borrows(pool, query)]
203    #[covariant]
204    stream: BoxStream<'this, Result<String>>,
205}
206
207impl Stream for SqliteScanner {
208    type Item = Result<String>;
209
210    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
211        self.with_stream_mut(|s| s.poll_next_unpin(cx))
212    }
213}
214
215unsafe impl Sync for SqliteScanner {}
216
217impl kv::Scan for SqliteScanner {
218    async fn next(&mut self) -> Result<Option<String>> {
219        <Self as StreamExt>::next(self).await.transpose()
220    }
221}
222
223impl kv::Adapter for Adapter {
224    type Scanner = SqliteScanner;
225
226    fn info(&self) -> kv::Info {
227        kv::Info::new(
228            Scheme::Sqlite,
229            &self.table,
230            Capability {
231                read: true,
232                write: true,
233                delete: true,
234                list: true,
235                shared: false,
236                ..Default::default()
237            },
238        )
239    }
240
241    async fn get(&self, path: &str) -> Result<Option<Buffer>> {
242        let pool = self.get_client().await?;
243
244        let value: Option<Vec<u8>> = sqlx::query_scalar(&format!(
245            "SELECT `{}` FROM `{}` WHERE `{}` = $1 LIMIT 1",
246            self.value_field, self.table, self.key_field
247        ))
248        .bind(path)
249        .fetch_optional(pool)
250        .await
251        .map_err(parse_sqlite_error)?;
252
253        Ok(value.map(Buffer::from))
254    }
255
256    async fn set(&self, path: &str, value: Buffer) -> Result<()> {
257        let pool = self.get_client().await?;
258
259        sqlx::query(&format!(
260            "INSERT OR REPLACE INTO `{}` (`{}`, `{}`) VALUES ($1, $2)",
261            self.table, self.key_field, self.value_field,
262        ))
263        .bind(path)
264        .bind(value.to_vec())
265        .execute(pool)
266        .await
267        .map_err(parse_sqlite_error)?;
268
269        Ok(())
270    }
271
272    async fn delete(&self, path: &str) -> Result<()> {
273        let pool = self.get_client().await?;
274
275        sqlx::query(&format!(
276            "DELETE FROM `{}` WHERE `{}` = $1",
277            self.table, self.key_field
278        ))
279        .bind(path)
280        .execute(pool)
281        .await
282        .map_err(parse_sqlite_error)?;
283
284        Ok(())
285    }
286
287    async fn scan(&self, path: &str) -> Result<Self::Scanner> {
288        let pool = self.get_client().await?;
289        let stream = SqliteScannerBuilder {
290            pool: pool.clone(),
291            query: format!(
292                "SELECT `{}` FROM `{}` WHERE `{}` LIKE $1",
293                self.key_field, self.table, self.key_field
294            ),
295            stream_builder: |pool, query| {
296                sqlx::query_scalar(query)
297                    .bind(format!("{path}%"))
298                    .fetch(pool)
299                    .map(|v| v.map_err(parse_sqlite_error))
300                    .boxed()
301            },
302        }
303        .build();
304
305        Ok(stream)
306    }
307}
308
309fn parse_sqlite_error(err: sqlx::Error) -> Error {
310    let is_temporary = matches!(
311        &err,
312        sqlx::Error::Database(db_err) if db_err.code().is_some_and(|c| c == "5" || c == "6")
313    );
314
315    let message = if is_temporary {
316        "database is locked or busy"
317    } else {
318        "unhandled error from sqlite"
319    };
320
321    let mut error = Error::new(ErrorKind::Unexpected, message).set_source(err);
322    if is_temporary {
323        error = error.set_temporary();
324    }
325    error
326}