opendal_core/services/surrealdb/
core.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::fmt::Debug;
19use std::sync::Arc;
20
21use mea::once::OnceCell;
22use surrealdb::Surreal;
23use surrealdb::engine::any::Any;
24use surrealdb::opt::auth::Database;
25
26use crate::*;
27
28#[derive(Clone)]
29pub struct SurrealdbCore {
30    pub db: OnceCell<Arc<Surreal<Any>>>,
31    pub connection_string: String,
32
33    pub username: String,
34    pub password: String,
35    pub namespace: String,
36    pub database: String,
37
38    pub table: String,
39    pub key_field: String,
40    pub value_field: String,
41}
42
43impl Debug for SurrealdbCore {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        f.debug_struct("SurrealdbCore")
46            .field("connection_string", &self.connection_string)
47            .field("username", &self.username)
48            .field("namespace", &self.namespace)
49            .field("database", &self.database)
50            .field("table", &self.table)
51            .field("key_field", &self.key_field)
52            .field("value_field", &self.value_field)
53            .finish_non_exhaustive()
54    }
55}
56
57impl SurrealdbCore {
58    async fn get_connection(&self) -> Result<&Surreal<Any>> {
59        self.db
60            .get_or_try_init(|| async {
61                let namespace = self.namespace.as_str();
62                let database = self.database.as_str();
63
64                let db: Surreal<Any> = Surreal::init();
65                db.connect(self.connection_string.clone())
66                    .await
67                    .map_err(parse_surrealdb_error)?;
68
69                if !self.username.is_empty() && !self.password.is_empty() {
70                    db.signin(Database {
71                        namespace,
72                        database,
73                        username: self.username.as_str(),
74                        password: self.password.as_str(),
75                    })
76                    .await
77                    .map_err(parse_surrealdb_error)?;
78                }
79                db.use_ns(namespace)
80                    .use_db(database)
81                    .await
82                    .map_err(parse_surrealdb_error)?;
83
84                Ok(Arc::new(db))
85            })
86            .await
87            .map(|v| v.as_ref())
88    }
89
90    pub async fn get(&self, path: &str) -> Result<Option<Buffer>> {
91        let query: String = if self.key_field == "id" {
92            "SELECT type::field($value_field) FROM type::thing($table, $path)".to_string()
93        } else {
94            format!(
95                "SELECT type::field($value_field) FROM type::table($table) WHERE {} = $path LIMIT 1",
96                self.key_field
97            )
98        };
99
100        let mut result = self
101            .get_connection()
102            .await?
103            .query(query)
104            .bind(("namespace", "opendal"))
105            .bind(("path", path.to_string()))
106            .bind(("table", self.table.to_string()))
107            .bind(("value_field", self.value_field.to_string()))
108            .await
109            .map_err(parse_surrealdb_error)?;
110
111        let value: Option<Vec<u8>> = result
112            .take((0, self.value_field.as_str()))
113            .map_err(parse_surrealdb_error)?;
114
115        Ok(value.map(Buffer::from))
116    }
117
118    pub async fn set(&self, path: &str, value: Buffer) -> Result<()> {
119        let query = format!(
120            "INSERT INTO {} ({}, {}) \
121            VALUES ($path, $value) \
122            ON DUPLICATE KEY UPDATE {} = $value",
123            self.table, self.key_field, self.value_field, self.value_field
124        );
125        self.get_connection()
126            .await?
127            .query(query)
128            .bind(("path", path.to_string()))
129            .bind(("value", value.to_vec()))
130            .await
131            .map_err(parse_surrealdb_error)?;
132        Ok(())
133    }
134
135    pub async fn delete(&self, path: &str) -> Result<()> {
136        let query: String = if self.key_field == "id" {
137            "DELETE FROM type::thing($table, $path)".to_string()
138        } else {
139            format!(
140                "DELETE FROM type::table($table) WHERE {} = $path",
141                self.key_field
142            )
143        };
144
145        self.get_connection()
146            .await?
147            .query(query.as_str())
148            .bind(("path", path.to_string()))
149            .bind(("table", self.table.to_string()))
150            .await
151            .map_err(parse_surrealdb_error)?;
152        Ok(())
153    }
154}
155
156fn parse_surrealdb_error(err: surrealdb::Error) -> Error {
157    Error::new(ErrorKind::Unexpected, "unhandled error from surrealdb").set_source(err)
158}