opendal/services/huggingface/
core.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::fmt::Debug;
19use std::sync::Arc;
20
21use bytes::Bytes;
22use http::header;
23use http::Request;
24use http::Response;
25use serde::Deserialize;
26
27use super::backend::RepoType;
28use crate::raw::*;
29use crate::*;
30
31pub struct HuggingfaceCore {
32    pub info: Arc<AccessorInfo>,
33
34    pub repo_type: RepoType,
35    pub repo_id: String,
36    pub revision: String,
37    pub root: String,
38    pub token: Option<String>,
39}
40
41impl Debug for HuggingfaceCore {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        f.debug_struct("HuggingfaceCore")
44            .field("repo_type", &self.repo_type)
45            .field("repo_id", &self.repo_id)
46            .field("revision", &self.revision)
47            .field("root", &self.root)
48            .finish_non_exhaustive()
49    }
50}
51
52impl HuggingfaceCore {
53    pub async fn hf_path_info(&self, path: &str) -> Result<Response<Buffer>> {
54        let p = build_abs_path(&self.root, path)
55            .trim_end_matches('/')
56            .to_string();
57
58        let url = match self.repo_type {
59            RepoType::Model => format!(
60                "https://huggingface.co/api/models/{}/paths-info/{}",
61                &self.repo_id, &self.revision
62            ),
63            RepoType::Dataset => format!(
64                "https://huggingface.co/api/datasets/{}/paths-info/{}",
65                &self.repo_id, &self.revision
66            ),
67        };
68
69        let mut req = Request::post(&url);
70        // Inject operation to the request.
71        req = req.extension(Operation::Stat);
72        if let Some(token) = &self.token {
73            let auth_header_content = format_authorization_by_bearer(token)?;
74            req = req.header(header::AUTHORIZATION, auth_header_content);
75        }
76
77        req = req.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded");
78
79        let req_body = format!("paths={}&expand=True", percent_encode_path(&p));
80
81        let req = req
82            .body(Buffer::from(Bytes::from(req_body)))
83            .map_err(new_request_build_error)?;
84
85        self.info.http_client().send(req).await
86    }
87
88    pub async fn hf_list(&self, path: &str, recursive: bool) -> Result<Response<Buffer>> {
89        let p = build_abs_path(&self.root, path)
90            .trim_end_matches('/')
91            .to_string();
92
93        let mut url = match self.repo_type {
94            RepoType::Model => format!(
95                "https://huggingface.co/api/models/{}/tree/{}/{}?expand=True",
96                &self.repo_id,
97                &self.revision,
98                percent_encode_path(&p)
99            ),
100            RepoType::Dataset => format!(
101                "https://huggingface.co/api/datasets/{}/tree/{}/{}?expand=True",
102                &self.repo_id,
103                &self.revision,
104                percent_encode_path(&p)
105            ),
106        };
107
108        if recursive {
109            url.push_str("&recursive=True");
110        }
111
112        let mut req = Request::get(&url);
113        // Inject operation to the request.
114        req = req.extension(Operation::List);
115        if let Some(token) = &self.token {
116            let auth_header_content = format_authorization_by_bearer(token)?;
117            req = req.header(header::AUTHORIZATION, auth_header_content);
118        }
119
120        let req = req.body(Buffer::new()).map_err(new_request_build_error)?;
121
122        self.info.http_client().send(req).await
123    }
124
125    pub async fn hf_resolve(
126        &self,
127        path: &str,
128        range: BytesRange,
129        _args: &OpRead,
130    ) -> Result<Response<HttpBody>> {
131        let p = build_abs_path(&self.root, path)
132            .trim_end_matches('/')
133            .to_string();
134
135        let url = match self.repo_type {
136            RepoType::Model => format!(
137                "https://huggingface.co/{}/resolve/{}/{}",
138                &self.repo_id,
139                &self.revision,
140                percent_encode_path(&p)
141            ),
142            RepoType::Dataset => format!(
143                "https://huggingface.co/datasets/{}/resolve/{}/{}",
144                &self.repo_id,
145                &self.revision,
146                percent_encode_path(&p)
147            ),
148        };
149
150        let mut req = Request::get(&url);
151
152        if let Some(token) = &self.token {
153            let auth_header_content = format_authorization_by_bearer(token)?;
154            req = req.header(header::AUTHORIZATION, auth_header_content);
155        }
156
157        if !range.is_full() {
158            req = req.header(header::RANGE, range.to_header());
159        }
160        // Inject operation to the request.
161        let req = req.extension(Operation::Read);
162        let req = req.body(Buffer::new()).map_err(new_request_build_error)?;
163
164        self.info.http_client().fetch(req).await
165    }
166}
167
168#[derive(Deserialize, Eq, PartialEq, Debug)]
169#[serde(rename_all = "camelCase")]
170#[allow(dead_code)]
171pub(super) struct HuggingfaceStatus {
172    #[serde(rename = "type")]
173    pub type_: String,
174    pub oid: String,
175    pub size: u64,
176    pub lfs: Option<HuggingfaceLfs>,
177    pub path: String,
178    pub last_commit: Option<HuggingfaceLastCommit>,
179    pub security: Option<HuggingfaceSecurity>,
180}
181
182#[derive(Deserialize, Eq, PartialEq, Debug)]
183#[serde(rename_all = "camelCase")]
184#[allow(dead_code)]
185pub(super) struct HuggingfaceLfs {
186    pub oid: String,
187    pub size: u64,
188    pub pointer_size: u64,
189}
190
191#[derive(Deserialize, Eq, PartialEq, Debug)]
192#[serde(rename_all = "camelCase")]
193#[allow(dead_code)]
194pub(super) struct HuggingfaceLastCommit {
195    pub id: String,
196    pub title: String,
197    pub date: String,
198}
199
200#[derive(Deserialize, Eq, PartialEq, Debug)]
201#[serde(rename_all = "camelCase")]
202#[allow(dead_code)]
203pub(super) struct HuggingfaceSecurity {
204    pub blob_id: String,
205    pub safe: bool,
206    pub av_scan: Option<HuggingfaceAvScan>,
207    pub pickle_import_scan: Option<HuggingfacePickleImportScan>,
208}
209
210#[derive(Deserialize, Eq, PartialEq, Debug)]
211#[allow(dead_code)]
212#[serde(rename_all = "camelCase")]
213pub(super) struct HuggingfaceAvScan {
214    pub virus_found: bool,
215    pub virus_names: Option<Vec<String>>,
216}
217
218#[derive(Deserialize, Eq, PartialEq, Debug)]
219#[serde(rename_all = "camelCase")]
220#[allow(dead_code)]
221pub(super) struct HuggingfacePickleImportScan {
222    pub highest_safety_level: String,
223    pub imports: Vec<HuggingfaceImport>,
224}
225
226#[derive(Deserialize, Eq, PartialEq, Debug)]
227#[allow(dead_code)]
228pub(super) struct HuggingfaceImport {
229    pub module: String,
230    pub name: String,
231    pub safety: String,
232}
233
234#[cfg(test)]
235mod tests {
236    use bytes::Bytes;
237
238    use super::*;
239    use crate::raw::new_json_deserialize_error;
240    use crate::types::Result;
241
242    #[test]
243    fn parse_list_response_test() -> Result<()> {
244        let resp = Bytes::from(
245            r#"
246            [
247                {
248                    "type": "file",
249                    "oid": "45fa7c3d85ee7dd4139adbc056da25ae136a65f2",
250                    "size": 69512435,
251                    "lfs": {
252                        "oid": "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c",
253                        "size": 69512435,
254                        "pointerSize": 133
255                    },
256                    "path": "maelstrom/lib/maelstrom.jar"
257                },
258                {
259                    "type": "directory",
260                    "oid": "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c",
261                    "size": 69512435,
262                    "path": "maelstrom/lib/plugins"
263                }
264            ]
265            "#,
266        );
267
268        let decoded_response = serde_json::from_slice::<Vec<HuggingfaceStatus>>(&resp)
269            .map_err(new_json_deserialize_error)?;
270
271        assert_eq!(decoded_response.len(), 2);
272
273        let file_entry = HuggingfaceStatus {
274            type_: "file".to_string(),
275            oid: "45fa7c3d85ee7dd4139adbc056da25ae136a65f2".to_string(),
276            size: 69512435,
277            lfs: Some(HuggingfaceLfs {
278                oid: "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c".to_string(),
279                size: 69512435,
280                pointer_size: 133,
281            }),
282            path: "maelstrom/lib/maelstrom.jar".to_string(),
283            last_commit: None,
284            security: None,
285        };
286
287        assert_eq!(decoded_response[0], file_entry);
288
289        let dir_entry = HuggingfaceStatus {
290            type_: "directory".to_string(),
291            oid: "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c".to_string(),
292            size: 69512435,
293            lfs: None,
294            path: "maelstrom/lib/plugins".to_string(),
295            last_commit: None,
296            security: None,
297        };
298
299        assert_eq!(decoded_response[1], dir_entry);
300
301        Ok(())
302    }
303
304    #[test]
305    fn parse_files_info_test() -> Result<()> {
306        let resp = Bytes::from(
307            r#"
308            [
309                {
310                    "type": "file",
311                    "oid": "45fa7c3d85ee7dd4139adbc056da25ae136a65f2",
312                    "size": 69512435,
313                    "lfs": {
314                        "oid": "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c",
315                        "size": 69512435,
316                        "pointerSize": 133
317                    },
318                    "path": "maelstrom/lib/maelstrom.jar",
319                    "lastCommit": {
320                        "id": "bc1ef030bf3743290d5e190695ab94582e51ae2f",
321                        "title": "Upload 141 files",
322                        "date": "2023-11-17T23:50:28.000Z"
323                    },
324                    "security": {
325                        "blobId": "45fa7c3d85ee7dd4139adbc056da25ae136a65f2",
326                        "name": "maelstrom/lib/maelstrom.jar",
327                        "safe": true,
328                        "avScan": {
329                            "virusFound": false,
330                            "virusNames": null
331                        },
332                        "pickleImportScan": {
333                            "highestSafetyLevel": "innocuous",
334                            "imports": [
335                                {"module": "torch", "name": "FloatStorage", "safety": "innocuous"},
336                                {"module": "collections", "name": "OrderedDict", "safety": "innocuous"},
337                                {"module": "torch", "name": "LongStorage", "safety": "innocuous"},
338                                {"module": "torch._utils", "name": "_rebuild_tensor_v2", "safety": "innocuous"}
339                            ]
340                        }
341                    }
342                }
343            ]
344            "#,
345        );
346
347        let decoded_response = serde_json::from_slice::<Vec<HuggingfaceStatus>>(&resp)
348            .map_err(new_json_deserialize_error)?;
349
350        assert_eq!(decoded_response.len(), 1);
351
352        let file_info = HuggingfaceStatus {
353            type_: "file".to_string(),
354            oid: "45fa7c3d85ee7dd4139adbc056da25ae136a65f2".to_string(),
355            size: 69512435,
356            lfs: Some(HuggingfaceLfs {
357                oid: "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c".to_string(),
358                size: 69512435,
359                pointer_size: 133,
360            }),
361            path: "maelstrom/lib/maelstrom.jar".to_string(),
362            last_commit: Some(HuggingfaceLastCommit {
363                id: "bc1ef030bf3743290d5e190695ab94582e51ae2f".to_string(),
364                title: "Upload 141 files".to_string(),
365                date: "2023-11-17T23:50:28.000Z".to_string(),
366            }),
367            security: Some(HuggingfaceSecurity {
368                blob_id: "45fa7c3d85ee7dd4139adbc056da25ae136a65f2".to_string(),
369                safe: true,
370                av_scan: Some(HuggingfaceAvScan {
371                    virus_found: false,
372                    virus_names: None,
373                }),
374                pickle_import_scan: Some(HuggingfacePickleImportScan {
375                    highest_safety_level: "innocuous".to_string(),
376                    imports: vec![
377                        HuggingfaceImport {
378                            module: "torch".to_string(),
379                            name: "FloatStorage".to_string(),
380                            safety: "innocuous".to_string(),
381                        },
382                        HuggingfaceImport {
383                            module: "collections".to_string(),
384                            name: "OrderedDict".to_string(),
385                            safety: "innocuous".to_string(),
386                        },
387                        HuggingfaceImport {
388                            module: "torch".to_string(),
389                            name: "LongStorage".to_string(),
390                            safety: "innocuous".to_string(),
391                        },
392                        HuggingfaceImport {
393                            module: "torch._utils".to_string(),
394                            name: "_rebuild_tensor_v2".to_string(),
395                            safety: "innocuous".to_string(),
396                        },
397                    ],
398                }),
399            }),
400        };
401
402        assert_eq!(decoded_response[0], file_info);
403
404        Ok(())
405    }
406}