opendal/services/huggingface/
core.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use bytes::Bytes;
19use http::header;
20use http::Request;
21use http::Response;
22use serde::Deserialize;
23use std::fmt::Debug;
24use std::sync::Arc;
25
26use super::backend::RepoType;
27use crate::raw::*;
28use crate::*;
29
30pub struct HuggingfaceCore {
31    pub info: Arc<AccessorInfo>,
32
33    pub repo_type: RepoType,
34    pub repo_id: String,
35    pub revision: String,
36    pub root: String,
37    pub token: Option<String>,
38}
39
40impl Debug for HuggingfaceCore {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        f.debug_struct("HuggingfaceCore")
43            .field("repo_type", &self.repo_type)
44            .field("repo_id", &self.repo_id)
45            .field("revision", &self.revision)
46            .field("root", &self.root)
47            .finish_non_exhaustive()
48    }
49}
50
51impl HuggingfaceCore {
52    pub async fn hf_path_info(&self, path: &str) -> Result<Response<Buffer>> {
53        let p = build_abs_path(&self.root, path)
54            .trim_end_matches('/')
55            .to_string();
56
57        let url = match self.repo_type {
58            RepoType::Model => format!(
59                "https://huggingface.co/api/models/{}/paths-info/{}",
60                &self.repo_id, &self.revision
61            ),
62            RepoType::Dataset => format!(
63                "https://huggingface.co/api/datasets/{}/paths-info/{}",
64                &self.repo_id, &self.revision
65            ),
66        };
67
68        let mut req = Request::post(&url);
69        // Inject operation to the request.
70        req = req.extension(Operation::Stat);
71        if let Some(token) = &self.token {
72            let auth_header_content = format_authorization_by_bearer(token)?;
73            req = req.header(header::AUTHORIZATION, auth_header_content);
74        }
75
76        req = req.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded");
77
78        let req_body = format!("paths={}&expand=True", percent_encode_path(&p));
79
80        let req = req
81            .body(Buffer::from(Bytes::from(req_body)))
82            .map_err(new_request_build_error)?;
83
84        self.info.http_client().send(req).await
85    }
86
87    pub async fn hf_list(&self, path: &str, recursive: bool) -> Result<Response<Buffer>> {
88        let p = build_abs_path(&self.root, path)
89            .trim_end_matches('/')
90            .to_string();
91
92        let mut url = match self.repo_type {
93            RepoType::Model => format!(
94                "https://huggingface.co/api/models/{}/tree/{}/{}?expand=True",
95                &self.repo_id,
96                &self.revision,
97                percent_encode_path(&p)
98            ),
99            RepoType::Dataset => format!(
100                "https://huggingface.co/api/datasets/{}/tree/{}/{}?expand=True",
101                &self.repo_id,
102                &self.revision,
103                percent_encode_path(&p)
104            ),
105        };
106
107        if recursive {
108            url.push_str("&recursive=True");
109        }
110
111        let mut req = Request::get(&url);
112        // Inject operation to the request.
113        req = req.extension(Operation::List);
114        if let Some(token) = &self.token {
115            let auth_header_content = format_authorization_by_bearer(token)?;
116            req = req.header(header::AUTHORIZATION, auth_header_content);
117        }
118
119        let req = req.body(Buffer::new()).map_err(new_request_build_error)?;
120
121        self.info.http_client().send(req).await
122    }
123
124    pub async fn hf_resolve(
125        &self,
126        path: &str,
127        range: BytesRange,
128        _args: &OpRead,
129    ) -> Result<Response<HttpBody>> {
130        let p = build_abs_path(&self.root, path)
131            .trim_end_matches('/')
132            .to_string();
133
134        let url = match self.repo_type {
135            RepoType::Model => format!(
136                "https://huggingface.co/{}/resolve/{}/{}",
137                &self.repo_id,
138                &self.revision,
139                percent_encode_path(&p)
140            ),
141            RepoType::Dataset => format!(
142                "https://huggingface.co/datasets/{}/resolve/{}/{}",
143                &self.repo_id,
144                &self.revision,
145                percent_encode_path(&p)
146            ),
147        };
148
149        let mut req = Request::get(&url);
150
151        if let Some(token) = &self.token {
152            let auth_header_content = format_authorization_by_bearer(token)?;
153            req = req.header(header::AUTHORIZATION, auth_header_content);
154        }
155
156        if !range.is_full() {
157            req = req.header(header::RANGE, range.to_header());
158        }
159        // Inject operation to the request.
160        let req = req.extension(Operation::Read);
161        let req = req.body(Buffer::new()).map_err(new_request_build_error)?;
162
163        self.info.http_client().fetch(req).await
164    }
165}
166
167#[derive(Deserialize, Eq, PartialEq, Debug)]
168#[serde(rename_all = "camelCase")]
169#[allow(dead_code)]
170pub(super) struct HuggingfaceStatus {
171    #[serde(rename = "type")]
172    pub type_: String,
173    pub oid: String,
174    pub size: u64,
175    pub lfs: Option<HuggingfaceLfs>,
176    pub path: String,
177    pub last_commit: Option<HuggingfaceLastCommit>,
178    pub security: Option<HuggingfaceSecurity>,
179}
180
181#[derive(Deserialize, Eq, PartialEq, Debug)]
182#[serde(rename_all = "camelCase")]
183#[allow(dead_code)]
184pub(super) struct HuggingfaceLfs {
185    pub oid: String,
186    pub size: u64,
187    pub pointer_size: u64,
188}
189
190#[derive(Deserialize, Eq, PartialEq, Debug)]
191#[serde(rename_all = "camelCase")]
192#[allow(dead_code)]
193pub(super) struct HuggingfaceLastCommit {
194    pub id: String,
195    pub title: String,
196    pub date: String,
197}
198
199#[derive(Deserialize, Eq, PartialEq, Debug)]
200#[serde(rename_all = "camelCase")]
201#[allow(dead_code)]
202pub(super) struct HuggingfaceSecurity {
203    pub blob_id: String,
204    pub safe: bool,
205    pub av_scan: Option<HuggingfaceAvScan>,
206    pub pickle_import_scan: Option<HuggingfacePickleImportScan>,
207}
208
209#[derive(Deserialize, Eq, PartialEq, Debug)]
210#[allow(dead_code)]
211#[serde(rename_all = "camelCase")]
212pub(super) struct HuggingfaceAvScan {
213    pub virus_found: bool,
214    pub virus_names: Option<Vec<String>>,
215}
216
217#[derive(Deserialize, Eq, PartialEq, Debug)]
218#[serde(rename_all = "camelCase")]
219#[allow(dead_code)]
220pub(super) struct HuggingfacePickleImportScan {
221    pub highest_safety_level: String,
222    pub imports: Vec<HuggingfaceImport>,
223}
224
225#[derive(Deserialize, Eq, PartialEq, Debug)]
226#[allow(dead_code)]
227pub(super) struct HuggingfaceImport {
228    pub module: String,
229    pub name: String,
230    pub safety: String,
231}
232
233#[cfg(test)]
234mod tests {
235    use bytes::Bytes;
236
237    use super::*;
238    use crate::raw::new_json_deserialize_error;
239    use crate::types::Result;
240
241    #[test]
242    fn parse_list_response_test() -> Result<()> {
243        let resp = Bytes::from(
244            r#"
245            [
246                {
247                    "type": "file",
248                    "oid": "45fa7c3d85ee7dd4139adbc056da25ae136a65f2",
249                    "size": 69512435,
250                    "lfs": {
251                        "oid": "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c",
252                        "size": 69512435,
253                        "pointerSize": 133
254                    },
255                    "path": "maelstrom/lib/maelstrom.jar"
256                },
257                {
258                    "type": "directory",
259                    "oid": "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c",
260                    "size": 69512435,
261                    "path": "maelstrom/lib/plugins"
262                }
263            ]
264            "#,
265        );
266
267        let decoded_response = serde_json::from_slice::<Vec<HuggingfaceStatus>>(&resp)
268            .map_err(new_json_deserialize_error)?;
269
270        assert_eq!(decoded_response.len(), 2);
271
272        let file_entry = HuggingfaceStatus {
273            type_: "file".to_string(),
274            oid: "45fa7c3d85ee7dd4139adbc056da25ae136a65f2".to_string(),
275            size: 69512435,
276            lfs: Some(HuggingfaceLfs {
277                oid: "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c".to_string(),
278                size: 69512435,
279                pointer_size: 133,
280            }),
281            path: "maelstrom/lib/maelstrom.jar".to_string(),
282            last_commit: None,
283            security: None,
284        };
285
286        assert_eq!(decoded_response[0], file_entry);
287
288        let dir_entry = HuggingfaceStatus {
289            type_: "directory".to_string(),
290            oid: "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c".to_string(),
291            size: 69512435,
292            lfs: None,
293            path: "maelstrom/lib/plugins".to_string(),
294            last_commit: None,
295            security: None,
296        };
297
298        assert_eq!(decoded_response[1], dir_entry);
299
300        Ok(())
301    }
302
303    #[test]
304    fn parse_files_info_test() -> Result<()> {
305        let resp = Bytes::from(
306            r#"
307            [
308                {
309                    "type": "file",
310                    "oid": "45fa7c3d85ee7dd4139adbc056da25ae136a65f2",
311                    "size": 69512435,
312                    "lfs": {
313                        "oid": "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c",
314                        "size": 69512435,
315                        "pointerSize": 133
316                    },
317                    "path": "maelstrom/lib/maelstrom.jar",
318                    "lastCommit": {
319                        "id": "bc1ef030bf3743290d5e190695ab94582e51ae2f",
320                        "title": "Upload 141 files",
321                        "date": "2023-11-17T23:50:28.000Z"
322                    },
323                    "security": {
324                        "blobId": "45fa7c3d85ee7dd4139adbc056da25ae136a65f2",
325                        "name": "maelstrom/lib/maelstrom.jar",
326                        "safe": true,
327                        "avScan": {
328                            "virusFound": false,
329                            "virusNames": null
330                        },
331                        "pickleImportScan": {
332                            "highestSafetyLevel": "innocuous",
333                            "imports": [
334                                {"module": "torch", "name": "FloatStorage", "safety": "innocuous"},
335                                {"module": "collections", "name": "OrderedDict", "safety": "innocuous"},
336                                {"module": "torch", "name": "LongStorage", "safety": "innocuous"},
337                                {"module": "torch._utils", "name": "_rebuild_tensor_v2", "safety": "innocuous"}
338                            ]
339                        }
340                    }
341                }
342            ]
343            "#,
344        );
345
346        let decoded_response = serde_json::from_slice::<Vec<HuggingfaceStatus>>(&resp)
347            .map_err(new_json_deserialize_error)?;
348
349        assert_eq!(decoded_response.len(), 1);
350
351        let file_info = HuggingfaceStatus {
352            type_: "file".to_string(),
353            oid: "45fa7c3d85ee7dd4139adbc056da25ae136a65f2".to_string(),
354            size: 69512435,
355            lfs: Some(HuggingfaceLfs {
356                oid: "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c".to_string(),
357                size: 69512435,
358                pointer_size: 133,
359            }),
360            path: "maelstrom/lib/maelstrom.jar".to_string(),
361            last_commit: Some(HuggingfaceLastCommit {
362                id: "bc1ef030bf3743290d5e190695ab94582e51ae2f".to_string(),
363                title: "Upload 141 files".to_string(),
364                date: "2023-11-17T23:50:28.000Z".to_string(),
365            }),
366            security: Some(HuggingfaceSecurity {
367                blob_id: "45fa7c3d85ee7dd4139adbc056da25ae136a65f2".to_string(),
368                safe: true,
369                av_scan: Some(HuggingfaceAvScan {
370                    virus_found: false,
371                    virus_names: None,
372                }),
373                pickle_import_scan: Some(HuggingfacePickleImportScan {
374                    highest_safety_level: "innocuous".to_string(),
375                    imports: vec![
376                        HuggingfaceImport {
377                            module: "torch".to_string(),
378                            name: "FloatStorage".to_string(),
379                            safety: "innocuous".to_string(),
380                        },
381                        HuggingfaceImport {
382                            module: "collections".to_string(),
383                            name: "OrderedDict".to_string(),
384                            safety: "innocuous".to_string(),
385                        },
386                        HuggingfaceImport {
387                            module: "torch".to_string(),
388                            name: "LongStorage".to_string(),
389                            safety: "innocuous".to_string(),
390                        },
391                        HuggingfaceImport {
392                            module: "torch._utils".to_string(),
393                            name: "_rebuild_tensor_v2".to_string(),
394                            safety: "innocuous".to_string(),
395                        },
396                    ],
397                }),
398            }),
399        };
400
401        assert_eq!(decoded_response[0], file_info);
402
403        Ok(())
404    }
405}