1use std::collections::HashMap;
25
26use reqsign::{AzureStorageConfig, AzureStorageCredential};
27
28use crate::{Error, ErrorKind, Result};
29
30pub(crate) fn azure_config_from_connection_string(
40 conn_str: &str,
41 storage: AzureStorageService,
42) -> Result<AzureStorageConfig> {
43 let key_values = parse_connection_string(conn_str)?;
44
45 if storage == AzureStorageService::Blob {
46 if let Some(development_config) = collect_blob_development_config(&key_values, &storage) {
48 return Ok(AzureStorageConfig {
49 account_name: Some(development_config.account_name),
50 account_key: Some(development_config.account_key),
51 endpoint: Some(development_config.endpoint),
52 ..Default::default()
53 });
54 }
55 }
56
57 let mut config = AzureStorageConfig {
58 account_name: key_values.get("AccountName").cloned(),
59 endpoint: collect_endpoint(&key_values, &storage)?,
60 ..Default::default()
61 };
62
63 if let Some(creds) = collect_credentials(&key_values) {
64 set_credentials(&mut config, creds);
65 };
66
67 Ok(config)
68}
69
70#[derive(PartialEq)]
73pub(crate) enum AzureStorageService {
74 Blob,
76
77 File,
79
80 Adls,
83}
84
85pub(crate) fn azure_account_name_from_endpoint(endpoint: &str) -> Option<String> {
86 const KNOWN_ENDPOINT_SUFFIXES: &[&str] = &[
88 "core.windows.net", "core.usgovcloudapi.net", "core.chinacloudapi.cn", ];
92
93 let endpoint: &str = endpoint
94 .strip_prefix("http://")
95 .or_else(|| endpoint.strip_prefix("https://"))
96 .unwrap_or(endpoint);
97
98 let (account_name, service_endpoint) = endpoint.split_once('.')?;
99 let (_storage_service, endpoint_suffix) = service_endpoint.split_once('.')?;
100
101 if KNOWN_ENDPOINT_SUFFIXES.contains(&endpoint_suffix.trim_end_matches('/')) {
102 Some(account_name.to_string())
103 } else {
104 None
105 }
106}
107
108fn parse_connection_string(conn_str: &str) -> Result<HashMap<String, String>> {
111 conn_str
112 .trim()
113 .replace("\n", "")
114 .split(';')
115 .filter(|&field| !field.is_empty())
116 .map(|field| {
117 let (key, value) = field.trim().split_once('=').ok_or(Error::new(
118 ErrorKind::ConfigInvalid,
119 format!("Invalid connection string, expected '=' in field: {field}"),
120 ))?;
121 Ok((key.to_string(), value.to_string()))
122 })
123 .collect()
124}
125
126fn collect_blob_development_config(
127 key_values: &HashMap<String, String>,
128 storage: &AzureStorageService,
129) -> Option<DevelopmentStorageConfig> {
130 debug_assert!(
131 storage == &AzureStorageService::Blob,
132 "Azurite Development Storage only supports Blob Storage"
133 );
134
135 const AZURITE_DEFAULT_STORAGE_ACCOUNT_NAME: &str = "devstoreaccount1";
137 const AZURITE_DEFAULT_STORAGE_ACCOUNT_KEY: &str =
138 "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==";
139
140 const AZURITE_DEFAULT_BLOB_URI: &str = "http://127.0.0.1:10000";
141
142 if key_values.get("UseDevelopmentStorage") != Some(&"true".to_string()) {
143 return None; }
145
146 let account_name = key_values
147 .get("AccountName")
148 .cloned()
149 .unwrap_or(AZURITE_DEFAULT_STORAGE_ACCOUNT_NAME.to_string());
150 let account_key = key_values
151 .get("AccountKey")
152 .cloned()
153 .unwrap_or(AZURITE_DEFAULT_STORAGE_ACCOUNT_KEY.to_string());
154 let development_proxy_uri = key_values
155 .get("DevelopmentStorageProxyUri")
156 .cloned()
157 .unwrap_or(AZURITE_DEFAULT_BLOB_URI.to_string());
158
159 Some(DevelopmentStorageConfig {
160 endpoint: format!("{development_proxy_uri}/{account_name}"),
161 account_name,
162 account_key,
163 })
164}
165
166struct DevelopmentStorageConfig {
168 account_name: String,
169 account_key: String,
170 endpoint: String,
171}
172
173fn collect_endpoint(
178 key_values: &HashMap<String, String>,
179 storage: &AzureStorageService,
180) -> Result<Option<String>> {
181 match storage {
182 AzureStorageService::Blob => collect_or_build_endpoint(key_values, "BlobEndpoint", "blob"),
183 AzureStorageService::File => collect_or_build_endpoint(key_values, "FileEndpoint", "file"),
184 AzureStorageService::Adls => {
185 if let Some(dfs_endpoint) = collect_endpoint_from_parts(key_values, "dfs")? {
188 Ok(Some(dfs_endpoint.clone()))
189 } else {
190 Ok(None)
191 }
192 }
193 }
194}
195
196fn collect_credentials(key_values: &HashMap<String, String>) -> Option<AzureStorageCredential> {
197 if let Some(sas_token) = key_values.get("SharedAccessSignature") {
198 Some(AzureStorageCredential::SharedAccessSignature(
199 sas_token.clone(),
200 ))
201 } else if let (Some(account_name), Some(account_key)) =
202 (key_values.get("AccountName"), key_values.get("AccountKey"))
203 {
204 Some(AzureStorageCredential::SharedKey(
205 account_name.clone(),
206 account_key.clone(),
207 ))
208 } else {
209 None
214 }
215}
216
217fn set_credentials(config: &mut AzureStorageConfig, creds: AzureStorageCredential) {
218 match creds {
219 AzureStorageCredential::SharedAccessSignature(sas_token) => {
220 config.sas_token = Some(sas_token);
221 }
222 AzureStorageCredential::SharedKey(account_name, account_key) => {
223 config.account_name = Some(account_name);
224 config.account_key = Some(account_key);
225 }
226 AzureStorageCredential::BearerToken(_, _) => {
227 }
229 }
230}
231
232fn collect_or_build_endpoint(
233 key_values: &HashMap<String, String>,
234 endpoint_key: &str,
235 service_name: &str,
236) -> Result<Option<String>> {
237 if let Some(endpoint) = key_values.get(endpoint_key) {
238 Ok(Some(endpoint.clone()))
239 } else if let Some(built_endpoint) = collect_endpoint_from_parts(key_values, service_name)? {
240 Ok(Some(built_endpoint.clone()))
241 } else {
242 Ok(None)
243 }
244}
245
246fn collect_endpoint_from_parts(
247 key_values: &HashMap<String, String>,
248 storage_endpoint_name: &str,
249) -> Result<Option<String>> {
250 let (account_name, endpoint_suffix) = match (
251 key_values.get("AccountName"),
252 key_values.get("EndpointSuffix"),
253 ) {
254 (Some(name), Some(suffix)) => (name, suffix),
255 _ => return Ok(None), };
257
258 let protocol = key_values
259 .get("DefaultEndpointsProtocol")
260 .map(String::as_str)
261 .unwrap_or("https"); if protocol != "http" && protocol != "https" {
263 return Err(Error::new(
264 ErrorKind::ConfigInvalid,
265 format!("Invalid DefaultEndpointsProtocol: {protocol}"),
266 ));
267 }
268
269 Ok(Some(format!(
270 "{protocol}://{account_name}.{storage_endpoint_name}.{endpoint_suffix}"
271 )))
272}
273
274#[cfg(test)]
275mod tests {
276 use reqsign::AzureStorageConfig;
277
278 use super::{
279 azure_account_name_from_endpoint, azure_config_from_connection_string, AzureStorageService,
280 };
281
282 #[test]
283 fn test_azure_config_from_connection_string() {
284 let test_cases = vec![
285 ("minimal fields",
286 (AzureStorageService::Blob, "BlobEndpoint=https://testaccount.blob.core.windows.net/"),
287 Some(AzureStorageConfig{
288 endpoint: Some("https://testaccount.blob.core.windows.net/".to_string()),
289 ..Default::default()
290 }),
291 ),
292 ("basic creds and blob endpoint",
293 (AzureStorageService::Blob, "AccountName=testaccount;AccountKey=testkey;BlobEndpoint=https://testaccount.blob.core.windows.net/"),
294 Some(AzureStorageConfig{
295 account_name: Some("testaccount".to_string()),
296 account_key: Some("testkey".to_string()),
297 endpoint: Some("https://testaccount.blob.core.windows.net/".to_string()),
298 ..Default::default()
299 }),
300 ),
301 ("SAS token",
302 (AzureStorageService::Blob, "SharedAccessSignature=blablabla"),
303 Some(AzureStorageConfig{
304 sas_token: Some("blablabla".to_string()),
305 ..Default::default()
306 }),
307 ),
308 ("endpoint from parts",
309 (AzureStorageService::Blob, "AccountName=testaccount;EndpointSuffix=core.windows.net;DefaultEndpointsProtocol=https"),
310 Some(AzureStorageConfig{
311 endpoint: Some("https://testaccount.blob.core.windows.net".to_string()),
312 account_name: Some("testaccount".to_string()),
313 ..Default::default()
314 }),
315 ),
316 ("endpoint from parts and no protocol",
317 (AzureStorageService::Blob, "AccountName=testaccount;EndpointSuffix=core.windows.net"),
318 Some(AzureStorageConfig{
319 endpoint: Some("https://testaccount.blob.core.windows.net".to_string()),
321 account_name: Some("testaccount".to_string()),
322 ..Default::default()
323 }),
324 ),
325 ("adls endpoint from parts",
326 (AzureStorageService::Adls, "AccountName=testaccount;EndpointSuffix=core.windows.net;DefaultEndpointsProtocol=https"),
327 Some(AzureStorageConfig{
328 account_name: Some("testaccount".to_string()),
329 endpoint: Some("https://testaccount.dfs.core.windows.net".to_string()),
330 ..Default::default()
331 }),
332 ),
333 ("file endpoint from field",
334 (AzureStorageService::File, "FileEndpoint=https://testaccount.file.core.windows.net"),
335 Some(AzureStorageConfig{
336 endpoint: Some("https://testaccount.file.core.windows.net".to_string()),
337 ..Default::default()
338 })
339 ),
340 ("file endpoint from parts",
341 (AzureStorageService::File, "AccountName=testaccount;EndpointSuffix=core.windows.net"),
342 Some(AzureStorageConfig{
343 account_name: Some("testaccount".to_string()),
344 endpoint: Some("https://testaccount.file.core.windows.net".to_string()),
345 ..Default::default()
346 }),
347 ),
348 ("prefers sas over key",
349 (AzureStorageService::Blob, "AccountName=testaccount;AccountKey=testkey;SharedAccessSignature=sas_token"),
350 Some(AzureStorageConfig{
351 sas_token: Some("sas_token".to_string()),
352 account_name: Some("testaccount".to_string()),
353 ..Default::default()
354 }),
355 ),
356 ("development storage",
357 (AzureStorageService::Blob, "UseDevelopmentStorage=true",),
358 Some(AzureStorageConfig{
359 account_name: Some("devstoreaccount1".to_string()),
360 account_key: Some("Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==".to_string()),
361 endpoint: Some("http://127.0.0.1:10000/devstoreaccount1".to_string()),
362 ..Default::default()
363 }),
364 ),
365 ("development storage with custom account values",
366 (AzureStorageService::Blob, "UseDevelopmentStorage=true;AccountName=myAccount;AccountKey=myKey"),
367 Some(AzureStorageConfig {
368 endpoint: Some("http://127.0.0.1:10000/myAccount".to_string()),
369 account_name: Some("myAccount".to_string()),
370 account_key: Some("myKey".to_string()),
371 ..Default::default()
372 }),
373 ),
374 ("development storage with custom uri",
375 (AzureStorageService::Blob, "UseDevelopmentStorage=true;DevelopmentStorageProxyUri=http://127.0.0.1:12345"),
376 Some(AzureStorageConfig {
377 endpoint: Some("http://127.0.0.1:12345/devstoreaccount1".to_string()),
378 account_name: Some("devstoreaccount1".to_string()),
379 account_key: Some("Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==".to_string()),
380 ..Default::default()
381 }),
382 ),
383 ("unknown key is ignored",
384 (AzureStorageService::Blob, "SomeUnknownKey=123;BlobEndpoint=https://testaccount.blob.core.windows.net/"),
385 Some(AzureStorageConfig{
386 endpoint: Some("https://testaccount.blob.core.windows.net/".to_string()),
387 ..Default::default()
388 }),
389 ),
390 ("leading and trailing `;`",
391 (AzureStorageService::Blob, ";AccountName=testaccount;"),
392 Some(AzureStorageConfig {
393 account_name: Some("testaccount".to_string()),
394 ..Default::default()
395 }),
396 ),
397 ("line breaks",
398 (AzureStorageService::Blob, r#"
399 AccountName=testaccount;
400 AccountKey=testkey;
401 EndpointSuffix=core.windows.net;
402 DefaultEndpointsProtocol=https"#),
403 Some(AzureStorageConfig {
404 account_name: Some("testaccount".to_string()),
405 account_key: Some("testkey".to_string()),
406 endpoint: Some("https://testaccount.blob.core.windows.net".to_string()),
407 ..Default::default()
408 }),
409 ),
410 ("missing equals",
411 (AzureStorageService::Blob, "AccountNameexample;AccountKey=example;EndpointSuffix=core.windows.net;DefaultEndpointsProtocol=https",),
412 None, ),
414 ("with invalid protocol",
415 (AzureStorageService::Blob, "DefaultEndpointsProtocol=ftp;AccountName=example;EndpointSuffix=core.windows.net",),
416 None, ),
418 ("azdls development storage",
419 (AzureStorageService::Adls, "UseDevelopmentStorage=true"),
420 Some(AzureStorageConfig::default()), ),
422 ];
423
424 for (name, (storage, conn_str), expected) in test_cases {
425 let actual = azure_config_from_connection_string(conn_str, storage);
426
427 if let Some(expected) = expected {
428 assert_azure_storage_config_eq(&actual.expect(name), &expected, name);
429 } else {
430 assert!(actual.is_err(), "Expected error for case: {name}");
431 }
432 }
433 }
434
435 #[test]
436 fn test_azure_account_name_from_endpoint() {
437 let test_cases = vec![
438 ("https://account.blob.core.windows.net", Some("account")),
439 (
440 "https://account.blob.core.usgovcloudapi.net",
441 Some("account"),
442 ),
443 (
444 "https://account.blob.core.chinacloudapi.cn",
445 Some("account"),
446 ),
447 ("https://account.dfs.core.windows.net", Some("account")),
448 ("https://account.blob.core.windows.net/", Some("account")),
449 ("https://account.blob.unknown.suffix.com", None),
450 ("http://blob.core.windows.net", None),
451 ];
452 for (endpoint, expected_account_name) in test_cases {
453 let account_name = azure_account_name_from_endpoint(endpoint);
454 assert_eq!(
455 account_name,
456 expected_account_name.map(|s| s.to_string()),
457 "Endpoint: {endpoint}"
458 );
459 }
460 }
461
462 fn assert_azure_storage_config_eq(
464 actual: &AzureStorageConfig,
465 expected: &AzureStorageConfig,
466 name: &str,
467 ) {
468 assert_eq!(
469 actual.account_name, expected.account_name,
470 "account_name mismatch: {name}"
471 );
472 assert_eq!(
473 actual.account_key, expected.account_key,
474 "account_key mismatch: {name}"
475 );
476 assert_eq!(
477 actual.endpoint, expected.endpoint,
478 "endpoint mismatch: {name}"
479 );
480 assert_eq!(
481 actual.sas_token, expected.sas_token,
482 "sas_token mismatch: {name}"
483 );
484 }
485}