opendal/services/huggingface/
config.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::fmt::Debug;
19
20use serde::Deserialize;
21use serde::Serialize;
22
23use super::HUGGINGFACE_SCHEME;
24use super::backend::HuggingfaceBuilder;
25
26/// Configuration for Huggingface service support.
27#[derive(Default, Serialize, Deserialize, Clone, PartialEq, Eq)]
28#[serde(default)]
29#[non_exhaustive]
30pub struct HuggingfaceConfig {
31    /// Repo type of this backend. Default is model.
32    ///
33    /// Available values:
34    /// - model
35    /// - dataset
36    pub repo_type: Option<String>,
37    /// Repo id of this backend.
38    ///
39    /// This is required.
40    pub repo_id: Option<String>,
41    /// Revision of this backend.
42    ///
43    /// Default is main.
44    pub revision: Option<String>,
45    /// Root of this backend. Can be "/path/to/dir".
46    ///
47    /// Default is "/".
48    pub root: Option<String>,
49    /// Token of this backend.
50    ///
51    /// This is optional.
52    pub token: Option<String>,
53}
54
55impl Debug for HuggingfaceConfig {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        f.debug_struct("HuggingfaceConfig")
58            .field("repo_type", &self.repo_type)
59            .field("repo_id", &self.repo_id)
60            .field("revision", &self.revision)
61            .field("root", &self.root)
62            .finish_non_exhaustive()
63    }
64}
65
66impl crate::Configurator for HuggingfaceConfig {
67    type Builder = HuggingfaceBuilder;
68
69    fn from_uri(uri: &crate::types::OperatorUri) -> crate::Result<Self> {
70        let mut map = uri.options().clone();
71
72        if let Some(repo_type) = uri.name() {
73            if !repo_type.is_empty() {
74                map.insert("repo_type".to_string(), repo_type.to_string());
75            }
76        }
77
78        let raw_path = uri.root().ok_or_else(|| {
79            crate::Error::new(
80                crate::ErrorKind::ConfigInvalid,
81                "uri path must include owner and repo",
82            )
83            .with_context("service", HUGGINGFACE_SCHEME)
84        })?;
85
86        let mut segments = raw_path.splitn(4, '/');
87        let owner = segments.next().filter(|s| !s.is_empty()).ok_or_else(|| {
88            crate::Error::new(
89                crate::ErrorKind::ConfigInvalid,
90                "repository owner is required in uri path",
91            )
92            .with_context("service", HUGGINGFACE_SCHEME)
93        })?;
94        let repo = segments.next().filter(|s| !s.is_empty()).ok_or_else(|| {
95            crate::Error::new(
96                crate::ErrorKind::ConfigInvalid,
97                "repository name is required in uri path",
98            )
99            .with_context("service", HUGGINGFACE_SCHEME)
100        })?;
101
102        map.insert("repo_id".to_string(), format!("{owner}/{repo}"));
103
104        if let Some(segment) = segments.next() {
105            if map.contains_key("revision") {
106                let mut root_value = segment.to_string();
107                if let Some(rest) = segments.next() {
108                    if !rest.is_empty() {
109                        if !root_value.is_empty() {
110                            root_value.push('/');
111                            root_value.push_str(rest);
112                        } else {
113                            root_value = rest.to_string();
114                        }
115                    }
116                }
117                if !root_value.is_empty() {
118                    map.insert("root".to_string(), root_value);
119                }
120            } else {
121                if !segment.is_empty() {
122                    map.insert("revision".to_string(), segment.to_string());
123                }
124                if let Some(rest) = segments.next() {
125                    if !rest.is_empty() {
126                        map.insert("root".to_string(), rest.to_string());
127                    }
128                }
129            }
130        }
131
132        Self::from_iter(map)
133    }
134
135    fn into_builder(self) -> Self::Builder {
136        HuggingfaceBuilder { config: self }
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use crate::Configurator;
144    use crate::types::OperatorUri;
145
146    #[test]
147    fn from_uri_sets_repo_type_id_and_revision() {
148        let uri = OperatorUri::new(
149            "huggingface://model/opendal/sample/main/dataset",
150            Vec::<(String, String)>::new(),
151        )
152        .unwrap();
153
154        let cfg = HuggingfaceConfig::from_uri(&uri).unwrap();
155        assert_eq!(cfg.repo_type.as_deref(), Some("model"));
156        assert_eq!(cfg.repo_id.as_deref(), Some("opendal/sample"));
157        assert_eq!(cfg.revision.as_deref(), Some("main"));
158        assert_eq!(cfg.root.as_deref(), Some("dataset"));
159    }
160
161    #[test]
162    fn from_uri_uses_existing_revision_and_sets_root() {
163        let uri = OperatorUri::new(
164            "huggingface://dataset/opendal/sample/data/train",
165            vec![("revision".to_string(), "dev".to_string())],
166        )
167        .unwrap();
168
169        let cfg = HuggingfaceConfig::from_uri(&uri).unwrap();
170        assert_eq!(cfg.repo_type.as_deref(), Some("dataset"));
171        assert_eq!(cfg.repo_id.as_deref(), Some("opendal/sample"));
172        assert_eq!(cfg.revision.as_deref(), Some("dev"));
173        assert_eq!(cfg.root.as_deref(), Some("data/train"));
174    }
175
176    #[test]
177    fn from_uri_requires_owner_and_repo() {
178        let uri = OperatorUri::new(
179            "huggingface://model/opendal",
180            Vec::<(String, String)>::new(),
181        )
182        .unwrap();
183
184        assert!(HuggingfaceConfig::from_uri(&uri).is_err());
185    }
186}