1use std::fmt::Debug;
19use std::sync::Arc;
20
21use bytes::Bytes;
22use http::Request;
23use http::Response;
24use http::header;
25use percent_encoding::{NON_ALPHANUMERIC, utf8_percent_encode};
26use serde::Deserialize;
27
28use super::backend::RepoType;
29use crate::raw::*;
30use crate::*;
31
32fn percent_encode_revision(revision: &str) -> String {
33 utf8_percent_encode(revision, NON_ALPHANUMERIC).to_string()
34}
35
36pub struct HuggingfaceCore {
37 pub info: Arc<AccessorInfo>,
38
39 pub repo_type: RepoType,
40 pub repo_id: String,
41 pub revision: String,
42 pub root: String,
43 pub token: Option<String>,
44 pub endpoint: String,
45}
46
47impl Debug for HuggingfaceCore {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 f.debug_struct("HuggingfaceCore")
50 .field("repo_type", &self.repo_type)
51 .field("repo_id", &self.repo_id)
52 .field("revision", &self.revision)
53 .field("root", &self.root)
54 .field("endpoint", &self.endpoint)
55 .finish_non_exhaustive()
56 }
57}
58
59impl HuggingfaceCore {
60 pub async fn hf_path_info(&self, path: &str) -> Result<Response<Buffer>> {
61 let p = build_abs_path(&self.root, path)
62 .trim_end_matches('/')
63 .to_string();
64
65 let url = match self.repo_type {
66 RepoType::Model => format!(
67 "{}/api/models/{}/paths-info/{}",
68 &self.endpoint,
69 &self.repo_id,
70 percent_encode_revision(&self.revision)
71 ),
72 RepoType::Dataset => format!(
73 "{}/api/datasets/{}/paths-info/{}",
74 &self.endpoint,
75 &self.repo_id,
76 percent_encode_revision(&self.revision)
77 ),
78 RepoType::Space => format!(
79 "{}/api/spaces/{}/paths-info/{}",
80 &self.endpoint,
81 &self.repo_id,
82 percent_encode_revision(&self.revision)
83 ),
84 };
85
86 let mut req = Request::post(&url);
87 req = req.extension(Operation::Stat);
89 if let Some(token) = &self.token {
90 let auth_header_content = format_authorization_by_bearer(token)?;
91 req = req.header(header::AUTHORIZATION, auth_header_content);
92 }
93
94 req = req.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded");
95
96 let req_body = format!("paths={}&expand=True", percent_encode_path(&p));
97
98 let req = req
99 .body(Buffer::from(Bytes::from(req_body)))
100 .map_err(new_request_build_error)?;
101
102 self.info.http_client().send(req).await
103 }
104
105 pub async fn hf_list(
106 &self,
107 path: &str,
108 recursive: bool,
109 cursor: Option<&str>,
110 ) -> Result<Response<Buffer>> {
111 let p = build_abs_path(&self.root, path)
112 .trim_end_matches('/')
113 .to_string();
114
115 let mut url = match self.repo_type {
116 RepoType::Model => format!(
117 "{}/api/models/{}/tree/{}/{}?expand=True",
118 &self.endpoint,
119 &self.repo_id,
120 percent_encode_revision(&self.revision),
121 percent_encode_path(&p)
122 ),
123 RepoType::Dataset => format!(
124 "{}/api/datasets/{}/tree/{}/{}?expand=True",
125 &self.endpoint,
126 &self.repo_id,
127 percent_encode_revision(&self.revision),
128 percent_encode_path(&p)
129 ),
130 RepoType::Space => format!(
131 "{}/api/spaces/{}/tree/{}/{}?expand=True",
132 &self.endpoint,
133 &self.repo_id,
134 percent_encode_revision(&self.revision),
135 percent_encode_path(&p)
136 ),
137 };
138
139 if recursive {
140 url.push_str("&recursive=True");
141 }
142
143 if let Some(cursor_val) = cursor {
144 url.push_str(&format!("&cursor={}", cursor_val));
145 }
146
147 let mut req = Request::get(&url);
148 req = req.extension(Operation::List);
150 if let Some(token) = &self.token {
151 let auth_header_content = format_authorization_by_bearer(token)?;
152 req = req.header(header::AUTHORIZATION, auth_header_content);
153 }
154
155 let req = req.body(Buffer::new()).map_err(new_request_build_error)?;
156
157 self.info.http_client().send(req).await
158 }
159
160 pub async fn hf_list_with_url(&self, url: &str) -> Result<Response<Buffer>> {
161 let mut req = Request::get(url);
162 req = req.extension(Operation::List);
164 if let Some(token) = &self.token {
165 let auth_header_content = format_authorization_by_bearer(token)?;
166 req = req.header(header::AUTHORIZATION, auth_header_content);
167 }
168
169 let req = req.body(Buffer::new()).map_err(new_request_build_error)?;
170
171 self.info.http_client().send(req).await
172 }
173
174 pub async fn hf_resolve(
175 &self,
176 path: &str,
177 range: BytesRange,
178 _args: &OpRead,
179 ) -> Result<Response<HttpBody>> {
180 let p = build_abs_path(&self.root, path)
181 .trim_end_matches('/')
182 .to_string();
183
184 let url = match self.repo_type {
185 RepoType::Model => format!(
186 "{}/{}/resolve/{}/{}",
187 &self.endpoint,
188 &self.repo_id,
189 percent_encode_revision(&self.revision),
190 percent_encode_path(&p)
191 ),
192 RepoType::Dataset => format!(
193 "{}/datasets/{}/resolve/{}/{}",
194 &self.endpoint,
195 &self.repo_id,
196 percent_encode_revision(&self.revision),
197 percent_encode_path(&p)
198 ),
199 RepoType::Space => format!(
200 "{}/spaces/{}/resolve/{}/{}",
201 &self.endpoint,
202 &self.repo_id,
203 percent_encode_revision(&self.revision),
204 percent_encode_path(&p)
205 ),
206 };
207
208 let mut req = Request::get(&url);
209
210 if let Some(token) = &self.token {
211 let auth_header_content = format_authorization_by_bearer(token)?;
212 req = req.header(header::AUTHORIZATION, auth_header_content);
213 }
214
215 if !range.is_full() {
216 req = req.header(header::RANGE, range.to_header());
217 }
218 let req = req.extension(Operation::Read);
220 let req = req.body(Buffer::new()).map_err(new_request_build_error)?;
221
222 self.info.http_client().fetch(req).await
223 }
224}
225
226#[derive(Deserialize, Eq, PartialEq, Debug)]
227#[serde(rename_all = "camelCase")]
228#[allow(dead_code)]
229pub(super) struct HuggingfaceStatus {
230 #[serde(rename = "type")]
231 pub type_: String,
232 pub oid: String,
233 pub size: u64,
234 pub lfs: Option<HuggingfaceLfs>,
235 pub path: String,
236 pub last_commit: Option<HuggingfaceLastCommit>,
237 pub security: Option<HuggingfaceSecurity>,
238}
239
240#[derive(Deserialize, Eq, PartialEq, Debug)]
241#[serde(rename_all = "camelCase")]
242#[allow(dead_code)]
243pub(super) struct HuggingfaceLfs {
244 pub oid: String,
245 pub size: u64,
246 pub pointer_size: u64,
247}
248
249#[derive(Deserialize, Eq, PartialEq, Debug)]
250#[serde(rename_all = "camelCase")]
251#[allow(dead_code)]
252pub(super) struct HuggingfaceLastCommit {
253 pub id: String,
254 pub title: String,
255 pub date: String,
256}
257
258#[derive(Deserialize, Eq, PartialEq, Debug)]
259#[serde(rename_all = "camelCase")]
260#[allow(dead_code)]
261pub(super) struct HuggingfaceSecurity {
262 pub blob_id: String,
263 pub safe: bool,
264 pub av_scan: Option<HuggingfaceAvScan>,
265 pub pickle_import_scan: Option<HuggingfacePickleImportScan>,
266}
267
268#[derive(Deserialize, Eq, PartialEq, Debug)]
269#[allow(dead_code)]
270#[serde(rename_all = "camelCase")]
271pub(super) struct HuggingfaceAvScan {
272 pub virus_found: bool,
273 pub virus_names: Option<Vec<String>>,
274}
275
276#[derive(Deserialize, Eq, PartialEq, Debug)]
277#[serde(rename_all = "camelCase")]
278#[allow(dead_code)]
279pub(super) struct HuggingfacePickleImportScan {
280 pub highest_safety_level: String,
281 pub imports: Vec<HuggingfaceImport>,
282}
283
284#[derive(Deserialize, Eq, PartialEq, Debug)]
285#[allow(dead_code)]
286pub(super) struct HuggingfaceImport {
287 pub module: String,
288 pub name: String,
289 pub safety: String,
290}
291
292#[cfg(test)]
293mod tests {
294 use bytes::Bytes;
295
296 use super::*;
297 use crate::raw::new_json_deserialize_error;
298 use crate::types::Result;
299 use http::{Request, Response, StatusCode};
300 use std::sync::{Arc, Mutex};
301
302 #[derive(Clone)]
304 struct MockHttpClient {
305 url: Arc<Mutex<Option<String>>>,
306 headers: Arc<Mutex<Option<http::HeaderMap>>>,
307 }
308
309 impl MockHttpClient {
310 fn new() -> Self {
311 Self {
312 url: Arc::new(Mutex::new(None)),
313 headers: Arc::new(Mutex::new(None)),
314 }
315 }
316
317 fn get_captured_url(&self) -> String {
318 self.url.lock().unwrap().clone().unwrap()
319 }
320
321 fn get_captured_headers(&self) -> http::HeaderMap {
322 self.headers.lock().unwrap().clone().unwrap()
323 }
324 }
325
326 impl HttpFetch for MockHttpClient {
327 async fn fetch(&self, req: Request<Buffer>) -> Result<Response<HttpBody>> {
328 *self.url.lock().unwrap() = Some(req.uri().to_string());
330 *self.headers.lock().unwrap() = Some(req.headers().clone());
331
332 Ok(Response::builder()
334 .status(StatusCode::OK)
335 .body(HttpBody::new(futures::stream::empty(), Some(0)))
336 .unwrap())
337 }
338 }
339
340 fn create_test_core(
342 repo_type: RepoType,
343 repo_id: &str,
344 revision: &str,
345 endpoint: &str,
346 ) -> (HuggingfaceCore, MockHttpClient) {
347 let mock_client = MockHttpClient::new();
348 let http_client = HttpClient::with(mock_client.clone());
349
350 let info = AccessorInfo::default();
351 info.set_scheme("huggingface")
352 .set_native_capability(Capability::default());
353 info.update_http_client(|_| http_client);
354
355 let core = HuggingfaceCore {
356 info: Arc::new(info),
357 repo_type,
358 repo_id: repo_id.to_string(),
359 revision: revision.to_string(),
360 root: "/".to_string(),
361 token: None,
362 endpoint: endpoint.to_string(),
363 };
364
365 (core, mock_client)
366 }
367
368 #[tokio::test]
369 async fn test_hf_path_info_url_model() -> Result<()> {
370 let (core, mock_client) = create_test_core(
371 RepoType::Model,
372 "test-user/test-repo",
373 "main",
374 "https://huggingface.co",
375 );
376
377 core.hf_path_info("test.txt").await?;
378
379 let url = mock_client.get_captured_url();
380 assert_eq!(
381 url,
382 "https://huggingface.co/api/models/test-user/test-repo/paths-info/main"
383 );
384
385 Ok(())
386 }
387
388 #[tokio::test]
389 async fn test_hf_path_info_url_dataset() -> Result<()> {
390 let (core, mock_client) = create_test_core(
391 RepoType::Dataset,
392 "test-org/test-dataset",
393 "v1.0.0",
394 "https://huggingface.co",
395 );
396
397 core.hf_path_info("data/file.csv").await?;
398
399 let url = mock_client.get_captured_url();
400 assert_eq!(
401 url,
402 "https://huggingface.co/api/datasets/test-org/test-dataset/paths-info/v1%2E0%2E0"
403 );
404
405 Ok(())
406 }
407
408 #[tokio::test]
409 async fn test_hf_path_info_url_custom_endpoint() -> Result<()> {
410 let (core, mock_client) = create_test_core(
411 RepoType::Model,
412 "test-org/test-dataset",
413 "refs/convert/parquet",
414 "https://custom-hf.example.com",
415 );
416
417 core.hf_path_info("model.bin").await?;
418
419 let url = mock_client.get_captured_url();
420 assert_eq!(
421 url,
422 "https://custom-hf.example.com/api/models/test-org/test-dataset/paths-info/refs%2Fconvert%2Fparquet"
423 );
424
425 Ok(())
426 }
427
428 #[tokio::test]
429 async fn test_hf_list_url_non_recursive() -> Result<()> {
430 let (core, mock_client) = create_test_core(
431 RepoType::Model,
432 "org/model",
433 "main",
434 "https://huggingface.co",
435 );
436
437 core.hf_list("path1", false, None).await?;
438
439 let url = mock_client.get_captured_url();
440 assert_eq!(
441 url,
442 "https://huggingface.co/api/models/org/model/tree/main/path1?expand=True"
443 );
444
445 Ok(())
446 }
447
448 #[tokio::test]
449 async fn test_hf_list_url_recursive() -> Result<()> {
450 let (core, mock_client) = create_test_core(
451 RepoType::Model,
452 "org/model",
453 "main",
454 "https://huggingface.co",
455 );
456
457 core.hf_list("path2", true, None).await?;
458
459 let url = mock_client.get_captured_url();
460 assert_eq!(
461 url,
462 "https://huggingface.co/api/models/org/model/tree/main/path2?expand=True&recursive=True"
463 );
464
465 Ok(())
466 }
467
468 #[tokio::test]
469 async fn test_hf_list_url_with_cursor() -> Result<()> {
470 let (core, mock_client) = create_test_core(
471 RepoType::Model,
472 "org/model",
473 "main",
474 "https://huggingface.co",
475 );
476
477 core.hf_list("path3", false, Some("abc123")).await?;
478
479 let url = mock_client.get_captured_url();
480 assert_eq!(
481 url,
482 "https://huggingface.co/api/models/org/model/tree/main/path3?expand=True&cursor=abc123"
483 );
484
485 Ok(())
486 }
487
488 #[tokio::test]
489 async fn test_hf_resolve_url_model() -> Result<()> {
490 let (core, mock_client) = create_test_core(
491 RepoType::Model,
492 "user/model",
493 "main",
494 "https://huggingface.co",
495 );
496
497 let args = OpRead::default();
498 core.hf_resolve("config.json", BytesRange::default(), &args)
499 .await?;
500
501 let url = mock_client.get_captured_url();
502 assert_eq!(
503 url,
504 "https://huggingface.co/user/model/resolve/main/config.json"
505 );
506
507 Ok(())
508 }
509
510 #[tokio::test]
511 async fn test_hf_resolve_url_dataset() -> Result<()> {
512 let (core, mock_client) = create_test_core(
513 RepoType::Dataset,
514 "org/data",
515 "v1.0",
516 "https://huggingface.co",
517 );
518
519 let args = OpRead::default();
520 core.hf_resolve("train.csv", BytesRange::default(), &args)
521 .await?;
522
523 let url = mock_client.get_captured_url();
524 assert_eq!(
525 url,
526 "https://huggingface.co/datasets/org/data/resolve/v1%2E0/train.csv"
527 );
528
529 Ok(())
530 }
531
532 #[tokio::test]
533 async fn test_hf_path_info_url_space() -> Result<()> {
534 let (core, mock_client) = create_test_core(
535 RepoType::Space,
536 "test-user/test-space",
537 "main",
538 "https://huggingface.co",
539 );
540
541 core.hf_path_info("app.py").await?;
542
543 let url = mock_client.get_captured_url();
544 assert_eq!(
545 url,
546 "https://huggingface.co/api/spaces/test-user/test-space/paths-info/main"
547 );
548
549 Ok(())
550 }
551
552 #[tokio::test]
553 async fn test_hf_list_url_space() -> Result<()> {
554 let (core, mock_client) = create_test_core(
555 RepoType::Space,
556 "org/space",
557 "main",
558 "https://huggingface.co",
559 );
560
561 core.hf_list("static", false, None).await?;
562
563 let url = mock_client.get_captured_url();
564 assert_eq!(
565 url,
566 "https://huggingface.co/api/spaces/org/space/tree/main/static?expand=True"
567 );
568
569 Ok(())
570 }
571
572 #[tokio::test]
573 async fn test_hf_resolve_url_space() -> Result<()> {
574 let (core, mock_client) = create_test_core(
575 RepoType::Space,
576 "user/space",
577 "main",
578 "https://huggingface.co",
579 );
580
581 let args = OpRead::default();
582 core.hf_resolve("README.md", BytesRange::default(), &args)
583 .await?;
584
585 let url = mock_client.get_captured_url();
586 assert_eq!(
587 url,
588 "https://huggingface.co/spaces/user/space/resolve/main/README.md"
589 );
590
591 Ok(())
592 }
593
594 #[tokio::test]
595 async fn test_hf_resolve_with_range() -> Result<()> {
596 let (core, mock_client) = create_test_core(
597 RepoType::Model,
598 "user/model",
599 "main",
600 "https://huggingface.co",
601 );
602
603 let args = OpRead::default();
604 let range = BytesRange::new(0, Some(1024));
605 core.hf_resolve("large_file.bin", range, &args).await?;
606
607 let url = mock_client.get_captured_url();
608 let headers = mock_client.get_captured_headers();
609 assert_eq!(
610 url,
611 "https://huggingface.co/user/model/resolve/main/large_file.bin"
612 );
613 assert_eq!(headers.get(http::header::RANGE).unwrap(), "bytes=0-1023");
614
615 Ok(())
616 }
617
618 #[test]
619 fn parse_list_response_test() -> Result<()> {
620 let resp = Bytes::from(
621 r#"
622 [
623 {
624 "type": "file",
625 "oid": "45fa7c3d85ee7dd4139adbc056da25ae136a65f2",
626 "size": 69512435,
627 "lfs": {
628 "oid": "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c",
629 "size": 69512435,
630 "pointerSize": 133
631 },
632 "path": "maelstrom/lib/maelstrom.jar"
633 },
634 {
635 "type": "directory",
636 "oid": "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c",
637 "size": 69512435,
638 "path": "maelstrom/lib/plugins"
639 }
640 ]
641 "#,
642 );
643
644 let decoded_response = serde_json::from_slice::<Vec<HuggingfaceStatus>>(&resp)
645 .map_err(new_json_deserialize_error)?;
646
647 assert_eq!(decoded_response.len(), 2);
648
649 let file_entry = HuggingfaceStatus {
650 type_: "file".to_string(),
651 oid: "45fa7c3d85ee7dd4139adbc056da25ae136a65f2".to_string(),
652 size: 69512435,
653 lfs: Some(HuggingfaceLfs {
654 oid: "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c".to_string(),
655 size: 69512435,
656 pointer_size: 133,
657 }),
658 path: "maelstrom/lib/maelstrom.jar".to_string(),
659 last_commit: None,
660 security: None,
661 };
662
663 assert_eq!(decoded_response[0], file_entry);
664
665 let dir_entry = HuggingfaceStatus {
666 type_: "directory".to_string(),
667 oid: "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c".to_string(),
668 size: 69512435,
669 lfs: None,
670 path: "maelstrom/lib/plugins".to_string(),
671 last_commit: None,
672 security: None,
673 };
674
675 assert_eq!(decoded_response[1], dir_entry);
676
677 Ok(())
678 }
679
680 #[test]
681 fn parse_files_info_test() -> Result<()> {
682 let resp = Bytes::from(
683 r#"
684 [
685 {
686 "type": "file",
687 "oid": "45fa7c3d85ee7dd4139adbc056da25ae136a65f2",
688 "size": 69512435,
689 "lfs": {
690 "oid": "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c",
691 "size": 69512435,
692 "pointerSize": 133
693 },
694 "path": "maelstrom/lib/maelstrom.jar",
695 "lastCommit": {
696 "id": "bc1ef030bf3743290d5e190695ab94582e51ae2f",
697 "title": "Upload 141 files",
698 "date": "2023-11-17T23:50:28.000Z"
699 },
700 "security": {
701 "blobId": "45fa7c3d85ee7dd4139adbc056da25ae136a65f2",
702 "name": "maelstrom/lib/maelstrom.jar",
703 "safe": true,
704 "avScan": {
705 "virusFound": false,
706 "virusNames": null
707 },
708 "pickleImportScan": {
709 "highestSafetyLevel": "innocuous",
710 "imports": [
711 {"module": "torch", "name": "FloatStorage", "safety": "innocuous"},
712 {"module": "collections", "name": "OrderedDict", "safety": "innocuous"},
713 {"module": "torch", "name": "LongStorage", "safety": "innocuous"},
714 {"module": "torch._utils", "name": "_rebuild_tensor_v2", "safety": "innocuous"}
715 ]
716 }
717 }
718 }
719 ]
720 "#,
721 );
722
723 let decoded_response = serde_json::from_slice::<Vec<HuggingfaceStatus>>(&resp)
724 .map_err(new_json_deserialize_error)?;
725
726 assert_eq!(decoded_response.len(), 1);
727
728 let file_info = HuggingfaceStatus {
729 type_: "file".to_string(),
730 oid: "45fa7c3d85ee7dd4139adbc056da25ae136a65f2".to_string(),
731 size: 69512435,
732 lfs: Some(HuggingfaceLfs {
733 oid: "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c".to_string(),
734 size: 69512435,
735 pointer_size: 133,
736 }),
737 path: "maelstrom/lib/maelstrom.jar".to_string(),
738 last_commit: Some(HuggingfaceLastCommit {
739 id: "bc1ef030bf3743290d5e190695ab94582e51ae2f".to_string(),
740 title: "Upload 141 files".to_string(),
741 date: "2023-11-17T23:50:28.000Z".to_string(),
742 }),
743 security: Some(HuggingfaceSecurity {
744 blob_id: "45fa7c3d85ee7dd4139adbc056da25ae136a65f2".to_string(),
745 safe: true,
746 av_scan: Some(HuggingfaceAvScan {
747 virus_found: false,
748 virus_names: None,
749 }),
750 pickle_import_scan: Some(HuggingfacePickleImportScan {
751 highest_safety_level: "innocuous".to_string(),
752 imports: vec![
753 HuggingfaceImport {
754 module: "torch".to_string(),
755 name: "FloatStorage".to_string(),
756 safety: "innocuous".to_string(),
757 },
758 HuggingfaceImport {
759 module: "collections".to_string(),
760 name: "OrderedDict".to_string(),
761 safety: "innocuous".to_string(),
762 },
763 HuggingfaceImport {
764 module: "torch".to_string(),
765 name: "LongStorage".to_string(),
766 safety: "innocuous".to_string(),
767 },
768 HuggingfaceImport {
769 module: "torch._utils".to_string(),
770 name: "_rebuild_tensor_v2".to_string(),
771 safety: "innocuous".to_string(),
772 },
773 ],
774 }),
775 }),
776 };
777
778 assert_eq!(decoded_response[0], file_info);
779
780 Ok(())
781 }
782}