1use std::fmt::Debug;
19use std::sync::Arc;
20
21use bytes::Bytes;
22use http::header;
23use http::Request;
24use http::Response;
25use serde::Deserialize;
26
27use super::backend::RepoType;
28use crate::raw::*;
29use crate::*;
30
31pub struct HuggingfaceCore {
32 pub info: Arc<AccessorInfo>,
33
34 pub repo_type: RepoType,
35 pub repo_id: String,
36 pub revision: String,
37 pub root: String,
38 pub token: Option<String>,
39}
40
41impl Debug for HuggingfaceCore {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 f.debug_struct("HuggingfaceCore")
44 .field("repo_type", &self.repo_type)
45 .field("repo_id", &self.repo_id)
46 .field("revision", &self.revision)
47 .field("root", &self.root)
48 .finish_non_exhaustive()
49 }
50}
51
52impl HuggingfaceCore {
53 pub async fn hf_path_info(&self, path: &str) -> Result<Response<Buffer>> {
54 let p = build_abs_path(&self.root, path)
55 .trim_end_matches('/')
56 .to_string();
57
58 let url = match self.repo_type {
59 RepoType::Model => format!(
60 "https://huggingface.co/api/models/{}/paths-info/{}",
61 &self.repo_id, &self.revision
62 ),
63 RepoType::Dataset => format!(
64 "https://huggingface.co/api/datasets/{}/paths-info/{}",
65 &self.repo_id, &self.revision
66 ),
67 };
68
69 let mut req = Request::post(&url);
70 req = req.extension(Operation::Stat);
72 if let Some(token) = &self.token {
73 let auth_header_content = format_authorization_by_bearer(token)?;
74 req = req.header(header::AUTHORIZATION, auth_header_content);
75 }
76
77 req = req.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded");
78
79 let req_body = format!("paths={}&expand=True", percent_encode_path(&p));
80
81 let req = req
82 .body(Buffer::from(Bytes::from(req_body)))
83 .map_err(new_request_build_error)?;
84
85 self.info.http_client().send(req).await
86 }
87
88 pub async fn hf_list(&self, path: &str, recursive: bool) -> Result<Response<Buffer>> {
89 let p = build_abs_path(&self.root, path)
90 .trim_end_matches('/')
91 .to_string();
92
93 let mut url = match self.repo_type {
94 RepoType::Model => format!(
95 "https://huggingface.co/api/models/{}/tree/{}/{}?expand=True",
96 &self.repo_id,
97 &self.revision,
98 percent_encode_path(&p)
99 ),
100 RepoType::Dataset => format!(
101 "https://huggingface.co/api/datasets/{}/tree/{}/{}?expand=True",
102 &self.repo_id,
103 &self.revision,
104 percent_encode_path(&p)
105 ),
106 };
107
108 if recursive {
109 url.push_str("&recursive=True");
110 }
111
112 let mut req = Request::get(&url);
113 req = req.extension(Operation::List);
115 if let Some(token) = &self.token {
116 let auth_header_content = format_authorization_by_bearer(token)?;
117 req = req.header(header::AUTHORIZATION, auth_header_content);
118 }
119
120 let req = req.body(Buffer::new()).map_err(new_request_build_error)?;
121
122 self.info.http_client().send(req).await
123 }
124
125 pub async fn hf_resolve(
126 &self,
127 path: &str,
128 range: BytesRange,
129 _args: &OpRead,
130 ) -> Result<Response<HttpBody>> {
131 let p = build_abs_path(&self.root, path)
132 .trim_end_matches('/')
133 .to_string();
134
135 let url = match self.repo_type {
136 RepoType::Model => format!(
137 "https://huggingface.co/{}/resolve/{}/{}",
138 &self.repo_id,
139 &self.revision,
140 percent_encode_path(&p)
141 ),
142 RepoType::Dataset => format!(
143 "https://huggingface.co/datasets/{}/resolve/{}/{}",
144 &self.repo_id,
145 &self.revision,
146 percent_encode_path(&p)
147 ),
148 };
149
150 let mut req = Request::get(&url);
151
152 if let Some(token) = &self.token {
153 let auth_header_content = format_authorization_by_bearer(token)?;
154 req = req.header(header::AUTHORIZATION, auth_header_content);
155 }
156
157 if !range.is_full() {
158 req = req.header(header::RANGE, range.to_header());
159 }
160 let req = req.extension(Operation::Read);
162 let req = req.body(Buffer::new()).map_err(new_request_build_error)?;
163
164 self.info.http_client().fetch(req).await
165 }
166}
167
168#[derive(Deserialize, Eq, PartialEq, Debug)]
169#[serde(rename_all = "camelCase")]
170#[allow(dead_code)]
171pub(super) struct HuggingfaceStatus {
172 #[serde(rename = "type")]
173 pub type_: String,
174 pub oid: String,
175 pub size: u64,
176 pub lfs: Option<HuggingfaceLfs>,
177 pub path: String,
178 pub last_commit: Option<HuggingfaceLastCommit>,
179 pub security: Option<HuggingfaceSecurity>,
180}
181
182#[derive(Deserialize, Eq, PartialEq, Debug)]
183#[serde(rename_all = "camelCase")]
184#[allow(dead_code)]
185pub(super) struct HuggingfaceLfs {
186 pub oid: String,
187 pub size: u64,
188 pub pointer_size: u64,
189}
190
191#[derive(Deserialize, Eq, PartialEq, Debug)]
192#[serde(rename_all = "camelCase")]
193#[allow(dead_code)]
194pub(super) struct HuggingfaceLastCommit {
195 pub id: String,
196 pub title: String,
197 pub date: String,
198}
199
200#[derive(Deserialize, Eq, PartialEq, Debug)]
201#[serde(rename_all = "camelCase")]
202#[allow(dead_code)]
203pub(super) struct HuggingfaceSecurity {
204 pub blob_id: String,
205 pub safe: bool,
206 pub av_scan: Option<HuggingfaceAvScan>,
207 pub pickle_import_scan: Option<HuggingfacePickleImportScan>,
208}
209
210#[derive(Deserialize, Eq, PartialEq, Debug)]
211#[allow(dead_code)]
212#[serde(rename_all = "camelCase")]
213pub(super) struct HuggingfaceAvScan {
214 pub virus_found: bool,
215 pub virus_names: Option<Vec<String>>,
216}
217
218#[derive(Deserialize, Eq, PartialEq, Debug)]
219#[serde(rename_all = "camelCase")]
220#[allow(dead_code)]
221pub(super) struct HuggingfacePickleImportScan {
222 pub highest_safety_level: String,
223 pub imports: Vec<HuggingfaceImport>,
224}
225
226#[derive(Deserialize, Eq, PartialEq, Debug)]
227#[allow(dead_code)]
228pub(super) struct HuggingfaceImport {
229 pub module: String,
230 pub name: String,
231 pub safety: String,
232}
233
234#[cfg(test)]
235mod tests {
236 use bytes::Bytes;
237
238 use super::*;
239 use crate::raw::new_json_deserialize_error;
240 use crate::types::Result;
241
242 #[test]
243 fn parse_list_response_test() -> Result<()> {
244 let resp = Bytes::from(
245 r#"
246 [
247 {
248 "type": "file",
249 "oid": "45fa7c3d85ee7dd4139adbc056da25ae136a65f2",
250 "size": 69512435,
251 "lfs": {
252 "oid": "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c",
253 "size": 69512435,
254 "pointerSize": 133
255 },
256 "path": "maelstrom/lib/maelstrom.jar"
257 },
258 {
259 "type": "directory",
260 "oid": "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c",
261 "size": 69512435,
262 "path": "maelstrom/lib/plugins"
263 }
264 ]
265 "#,
266 );
267
268 let decoded_response = serde_json::from_slice::<Vec<HuggingfaceStatus>>(&resp)
269 .map_err(new_json_deserialize_error)?;
270
271 assert_eq!(decoded_response.len(), 2);
272
273 let file_entry = HuggingfaceStatus {
274 type_: "file".to_string(),
275 oid: "45fa7c3d85ee7dd4139adbc056da25ae136a65f2".to_string(),
276 size: 69512435,
277 lfs: Some(HuggingfaceLfs {
278 oid: "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c".to_string(),
279 size: 69512435,
280 pointer_size: 133,
281 }),
282 path: "maelstrom/lib/maelstrom.jar".to_string(),
283 last_commit: None,
284 security: None,
285 };
286
287 assert_eq!(decoded_response[0], file_entry);
288
289 let dir_entry = HuggingfaceStatus {
290 type_: "directory".to_string(),
291 oid: "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c".to_string(),
292 size: 69512435,
293 lfs: None,
294 path: "maelstrom/lib/plugins".to_string(),
295 last_commit: None,
296 security: None,
297 };
298
299 assert_eq!(decoded_response[1], dir_entry);
300
301 Ok(())
302 }
303
304 #[test]
305 fn parse_files_info_test() -> Result<()> {
306 let resp = Bytes::from(
307 r#"
308 [
309 {
310 "type": "file",
311 "oid": "45fa7c3d85ee7dd4139adbc056da25ae136a65f2",
312 "size": 69512435,
313 "lfs": {
314 "oid": "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c",
315 "size": 69512435,
316 "pointerSize": 133
317 },
318 "path": "maelstrom/lib/maelstrom.jar",
319 "lastCommit": {
320 "id": "bc1ef030bf3743290d5e190695ab94582e51ae2f",
321 "title": "Upload 141 files",
322 "date": "2023-11-17T23:50:28.000Z"
323 },
324 "security": {
325 "blobId": "45fa7c3d85ee7dd4139adbc056da25ae136a65f2",
326 "name": "maelstrom/lib/maelstrom.jar",
327 "safe": true,
328 "avScan": {
329 "virusFound": false,
330 "virusNames": null
331 },
332 "pickleImportScan": {
333 "highestSafetyLevel": "innocuous",
334 "imports": [
335 {"module": "torch", "name": "FloatStorage", "safety": "innocuous"},
336 {"module": "collections", "name": "OrderedDict", "safety": "innocuous"},
337 {"module": "torch", "name": "LongStorage", "safety": "innocuous"},
338 {"module": "torch._utils", "name": "_rebuild_tensor_v2", "safety": "innocuous"}
339 ]
340 }
341 }
342 }
343 ]
344 "#,
345 );
346
347 let decoded_response = serde_json::from_slice::<Vec<HuggingfaceStatus>>(&resp)
348 .map_err(new_json_deserialize_error)?;
349
350 assert_eq!(decoded_response.len(), 1);
351
352 let file_info = HuggingfaceStatus {
353 type_: "file".to_string(),
354 oid: "45fa7c3d85ee7dd4139adbc056da25ae136a65f2".to_string(),
355 size: 69512435,
356 lfs: Some(HuggingfaceLfs {
357 oid: "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c".to_string(),
358 size: 69512435,
359 pointer_size: 133,
360 }),
361 path: "maelstrom/lib/maelstrom.jar".to_string(),
362 last_commit: Some(HuggingfaceLastCommit {
363 id: "bc1ef030bf3743290d5e190695ab94582e51ae2f".to_string(),
364 title: "Upload 141 files".to_string(),
365 date: "2023-11-17T23:50:28.000Z".to_string(),
366 }),
367 security: Some(HuggingfaceSecurity {
368 blob_id: "45fa7c3d85ee7dd4139adbc056da25ae136a65f2".to_string(),
369 safe: true,
370 av_scan: Some(HuggingfaceAvScan {
371 virus_found: false,
372 virus_names: None,
373 }),
374 pickle_import_scan: Some(HuggingfacePickleImportScan {
375 highest_safety_level: "innocuous".to_string(),
376 imports: vec![
377 HuggingfaceImport {
378 module: "torch".to_string(),
379 name: "FloatStorage".to_string(),
380 safety: "innocuous".to_string(),
381 },
382 HuggingfaceImport {
383 module: "collections".to_string(),
384 name: "OrderedDict".to_string(),
385 safety: "innocuous".to_string(),
386 },
387 HuggingfaceImport {
388 module: "torch".to_string(),
389 name: "LongStorage".to_string(),
390 safety: "innocuous".to_string(),
391 },
392 HuggingfaceImport {
393 module: "torch._utils".to_string(),
394 name: "_rebuild_tensor_v2".to_string(),
395 safety: "innocuous".to_string(),
396 },
397 ],
398 }),
399 }),
400 };
401
402 assert_eq!(decoded_response[0], file_info);
403
404 Ok(())
405 }
406}