opendal/services/d1/
backend.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::fmt::Debug;
19use std::fmt::Formatter;
20
21use http::header;
22use http::Request;
23use http::StatusCode;
24use serde_json::Value;
25
26use super::error::parse_error;
27use super::model::D1Response;
28use crate::raw::adapters::kv;
29use crate::raw::*;
30use crate::services::D1Config;
31use crate::ErrorKind;
32use crate::*;
33
34impl Configurator for D1Config {
35    type Builder = D1Builder;
36    fn into_builder(self) -> Self::Builder {
37        D1Builder {
38            config: self,
39            http_client: None,
40        }
41    }
42}
43
44#[doc = include_str!("docs.md")]
45#[derive(Default)]
46pub struct D1Builder {
47    config: D1Config,
48
49    http_client: Option<HttpClient>,
50}
51
52impl Debug for D1Builder {
53    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("D1Builder")
55            .field("config", &self.config)
56            .finish()
57    }
58}
59
60impl D1Builder {
61    /// Set api token for the cloudflare d1 service.
62    ///
63    /// create a api token from [here](https://dash.cloudflare.com/profile/api-tokens)
64    pub fn token(mut self, token: &str) -> Self {
65        if !token.is_empty() {
66            self.config.token = Some(token.to_string());
67        }
68        self
69    }
70
71    /// Set the account identifier for the cloudflare d1 service.
72    ///
73    /// get the account identifier from Workers & Pages -> Overview -> Account ID
74    /// If not specified, it will return an error when building.
75    pub fn account_id(mut self, account_id: &str) -> Self {
76        if !account_id.is_empty() {
77            self.config.account_id = Some(account_id.to_string());
78        }
79        self
80    }
81
82    /// Set the database identifier for the cloudflare d1 service.
83    ///
84    /// get the database identifier from Workers & Pages -> D1 -> [Your Database] -> Database ID
85    /// If not specified, it will return an error when building.
86    pub fn database_id(mut self, database_id: &str) -> Self {
87        if !database_id.is_empty() {
88            self.config.database_id = Some(database_id.to_string());
89        }
90        self
91    }
92
93    /// set the working directory, all operations will be performed under it.
94    ///
95    /// default: "/"
96    pub fn root(mut self, root: &str) -> Self {
97        self.config.root = if root.is_empty() {
98            None
99        } else {
100            Some(root.to_string())
101        };
102
103        self
104    }
105
106    /// Set the table name of the d1 service to read/write.
107    ///
108    /// If not specified, it will return an error when building.
109    pub fn table(mut self, table: &str) -> Self {
110        if !table.is_empty() {
111            self.config.table = Some(table.to_owned());
112        }
113        self
114    }
115
116    /// Set the key field name of the d1 service to read/write.
117    ///
118    /// Default to `key` if not specified.
119    pub fn key_field(mut self, key_field: &str) -> Self {
120        if !key_field.is_empty() {
121            self.config.key_field = Some(key_field.to_string());
122        }
123        self
124    }
125
126    /// Set the value field name of the d1 service to read/write.
127    ///
128    /// Default to `value` if not specified.
129    pub fn value_field(mut self, value_field: &str) -> Self {
130        if !value_field.is_empty() {
131            self.config.value_field = Some(value_field.to_string());
132        }
133        self
134    }
135}
136
137impl Builder for D1Builder {
138    type Config = D1Config;
139
140    fn build(self) -> Result<impl Access> {
141        let mut authorization = None;
142        let config = self.config;
143
144        if let Some(token) = config.token {
145            authorization = Some(format_authorization_by_bearer(&token)?)
146        }
147
148        let Some(account_id) = config.account_id else {
149            return Err(Error::new(
150                ErrorKind::ConfigInvalid,
151                "account_id is required",
152            ));
153        };
154
155        let Some(database_id) = config.database_id.clone() else {
156            return Err(Error::new(
157                ErrorKind::ConfigInvalid,
158                "database_id is required",
159            ));
160        };
161
162        let client = if let Some(client) = self.http_client {
163            client
164        } else {
165            HttpClient::new().map_err(|err| {
166                err.with_operation("Builder::build")
167                    .with_context("service", Scheme::D1)
168            })?
169        };
170
171        let Some(table) = config.table.clone() else {
172            return Err(Error::new(ErrorKind::ConfigInvalid, "table is required"));
173        };
174
175        let key_field = config
176            .key_field
177            .clone()
178            .unwrap_or_else(|| "key".to_string());
179
180        let value_field = config
181            .value_field
182            .clone()
183            .unwrap_or_else(|| "value".to_string());
184
185        let root = normalize_root(
186            config
187                .root
188                .clone()
189                .unwrap_or_else(|| "/".to_string())
190                .as_str(),
191        );
192        Ok(D1Backend::new(Adapter {
193            authorization,
194            account_id,
195            database_id,
196            client,
197            table,
198            key_field,
199            value_field,
200        })
201        .with_normalized_root(root))
202    }
203}
204
205pub type D1Backend = kv::Backend<Adapter>;
206
207#[derive(Clone)]
208pub struct Adapter {
209    authorization: Option<String>,
210    account_id: String,
211    database_id: String,
212
213    client: HttpClient,
214    table: String,
215    key_field: String,
216    value_field: String,
217}
218
219impl Debug for Adapter {
220    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
221        let mut ds = f.debug_struct("D1Adapter");
222        ds.field("table", &self.table);
223        ds.field("key_field", &self.key_field);
224        ds.field("value_field", &self.value_field);
225        ds.finish()
226    }
227}
228
229impl Adapter {
230    fn create_d1_query_request(&self, sql: &str, params: Vec<Value>) -> Result<Request<Buffer>> {
231        let p = format!(
232            "/accounts/{}/d1/database/{}/query",
233            self.account_id, self.database_id
234        );
235        let url: String = format!(
236            "{}{}",
237            "https://api.cloudflare.com/client/v4",
238            percent_encode_path(&p)
239        );
240
241        let mut req = Request::post(&url);
242        if let Some(auth) = &self.authorization {
243            req = req.header(header::AUTHORIZATION, auth);
244        }
245        req = req.header(header::CONTENT_TYPE, "application/json");
246
247        let json = serde_json::json!({
248            "sql": sql,
249            "params": params,
250        });
251
252        let body = serde_json::to_vec(&json).map_err(new_json_serialize_error)?;
253        req.body(Buffer::from(body))
254            .map_err(new_request_build_error)
255    }
256}
257
258impl kv::Adapter for Adapter {
259    type Scanner = ();
260
261    fn info(&self) -> kv::Info {
262        kv::Info::new(
263            Scheme::D1,
264            &self.table,
265            Capability {
266                read: true,
267                write: true,
268                // Cloudflare D1 supports 1MB as max in write_total.
269                // refer to https://developers.cloudflare.com/d1/platform/limits/
270                write_total_max_size: Some(1000 * 1000),
271                shared: true,
272                ..Default::default()
273            },
274        )
275    }
276
277    async fn get(&self, path: &str) -> Result<Option<Buffer>> {
278        let query = format!(
279            "SELECT {} FROM {} WHERE {} = ? LIMIT 1",
280            self.value_field, self.table, self.key_field
281        );
282        let req = self.create_d1_query_request(&query, vec![path.into()])?;
283
284        let resp = self.client.send(req).await?;
285        let status = resp.status();
286        match status {
287            StatusCode::OK | StatusCode::PARTIAL_CONTENT => {
288                let body = resp.into_body();
289                let bs = body.to_bytes();
290                let d1_response = D1Response::parse(&bs)?;
291                Ok(d1_response.get_result(&self.value_field))
292            }
293            _ => Err(parse_error(resp)),
294        }
295    }
296
297    async fn set(&self, path: &str, value: Buffer) -> Result<()> {
298        let table = &self.table;
299        let key_field = &self.key_field;
300        let value_field = &self.value_field;
301        let query = format!(
302            "INSERT INTO {table} ({key_field}, {value_field}) \
303                VALUES (?, ?) \
304                ON CONFLICT ({key_field}) \
305                    DO UPDATE SET {value_field} = EXCLUDED.{value_field}",
306        );
307
308        let params = vec![path.into(), value.to_vec().into()];
309        let req = self.create_d1_query_request(&query, params)?;
310
311        let resp = self.client.send(req).await?;
312        let status = resp.status();
313        match status {
314            StatusCode::OK | StatusCode::PARTIAL_CONTENT => Ok(()),
315            _ => Err(parse_error(resp)),
316        }
317    }
318
319    async fn delete(&self, path: &str) -> Result<()> {
320        let query = format!("DELETE FROM {} WHERE {} = ?", self.table, self.key_field);
321        let req = self.create_d1_query_request(&query, vec![path.into()])?;
322
323        let resp = self.client.send(req).await?;
324        let status = resp.status();
325        match status {
326            StatusCode::OK | StatusCode::PARTIAL_CONTENT => Ok(()),
327            _ => Err(parse_error(resp)),
328        }
329    }
330}