1use std::collections::HashMap;
25
26use http::response::Parts;
27use http::Uri;
28use reqsign::{AzureStorageConfig, AzureStorageCredential};
29
30use crate::{Error, ErrorKind, Result};
31
32pub(crate) fn azure_config_from_connection_string(
42 conn_str: &str,
43 storage: AzureStorageService,
44) -> Result<AzureStorageConfig> {
45 let key_values = parse_connection_string(conn_str)?;
46
47 if storage == AzureStorageService::Blob {
48 if let Some(development_config) = collect_blob_development_config(&key_values, &storage) {
50 return Ok(AzureStorageConfig {
51 account_name: Some(development_config.account_name),
52 account_key: Some(development_config.account_key),
53 endpoint: Some(development_config.endpoint),
54 ..Default::default()
55 });
56 }
57 }
58
59 let mut config = AzureStorageConfig {
60 account_name: key_values.get("AccountName").cloned(),
61 endpoint: collect_endpoint(&key_values, &storage)?,
62 ..Default::default()
63 };
64
65 if let Some(creds) = collect_credentials(&key_values) {
66 set_credentials(&mut config, creds);
67 };
68
69 Ok(config)
70}
71
72#[derive(PartialEq)]
75pub(crate) enum AzureStorageService {
76 Blob,
78
79 #[cfg(feature = "services-azfile")]
81 File,
82
83 #[cfg(feature = "services-azdls")]
86 Adls,
87}
88
89pub(crate) fn azure_account_name_from_endpoint(endpoint: &str) -> Option<String> {
90 const KNOWN_ENDPOINT_SUFFIXES: &[&str] = &[
92 "core.windows.net", "core.usgovcloudapi.net", "core.chinacloudapi.cn", ];
96
97 let endpoint: &str = endpoint
98 .strip_prefix("http://")
99 .or_else(|| endpoint.strip_prefix("https://"))
100 .unwrap_or(endpoint);
101
102 let (account_name, service_endpoint) = endpoint.split_once('.')?;
103 let (_storage_service, endpoint_suffix) = service_endpoint.split_once('.')?;
104
105 if KNOWN_ENDPOINT_SUFFIXES.contains(&endpoint_suffix.trim_end_matches('/')) {
106 Some(account_name.to_string())
107 } else {
108 None
109 }
110}
111
112fn parse_connection_string(conn_str: &str) -> Result<HashMap<String, String>> {
115 conn_str
116 .trim()
117 .replace("\n", "")
118 .split(';')
119 .filter(|&field| !field.is_empty())
120 .map(|field| {
121 let (key, value) = field.trim().split_once('=').ok_or(Error::new(
122 ErrorKind::ConfigInvalid,
123 format!("Invalid connection string, expected '=' in field: {field}"),
124 ))?;
125 Ok((key.to_string(), value.to_string()))
126 })
127 .collect()
128}
129
130fn collect_blob_development_config(
131 key_values: &HashMap<String, String>,
132 storage: &AzureStorageService,
133) -> Option<DevelopmentStorageConfig> {
134 debug_assert!(
135 storage == &AzureStorageService::Blob,
136 "Azurite Development Storage only supports Blob Storage"
137 );
138
139 const AZURITE_DEFAULT_STORAGE_ACCOUNT_NAME: &str = "devstoreaccount1";
141 const AZURITE_DEFAULT_STORAGE_ACCOUNT_KEY: &str =
142 "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==";
143
144 const AZURITE_DEFAULT_BLOB_URI: &str = "http://127.0.0.1:10000";
145
146 if key_values.get("UseDevelopmentStorage") != Some(&"true".to_string()) {
147 return None; }
149
150 let account_name = key_values
151 .get("AccountName")
152 .cloned()
153 .unwrap_or(AZURITE_DEFAULT_STORAGE_ACCOUNT_NAME.to_string());
154 let account_key = key_values
155 .get("AccountKey")
156 .cloned()
157 .unwrap_or(AZURITE_DEFAULT_STORAGE_ACCOUNT_KEY.to_string());
158 let development_proxy_uri = key_values
159 .get("DevelopmentStorageProxyUri")
160 .cloned()
161 .unwrap_or(AZURITE_DEFAULT_BLOB_URI.to_string());
162
163 Some(DevelopmentStorageConfig {
164 endpoint: format!("{development_proxy_uri}/{account_name}"),
165 account_name,
166 account_key,
167 })
168}
169
170struct DevelopmentStorageConfig {
172 account_name: String,
173 account_key: String,
174 endpoint: String,
175}
176
177fn collect_endpoint(
182 key_values: &HashMap<String, String>,
183 storage: &AzureStorageService,
184) -> Result<Option<String>> {
185 match storage {
186 AzureStorageService::Blob => collect_or_build_endpoint(key_values, "BlobEndpoint", "blob"),
187 #[cfg(feature = "services-azfile")]
188 AzureStorageService::File => collect_or_build_endpoint(key_values, "FileEndpoint", "file"),
189 #[cfg(feature = "services-azdls")]
190 AzureStorageService::Adls => {
191 if let Some(dfs_endpoint) = collect_endpoint_from_parts(key_values, "dfs")? {
194 Ok(Some(dfs_endpoint.clone()))
195 } else {
196 Ok(None)
197 }
198 }
199 }
200}
201
202fn collect_credentials(key_values: &HashMap<String, String>) -> Option<AzureStorageCredential> {
203 if let Some(sas_token) = key_values.get("SharedAccessSignature") {
204 Some(AzureStorageCredential::SharedAccessSignature(
205 sas_token.clone(),
206 ))
207 } else if let (Some(account_name), Some(account_key)) =
208 (key_values.get("AccountName"), key_values.get("AccountKey"))
209 {
210 Some(AzureStorageCredential::SharedKey(
211 account_name.clone(),
212 account_key.clone(),
213 ))
214 } else {
215 None
220 }
221}
222
223fn set_credentials(config: &mut AzureStorageConfig, creds: AzureStorageCredential) {
224 match creds {
225 AzureStorageCredential::SharedAccessSignature(sas_token) => {
226 config.sas_token = Some(sas_token);
227 }
228 AzureStorageCredential::SharedKey(account_name, account_key) => {
229 config.account_name = Some(account_name);
230 config.account_key = Some(account_key);
231 }
232 AzureStorageCredential::BearerToken(_, _) => {
233 }
235 }
236}
237
238fn collect_or_build_endpoint(
239 key_values: &HashMap<String, String>,
240 endpoint_key: &str,
241 service_name: &str,
242) -> Result<Option<String>> {
243 if let Some(endpoint) = key_values.get(endpoint_key) {
244 Ok(Some(endpoint.clone()))
245 } else if let Some(built_endpoint) = collect_endpoint_from_parts(key_values, service_name)? {
246 Ok(Some(built_endpoint.clone()))
247 } else {
248 Ok(None)
249 }
250}
251
252fn collect_endpoint_from_parts(
253 key_values: &HashMap<String, String>,
254 storage_endpoint_name: &str,
255) -> Result<Option<String>> {
256 let (account_name, endpoint_suffix) = match (
257 key_values.get("AccountName"),
258 key_values.get("EndpointSuffix"),
259 ) {
260 (Some(name), Some(suffix)) => (name, suffix),
261 _ => return Ok(None), };
263
264 let protocol = key_values
265 .get("DefaultEndpointsProtocol")
266 .map(String::as_str)
267 .unwrap_or("https"); if protocol != "http" && protocol != "https" {
269 return Err(Error::new(
270 ErrorKind::ConfigInvalid,
271 format!("Invalid DefaultEndpointsProtocol: {protocol}"),
272 ));
273 }
274
275 Ok(Some(format!(
276 "{protocol}://{account_name}.{storage_endpoint_name}.{endpoint_suffix}"
277 )))
278}
279
280pub fn with_azure_error_response_context(mut err: Error, mut parts: Parts) -> Error {
288 if let Some(uri) = parts.extensions.get::<Uri>() {
289 err = err.with_context("uri", censor_sas_uri(uri));
290 }
291
292 parts.headers.remove("Set-Cookie");
294 parts.headers.remove("WWW-Authenticate");
295 parts.headers.remove("Proxy-Authenticate");
296
297 err = err.with_context("response", format!("{parts:?}"));
298
299 err
300}
301
302fn censor_sas_uri(uri: &Uri) -> String {
303 if let Some(query) = uri.query() {
304 let path = uri.path();
314 let new_query: String = query
315 .split("&")
316 .filter(|p| !p.starts_with("sig="))
317 .collect::<Vec<_>>()
318 .join("&");
319 let mut parts = uri.clone().into_parts();
320 parts.path_and_query = Some(format!("{path}?{new_query}").try_into().unwrap());
321 Uri::from_parts(parts).unwrap().to_string()
322 } else {
323 uri.to_string()
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use http::Uri;
330 use reqsign::AzureStorageConfig;
331
332 use crate::raw::azure::censor_sas_uri;
333
334 use super::{
335 azure_account_name_from_endpoint, azure_config_from_connection_string, AzureStorageService,
336 };
337
338 #[test]
339 fn test_azure_config_from_connection_string() {
340 #[allow(unused_mut)]
341 let mut test_cases = vec![
342 ("minimal fields",
343 (AzureStorageService::Blob, "BlobEndpoint=https://testaccount.blob.core.windows.net/"),
344 Some(AzureStorageConfig{
345 endpoint: Some("https://testaccount.blob.core.windows.net/".to_string()),
346 ..Default::default()
347 }),
348 ),
349 ("basic creds and blob endpoint",
350 (AzureStorageService::Blob, "AccountName=testaccount;AccountKey=testkey;BlobEndpoint=https://testaccount.blob.core.windows.net/"),
351 Some(AzureStorageConfig{
352 account_name: Some("testaccount".to_string()),
353 account_key: Some("testkey".to_string()),
354 endpoint: Some("https://testaccount.blob.core.windows.net/".to_string()),
355 ..Default::default()
356 }),
357 ),
358 ("SAS token",
359 (AzureStorageService::Blob, "SharedAccessSignature=blablabla"),
360 Some(AzureStorageConfig{
361 sas_token: Some("blablabla".to_string()),
362 ..Default::default()
363 }),
364 ),
365 ("endpoint from parts",
366 (AzureStorageService::Blob, "AccountName=testaccount;EndpointSuffix=core.windows.net;DefaultEndpointsProtocol=https"),
367 Some(AzureStorageConfig{
368 endpoint: Some("https://testaccount.blob.core.windows.net".to_string()),
369 account_name: Some("testaccount".to_string()),
370 ..Default::default()
371 }),
372 ),
373 ("endpoint from parts and no protocol",
374 (AzureStorageService::Blob, "AccountName=testaccount;EndpointSuffix=core.windows.net"),
375 Some(AzureStorageConfig{
376 endpoint: Some("https://testaccount.blob.core.windows.net".to_string()),
378 account_name: Some("testaccount".to_string()),
379 ..Default::default()
380 }),
381 ),
382 ("prefers sas over key",
383 (AzureStorageService::Blob, "AccountName=testaccount;AccountKey=testkey;SharedAccessSignature=sas_token"),
384 Some(AzureStorageConfig{
385 sas_token: Some("sas_token".to_string()),
386 account_name: Some("testaccount".to_string()),
387 ..Default::default()
388 }),
389 ),
390 ("development storage",
391 (AzureStorageService::Blob, "UseDevelopmentStorage=true",),
392 Some(AzureStorageConfig{
393 account_name: Some("devstoreaccount1".to_string()),
394 account_key: Some("Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==".to_string()),
395 endpoint: Some("http://127.0.0.1:10000/devstoreaccount1".to_string()),
396 ..Default::default()
397 }),
398 ),
399 ("development storage with custom account values",
400 (AzureStorageService::Blob, "UseDevelopmentStorage=true;AccountName=myAccount;AccountKey=myKey"),
401 Some(AzureStorageConfig {
402 endpoint: Some("http://127.0.0.1:10000/myAccount".to_string()),
403 account_name: Some("myAccount".to_string()),
404 account_key: Some("myKey".to_string()),
405 ..Default::default()
406 }),
407 ),
408 ("development storage with custom uri",
409 (AzureStorageService::Blob, "UseDevelopmentStorage=true;DevelopmentStorageProxyUri=http://127.0.0.1:12345"),
410 Some(AzureStorageConfig {
411 endpoint: Some("http://127.0.0.1:12345/devstoreaccount1".to_string()),
412 account_name: Some("devstoreaccount1".to_string()),
413 account_key: Some("Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==".to_string()),
414 ..Default::default()
415 }),
416 ),
417 ("unknown key is ignored",
418 (AzureStorageService::Blob, "SomeUnknownKey=123;BlobEndpoint=https://testaccount.blob.core.windows.net/"),
419 Some(AzureStorageConfig{
420 endpoint: Some("https://testaccount.blob.core.windows.net/".to_string()),
421 ..Default::default()
422 }),
423 ),
424 ("leading and trailing `;`",
425 (AzureStorageService::Blob, ";AccountName=testaccount;"),
426 Some(AzureStorageConfig {
427 account_name: Some("testaccount".to_string()),
428 ..Default::default()
429 }),
430 ),
431 ("line breaks",
432 (AzureStorageService::Blob, r#"
433 AccountName=testaccount;
434 AccountKey=testkey;
435 EndpointSuffix=core.windows.net;
436 DefaultEndpointsProtocol=https"#),
437 Some(AzureStorageConfig {
438 account_name: Some("testaccount".to_string()),
439 account_key: Some("testkey".to_string()),
440 endpoint: Some("https://testaccount.blob.core.windows.net".to_string()),
441 ..Default::default()
442 }),
443 ),
444 ("missing equals",
445 (AzureStorageService::Blob, "AccountNameexample;AccountKey=example;EndpointSuffix=core.windows.net;DefaultEndpointsProtocol=https",),
446 None, ),
448 ("with invalid protocol",
449 (AzureStorageService::Blob, "DefaultEndpointsProtocol=ftp;AccountName=example;EndpointSuffix=core.windows.net",),
450 None, ),
452 ];
453
454 #[cfg(feature = "services-azdls")]
455 test_cases.push(
456 ("adls endpoint from parts",
457 (AzureStorageService::Adls, "AccountName=testaccount;EndpointSuffix=core.windows.net;DefaultEndpointsProtocol=https"),
458 Some(AzureStorageConfig{
459 account_name: Some("testaccount".to_string()),
460 endpoint: Some("https://testaccount.dfs.core.windows.net".to_string()),
461 ..Default::default()
462 }),
463 )
464 );
465
466 #[cfg(feature = "services-azfile")]
467 test_cases.extend(vec![
468 (
469 "file endpoint from field",
470 (
471 AzureStorageService::File,
472 "FileEndpoint=https://testaccount.file.core.windows.net",
473 ),
474 Some(AzureStorageConfig {
475 endpoint: Some("https://testaccount.file.core.windows.net".to_string()),
476 ..Default::default()
477 }),
478 ),
479 (
480 "file endpoint from parts",
481 (
482 AzureStorageService::File,
483 "AccountName=testaccount;EndpointSuffix=core.windows.net",
484 ),
485 Some(AzureStorageConfig {
486 account_name: Some("testaccount".to_string()),
487 endpoint: Some("https://testaccount.file.core.windows.net".to_string()),
488 ..Default::default()
489 }),
490 ),
491 ]);
492
493 #[cfg(feature = "services-azdls")]
494 test_cases.push((
495 "azdls development storage",
496 (AzureStorageService::Adls, "UseDevelopmentStorage=true"),
497 Some(AzureStorageConfig::default()), ));
499
500 for (name, (storage, conn_str), expected) in test_cases {
501 let actual = azure_config_from_connection_string(conn_str, storage);
502
503 if let Some(expected) = expected {
504 assert_azure_storage_config_eq(&actual.expect(name), &expected, name);
505 } else {
506 assert!(actual.is_err(), "Expected error for case: {name}");
507 }
508 }
509 }
510
511 #[test]
512 fn test_azure_account_name_from_endpoint() {
513 let test_cases = vec![
514 ("https://account.blob.core.windows.net", Some("account")),
515 (
516 "https://account.blob.core.usgovcloudapi.net",
517 Some("account"),
518 ),
519 (
520 "https://account.blob.core.chinacloudapi.cn",
521 Some("account"),
522 ),
523 ("https://account.dfs.core.windows.net", Some("account")),
524 ("https://account.blob.core.windows.net/", Some("account")),
525 ("https://account.blob.unknown.suffix.com", None),
526 ("http://blob.core.windows.net", None),
527 ];
528 for (endpoint, expected_account_name) in test_cases {
529 let account_name = azure_account_name_from_endpoint(endpoint);
530 assert_eq!(
531 account_name,
532 expected_account_name.map(|s| s.to_string()),
533 "Endpoint: {endpoint}"
534 );
535 }
536 }
537
538 #[test]
539 fn test_azure_uri_context_removes_sig() {
540 let uri: Uri = "https://foo.bar/path?foo=foo&sig=SENSITIVE&bar=bar"
541 .parse()
542 .unwrap();
543 let expected = "https://foo.bar/path?foo=foo&bar=bar";
544 assert_eq!(censor_sas_uri(&uri), expected);
545 }
546
547 fn assert_azure_storage_config_eq(
549 actual: &AzureStorageConfig,
550 expected: &AzureStorageConfig,
551 name: &str,
552 ) {
553 assert_eq!(
554 actual.account_name, expected.account_name,
555 "account_name mismatch: {name}"
556 );
557 assert_eq!(
558 actual.account_key, expected.account_key,
559 "account_key mismatch: {name}"
560 );
561 assert_eq!(
562 actual.endpoint, expected.endpoint,
563 "endpoint mismatch: {name}"
564 );
565 assert_eq!(
566 actual.sas_token, expected.sas_token,
567 "sas_token mismatch: {name}"
568 );
569 }
570}