opendal/services/d1/
backend.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::fmt::Debug;
19use std::fmt::Formatter;
20
21use http::header;
22use http::Request;
23use http::StatusCode;
24use serde_json::Value;
25
26use super::error::parse_error;
27use super::model::D1Response;
28use crate::raw::adapters::kv;
29use crate::raw::*;
30use crate::services::D1Config;
31use crate::ErrorKind;
32use crate::*;
33
34impl Configurator for D1Config {
35    type Builder = D1Builder;
36    fn into_builder(self) -> Self::Builder {
37        D1Builder {
38            config: self,
39            http_client: None,
40        }
41    }
42}
43
44#[doc = include_str!("docs.md")]
45#[derive(Default)]
46pub struct D1Builder {
47    config: D1Config,
48
49    http_client: Option<HttpClient>,
50}
51
52impl Debug for D1Builder {
53    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("D1Builder")
55            .field("config", &self.config)
56            .finish()
57    }
58}
59
60impl D1Builder {
61    /// Set api token for the cloudflare d1 service.
62    ///
63    /// create a api token from [here](https://dash.cloudflare.com/profile/api-tokens)
64    pub fn token(mut self, token: &str) -> Self {
65        if !token.is_empty() {
66            self.config.token = Some(token.to_string());
67        }
68        self
69    }
70
71    /// Set the account identifier for the cloudflare d1 service.
72    ///
73    /// get the account identifier from Workers & Pages -> Overview -> Account ID
74    /// If not specified, it will return an error when building.
75    pub fn account_id(mut self, account_id: &str) -> Self {
76        if !account_id.is_empty() {
77            self.config.account_id = Some(account_id.to_string());
78        }
79        self
80    }
81
82    /// Set the database identifier for the cloudflare d1 service.
83    ///
84    /// get the database identifier from Workers & Pages -> D1 -> [Your Database] -> Database ID
85    /// If not specified, it will return an error when building.
86    pub fn database_id(mut self, database_id: &str) -> Self {
87        if !database_id.is_empty() {
88            self.config.database_id = Some(database_id.to_string());
89        }
90        self
91    }
92
93    /// set the working directory, all operations will be performed under it.
94    ///
95    /// default: "/"
96    pub fn root(mut self, root: &str) -> Self {
97        self.config.root = if root.is_empty() {
98            None
99        } else {
100            Some(root.to_string())
101        };
102
103        self
104    }
105
106    /// Set the table name of the d1 service to read/write.
107    ///
108    /// If not specified, it will return an error when building.
109    pub fn table(mut self, table: &str) -> Self {
110        if !table.is_empty() {
111            self.config.table = Some(table.to_owned());
112        }
113        self
114    }
115
116    /// Set the key field name of the d1 service to read/write.
117    ///
118    /// Default to `key` if not specified.
119    pub fn key_field(mut self, key_field: &str) -> Self {
120        if !key_field.is_empty() {
121            self.config.key_field = Some(key_field.to_string());
122        }
123        self
124    }
125
126    /// Set the value field name of the d1 service to read/write.
127    ///
128    /// Default to `value` if not specified.
129    pub fn value_field(mut self, value_field: &str) -> Self {
130        if !value_field.is_empty() {
131            self.config.value_field = Some(value_field.to_string());
132        }
133        self
134    }
135}
136
137impl Builder for D1Builder {
138    const SCHEME: Scheme = Scheme::D1;
139    type Config = D1Config;
140
141    fn build(self) -> Result<impl Access> {
142        let mut authorization = None;
143        let config = self.config;
144
145        if let Some(token) = config.token {
146            authorization = Some(format_authorization_by_bearer(&token)?)
147        }
148
149        let Some(account_id) = config.account_id else {
150            return Err(Error::new(
151                ErrorKind::ConfigInvalid,
152                "account_id is required",
153            ));
154        };
155
156        let Some(database_id) = config.database_id.clone() else {
157            return Err(Error::new(
158                ErrorKind::ConfigInvalid,
159                "database_id is required",
160            ));
161        };
162
163        let client = if let Some(client) = self.http_client {
164            client
165        } else {
166            HttpClient::new().map_err(|err| {
167                err.with_operation("Builder::build")
168                    .with_context("service", Scheme::D1)
169            })?
170        };
171
172        let Some(table) = config.table.clone() else {
173            return Err(Error::new(ErrorKind::ConfigInvalid, "table is required"));
174        };
175
176        let key_field = config
177            .key_field
178            .clone()
179            .unwrap_or_else(|| "key".to_string());
180
181        let value_field = config
182            .value_field
183            .clone()
184            .unwrap_or_else(|| "value".to_string());
185
186        let root = normalize_root(
187            config
188                .root
189                .clone()
190                .unwrap_or_else(|| "/".to_string())
191                .as_str(),
192        );
193        Ok(D1Backend::new(Adapter {
194            authorization,
195            account_id,
196            database_id,
197            client,
198            table,
199            key_field,
200            value_field,
201        })
202        .with_normalized_root(root))
203    }
204}
205
206pub type D1Backend = kv::Backend<Adapter>;
207
208#[derive(Clone)]
209pub struct Adapter {
210    authorization: Option<String>,
211    account_id: String,
212    database_id: String,
213
214    client: HttpClient,
215    table: String,
216    key_field: String,
217    value_field: String,
218}
219
220impl Debug for Adapter {
221    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
222        let mut ds = f.debug_struct("D1Adapter");
223        ds.field("table", &self.table);
224        ds.field("key_field", &self.key_field);
225        ds.field("value_field", &self.value_field);
226        ds.finish()
227    }
228}
229
230impl Adapter {
231    fn create_d1_query_request(&self, sql: &str, params: Vec<Value>) -> Result<Request<Buffer>> {
232        let p = format!(
233            "/accounts/{}/d1/database/{}/query",
234            self.account_id, self.database_id
235        );
236        let url: String = format!(
237            "{}{}",
238            "https://api.cloudflare.com/client/v4",
239            percent_encode_path(&p)
240        );
241
242        let mut req = Request::post(&url);
243        if let Some(auth) = &self.authorization {
244            req = req.header(header::AUTHORIZATION, auth);
245        }
246        req = req.header(header::CONTENT_TYPE, "application/json");
247
248        let json = serde_json::json!({
249            "sql": sql,
250            "params": params,
251        });
252
253        let body = serde_json::to_vec(&json).map_err(new_json_serialize_error)?;
254        req.body(Buffer::from(body))
255            .map_err(new_request_build_error)
256    }
257}
258
259impl kv::Adapter for Adapter {
260    type Scanner = ();
261
262    fn info(&self) -> kv::Info {
263        kv::Info::new(
264            Scheme::D1,
265            &self.table,
266            Capability {
267                read: true,
268                write: true,
269                // Cloudflare D1 supports 1MB as max in write_total.
270                // refer to https://developers.cloudflare.com/d1/platform/limits/
271                write_total_max_size: Some(1000 * 1000),
272                shared: true,
273                ..Default::default()
274            },
275        )
276    }
277
278    async fn get(&self, path: &str) -> Result<Option<Buffer>> {
279        let query = format!(
280            "SELECT {} FROM {} WHERE {} = ? LIMIT 1",
281            self.value_field, self.table, self.key_field
282        );
283        let req = self.create_d1_query_request(&query, vec![path.into()])?;
284
285        let resp = self.client.send(req).await?;
286        let status = resp.status();
287        match status {
288            StatusCode::OK | StatusCode::PARTIAL_CONTENT => {
289                let body = resp.into_body();
290                let bs = body.to_bytes();
291                let d1_response = D1Response::parse(&bs)?;
292                Ok(d1_response.get_result(&self.value_field))
293            }
294            _ => Err(parse_error(resp)),
295        }
296    }
297
298    async fn set(&self, path: &str, value: Buffer) -> Result<()> {
299        let table = &self.table;
300        let key_field = &self.key_field;
301        let value_field = &self.value_field;
302        let query = format!(
303            "INSERT INTO {table} ({key_field}, {value_field}) \
304                VALUES (?, ?) \
305                ON CONFLICT ({key_field}) \
306                    DO UPDATE SET {value_field} = EXCLUDED.{value_field}",
307        );
308
309        let params = vec![path.into(), value.to_vec().into()];
310        let req = self.create_d1_query_request(&query, params)?;
311
312        let resp = self.client.send(req).await?;
313        let status = resp.status();
314        match status {
315            StatusCode::OK | StatusCode::PARTIAL_CONTENT => Ok(()),
316            _ => Err(parse_error(resp)),
317        }
318    }
319
320    async fn delete(&self, path: &str) -> Result<()> {
321        let query = format!("DELETE FROM {} WHERE {} = ?", self.table, self.key_field);
322        let req = self.create_d1_query_request(&query, vec![path.into()])?;
323
324        let resp = self.client.send(req).await?;
325        let status = resp.status();
326        match status {
327            StatusCode::OK | StatusCode::PARTIAL_CONTENT => Ok(()),
328            _ => Err(parse_error(resp)),
329        }
330    }
331}