1use bytes::Bytes;
19use http::header;
20use http::Request;
21use http::Response;
22use serde::Deserialize;
23use std::fmt::Debug;
24use std::sync::Arc;
25
26use super::backend::RepoType;
27use crate::raw::*;
28use crate::*;
29
30pub struct HuggingfaceCore {
31 pub info: Arc<AccessorInfo>,
32
33 pub repo_type: RepoType,
34 pub repo_id: String,
35 pub revision: String,
36 pub root: String,
37 pub token: Option<String>,
38}
39
40impl Debug for HuggingfaceCore {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 f.debug_struct("HuggingfaceCore")
43 .field("repo_type", &self.repo_type)
44 .field("repo_id", &self.repo_id)
45 .field("revision", &self.revision)
46 .field("root", &self.root)
47 .finish_non_exhaustive()
48 }
49}
50
51impl HuggingfaceCore {
52 pub async fn hf_path_info(&self, path: &str) -> Result<Response<Buffer>> {
53 let p = build_abs_path(&self.root, path)
54 .trim_end_matches('/')
55 .to_string();
56
57 let url = match self.repo_type {
58 RepoType::Model => format!(
59 "https://huggingface.co/api/models/{}/paths-info/{}",
60 &self.repo_id, &self.revision
61 ),
62 RepoType::Dataset => format!(
63 "https://huggingface.co/api/datasets/{}/paths-info/{}",
64 &self.repo_id, &self.revision
65 ),
66 };
67
68 let mut req = Request::post(&url);
69 req = req.extension(Operation::Stat);
71 if let Some(token) = &self.token {
72 let auth_header_content = format_authorization_by_bearer(token)?;
73 req = req.header(header::AUTHORIZATION, auth_header_content);
74 }
75
76 req = req.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded");
77
78 let req_body = format!("paths={}&expand=True", percent_encode_path(&p));
79
80 let req = req
81 .body(Buffer::from(Bytes::from(req_body)))
82 .map_err(new_request_build_error)?;
83
84 self.info.http_client().send(req).await
85 }
86
87 pub async fn hf_list(&self, path: &str, recursive: bool) -> Result<Response<Buffer>> {
88 let p = build_abs_path(&self.root, path)
89 .trim_end_matches('/')
90 .to_string();
91
92 let mut url = match self.repo_type {
93 RepoType::Model => format!(
94 "https://huggingface.co/api/models/{}/tree/{}/{}?expand=True",
95 &self.repo_id,
96 &self.revision,
97 percent_encode_path(&p)
98 ),
99 RepoType::Dataset => format!(
100 "https://huggingface.co/api/datasets/{}/tree/{}/{}?expand=True",
101 &self.repo_id,
102 &self.revision,
103 percent_encode_path(&p)
104 ),
105 };
106
107 if recursive {
108 url.push_str("&recursive=True");
109 }
110
111 let mut req = Request::get(&url);
112 req = req.extension(Operation::List);
114 if let Some(token) = &self.token {
115 let auth_header_content = format_authorization_by_bearer(token)?;
116 req = req.header(header::AUTHORIZATION, auth_header_content);
117 }
118
119 let req = req.body(Buffer::new()).map_err(new_request_build_error)?;
120
121 self.info.http_client().send(req).await
122 }
123
124 pub async fn hf_resolve(
125 &self,
126 path: &str,
127 range: BytesRange,
128 _args: &OpRead,
129 ) -> Result<Response<HttpBody>> {
130 let p = build_abs_path(&self.root, path)
131 .trim_end_matches('/')
132 .to_string();
133
134 let url = match self.repo_type {
135 RepoType::Model => format!(
136 "https://huggingface.co/{}/resolve/{}/{}",
137 &self.repo_id,
138 &self.revision,
139 percent_encode_path(&p)
140 ),
141 RepoType::Dataset => format!(
142 "https://huggingface.co/datasets/{}/resolve/{}/{}",
143 &self.repo_id,
144 &self.revision,
145 percent_encode_path(&p)
146 ),
147 };
148
149 let mut req = Request::get(&url);
150
151 if let Some(token) = &self.token {
152 let auth_header_content = format_authorization_by_bearer(token)?;
153 req = req.header(header::AUTHORIZATION, auth_header_content);
154 }
155
156 if !range.is_full() {
157 req = req.header(header::RANGE, range.to_header());
158 }
159 let req = req.extension(Operation::Read);
161 let req = req.body(Buffer::new()).map_err(new_request_build_error)?;
162
163 self.info.http_client().fetch(req).await
164 }
165}
166
167#[derive(Deserialize, Eq, PartialEq, Debug)]
168#[serde(rename_all = "camelCase")]
169#[allow(dead_code)]
170pub(super) struct HuggingfaceStatus {
171 #[serde(rename = "type")]
172 pub type_: String,
173 pub oid: String,
174 pub size: u64,
175 pub lfs: Option<HuggingfaceLfs>,
176 pub path: String,
177 pub last_commit: Option<HuggingfaceLastCommit>,
178 pub security: Option<HuggingfaceSecurity>,
179}
180
181#[derive(Deserialize, Eq, PartialEq, Debug)]
182#[serde(rename_all = "camelCase")]
183#[allow(dead_code)]
184pub(super) struct HuggingfaceLfs {
185 pub oid: String,
186 pub size: u64,
187 pub pointer_size: u64,
188}
189
190#[derive(Deserialize, Eq, PartialEq, Debug)]
191#[serde(rename_all = "camelCase")]
192#[allow(dead_code)]
193pub(super) struct HuggingfaceLastCommit {
194 pub id: String,
195 pub title: String,
196 pub date: String,
197}
198
199#[derive(Deserialize, Eq, PartialEq, Debug)]
200#[serde(rename_all = "camelCase")]
201#[allow(dead_code)]
202pub(super) struct HuggingfaceSecurity {
203 pub blob_id: String,
204 pub safe: bool,
205 pub av_scan: Option<HuggingfaceAvScan>,
206 pub pickle_import_scan: Option<HuggingfacePickleImportScan>,
207}
208
209#[derive(Deserialize, Eq, PartialEq, Debug)]
210#[allow(dead_code)]
211#[serde(rename_all = "camelCase")]
212pub(super) struct HuggingfaceAvScan {
213 pub virus_found: bool,
214 pub virus_names: Option<Vec<String>>,
215}
216
217#[derive(Deserialize, Eq, PartialEq, Debug)]
218#[serde(rename_all = "camelCase")]
219#[allow(dead_code)]
220pub(super) struct HuggingfacePickleImportScan {
221 pub highest_safety_level: String,
222 pub imports: Vec<HuggingfaceImport>,
223}
224
225#[derive(Deserialize, Eq, PartialEq, Debug)]
226#[allow(dead_code)]
227pub(super) struct HuggingfaceImport {
228 pub module: String,
229 pub name: String,
230 pub safety: String,
231}
232
233#[cfg(test)]
234mod tests {
235 use bytes::Bytes;
236
237 use super::*;
238 use crate::raw::new_json_deserialize_error;
239 use crate::types::Result;
240
241 #[test]
242 fn parse_list_response_test() -> Result<()> {
243 let resp = Bytes::from(
244 r#"
245 [
246 {
247 "type": "file",
248 "oid": "45fa7c3d85ee7dd4139adbc056da25ae136a65f2",
249 "size": 69512435,
250 "lfs": {
251 "oid": "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c",
252 "size": 69512435,
253 "pointerSize": 133
254 },
255 "path": "maelstrom/lib/maelstrom.jar"
256 },
257 {
258 "type": "directory",
259 "oid": "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c",
260 "size": 69512435,
261 "path": "maelstrom/lib/plugins"
262 }
263 ]
264 "#,
265 );
266
267 let decoded_response = serde_json::from_slice::<Vec<HuggingfaceStatus>>(&resp)
268 .map_err(new_json_deserialize_error)?;
269
270 assert_eq!(decoded_response.len(), 2);
271
272 let file_entry = HuggingfaceStatus {
273 type_: "file".to_string(),
274 oid: "45fa7c3d85ee7dd4139adbc056da25ae136a65f2".to_string(),
275 size: 69512435,
276 lfs: Some(HuggingfaceLfs {
277 oid: "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c".to_string(),
278 size: 69512435,
279 pointer_size: 133,
280 }),
281 path: "maelstrom/lib/maelstrom.jar".to_string(),
282 last_commit: None,
283 security: None,
284 };
285
286 assert_eq!(decoded_response[0], file_entry);
287
288 let dir_entry = HuggingfaceStatus {
289 type_: "directory".to_string(),
290 oid: "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c".to_string(),
291 size: 69512435,
292 lfs: None,
293 path: "maelstrom/lib/plugins".to_string(),
294 last_commit: None,
295 security: None,
296 };
297
298 assert_eq!(decoded_response[1], dir_entry);
299
300 Ok(())
301 }
302
303 #[test]
304 fn parse_files_info_test() -> Result<()> {
305 let resp = Bytes::from(
306 r#"
307 [
308 {
309 "type": "file",
310 "oid": "45fa7c3d85ee7dd4139adbc056da25ae136a65f2",
311 "size": 69512435,
312 "lfs": {
313 "oid": "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c",
314 "size": 69512435,
315 "pointerSize": 133
316 },
317 "path": "maelstrom/lib/maelstrom.jar",
318 "lastCommit": {
319 "id": "bc1ef030bf3743290d5e190695ab94582e51ae2f",
320 "title": "Upload 141 files",
321 "date": "2023-11-17T23:50:28.000Z"
322 },
323 "security": {
324 "blobId": "45fa7c3d85ee7dd4139adbc056da25ae136a65f2",
325 "name": "maelstrom/lib/maelstrom.jar",
326 "safe": true,
327 "avScan": {
328 "virusFound": false,
329 "virusNames": null
330 },
331 "pickleImportScan": {
332 "highestSafetyLevel": "innocuous",
333 "imports": [
334 {"module": "torch", "name": "FloatStorage", "safety": "innocuous"},
335 {"module": "collections", "name": "OrderedDict", "safety": "innocuous"},
336 {"module": "torch", "name": "LongStorage", "safety": "innocuous"},
337 {"module": "torch._utils", "name": "_rebuild_tensor_v2", "safety": "innocuous"}
338 ]
339 }
340 }
341 }
342 ]
343 "#,
344 );
345
346 let decoded_response = serde_json::from_slice::<Vec<HuggingfaceStatus>>(&resp)
347 .map_err(new_json_deserialize_error)?;
348
349 assert_eq!(decoded_response.len(), 1);
350
351 let file_info = HuggingfaceStatus {
352 type_: "file".to_string(),
353 oid: "45fa7c3d85ee7dd4139adbc056da25ae136a65f2".to_string(),
354 size: 69512435,
355 lfs: Some(HuggingfaceLfs {
356 oid: "b43f4c2ea569da1d66ca74e26ca8ea4430dfc29195e97144b2d0b4f3f6cafa1c".to_string(),
357 size: 69512435,
358 pointer_size: 133,
359 }),
360 path: "maelstrom/lib/maelstrom.jar".to_string(),
361 last_commit: Some(HuggingfaceLastCommit {
362 id: "bc1ef030bf3743290d5e190695ab94582e51ae2f".to_string(),
363 title: "Upload 141 files".to_string(),
364 date: "2023-11-17T23:50:28.000Z".to_string(),
365 }),
366 security: Some(HuggingfaceSecurity {
367 blob_id: "45fa7c3d85ee7dd4139adbc056da25ae136a65f2".to_string(),
368 safe: true,
369 av_scan: Some(HuggingfaceAvScan {
370 virus_found: false,
371 virus_names: None,
372 }),
373 pickle_import_scan: Some(HuggingfacePickleImportScan {
374 highest_safety_level: "innocuous".to_string(),
375 imports: vec![
376 HuggingfaceImport {
377 module: "torch".to_string(),
378 name: "FloatStorage".to_string(),
379 safety: "innocuous".to_string(),
380 },
381 HuggingfaceImport {
382 module: "collections".to_string(),
383 name: "OrderedDict".to_string(),
384 safety: "innocuous".to_string(),
385 },
386 HuggingfaceImport {
387 module: "torch".to_string(),
388 name: "LongStorage".to_string(),
389 safety: "innocuous".to_string(),
390 },
391 HuggingfaceImport {
392 module: "torch._utils".to_string(),
393 name: "_rebuild_tensor_v2".to_string(),
394 safety: "innocuous".to_string(),
395 },
396 ],
397 }),
398 }),
399 };
400
401 assert_eq!(decoded_response[0], file_info);
402
403 Ok(())
404 }
405}