1use std::collections::HashMap;
25
26use reqsign::{AzureStorageConfig, AzureStorageCredential};
27
28use crate::{Error, ErrorKind, Result};
29
30pub(crate) fn azure_config_from_connection_string(
40 conn_str: &str,
41 storage: AzureStorageService,
42) -> Result<AzureStorageConfig> {
43 let key_values = parse_connection_string(conn_str)?;
44
45 if storage == AzureStorageService::Blob {
46 if let Some(development_config) = collect_blob_development_config(&key_values, &storage) {
48 return Ok(AzureStorageConfig {
49 account_name: Some(development_config.account_name),
50 account_key: Some(development_config.account_key),
51 endpoint: Some(development_config.endpoint),
52 ..Default::default()
53 });
54 }
55 }
56
57 let mut config = AzureStorageConfig {
58 account_name: key_values.get("AccountName").cloned(),
59 endpoint: collect_endpoint(&key_values, &storage)?,
60 ..Default::default()
61 };
62
63 if let Some(creds) = collect_credentials(&key_values) {
64 set_credentials(&mut config, creds);
65 };
66
67 Ok(config)
68}
69
70#[derive(PartialEq)]
73pub(crate) enum AzureStorageService {
74 Blob,
76
77 File,
79
80 Adls,
83}
84
85pub(crate) fn azure_account_name_from_endpoint(endpoint: &str) -> Option<String> {
86 const KNOWN_ENDPOINT_SUFFIXES: &[&str] = &[
88 "core.windows.net", "core.usgovcloudapi.net", "core.chinacloudapi.cn", ];
92
93 let endpoint: &str = endpoint
94 .strip_prefix("http://")
95 .or_else(|| endpoint.strip_prefix("https://"))
96 .unwrap_or(endpoint);
97
98 let (account_name, service_endpoint) = endpoint.split_once('.')?;
99 let (_storage_service, endpoint_suffix) = service_endpoint.split_once('.')?;
100
101 if KNOWN_ENDPOINT_SUFFIXES.contains(&endpoint_suffix.trim_end_matches('/')) {
102 Some(account_name.to_string())
103 } else {
104 None
105 }
106}
107
108fn parse_connection_string(conn_str: &str) -> Result<HashMap<String, String>> {
111 conn_str
112 .trim()
113 .replace("\n", "")
114 .split(';')
115 .filter(|&field| !field.is_empty())
116 .map(|field| {
117 let (key, value) = field.trim().split_once('=').ok_or(Error::new(
118 ErrorKind::ConfigInvalid,
119 format!(
120 "Invalid connection string, expected '=' in field: {}",
121 field
122 ),
123 ))?;
124 Ok((key.to_string(), value.to_string()))
125 })
126 .collect()
127}
128
129fn collect_blob_development_config(
130 key_values: &HashMap<String, String>,
131 storage: &AzureStorageService,
132) -> Option<DevelopmentStorageConfig> {
133 debug_assert!(
134 storage == &AzureStorageService::Blob,
135 "Azurite Development Storage only supports Blob Storage"
136 );
137
138 const AZURITE_DEFAULT_STORAGE_ACCOUNT_NAME: &str = "devstoreaccount1";
140 const AZURITE_DEFAULT_STORAGE_ACCOUNT_KEY: &str =
141 "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==";
142
143 const AZURITE_DEFAULT_BLOB_URI: &str = "http://127.0.0.1:10000";
144
145 if key_values.get("UseDevelopmentStorage") != Some(&"true".to_string()) {
146 return None; }
148
149 let account_name = key_values
150 .get("AccountName")
151 .cloned()
152 .unwrap_or(AZURITE_DEFAULT_STORAGE_ACCOUNT_NAME.to_string());
153 let account_key = key_values
154 .get("AccountKey")
155 .cloned()
156 .unwrap_or(AZURITE_DEFAULT_STORAGE_ACCOUNT_KEY.to_string());
157 let development_proxy_uri = key_values
158 .get("DevelopmentStorageProxyUri")
159 .cloned()
160 .unwrap_or(AZURITE_DEFAULT_BLOB_URI.to_string());
161
162 Some(DevelopmentStorageConfig {
163 endpoint: format!("{development_proxy_uri}/{account_name}"),
164 account_name,
165 account_key,
166 })
167}
168
169struct DevelopmentStorageConfig {
171 account_name: String,
172 account_key: String,
173 endpoint: String,
174}
175
176fn collect_endpoint(
181 key_values: &HashMap<String, String>,
182 storage: &AzureStorageService,
183) -> Result<Option<String>> {
184 match storage {
185 AzureStorageService::Blob => collect_or_build_endpoint(key_values, "BlobEndpoint", "blob"),
186 AzureStorageService::File => collect_or_build_endpoint(key_values, "FileEndpoint", "file"),
187 AzureStorageService::Adls => {
188 if let Some(dfs_endpoint) = collect_endpoint_from_parts(key_values, "dfs")? {
191 Ok(Some(dfs_endpoint.clone()))
192 } else {
193 Ok(None)
194 }
195 }
196 }
197}
198
199fn collect_credentials(key_values: &HashMap<String, String>) -> Option<AzureStorageCredential> {
200 if let Some(sas_token) = key_values.get("SharedAccessSignature") {
201 Some(AzureStorageCredential::SharedAccessSignature(
202 sas_token.clone(),
203 ))
204 } else if let (Some(account_name), Some(account_key)) =
205 (key_values.get("AccountName"), key_values.get("AccountKey"))
206 {
207 Some(AzureStorageCredential::SharedKey(
208 account_name.clone(),
209 account_key.clone(),
210 ))
211 } else {
212 None
217 }
218}
219
220fn set_credentials(config: &mut AzureStorageConfig, creds: AzureStorageCredential) {
221 match creds {
222 AzureStorageCredential::SharedAccessSignature(sas_token) => {
223 config.sas_token = Some(sas_token);
224 }
225 AzureStorageCredential::SharedKey(account_name, account_key) => {
226 config.account_name = Some(account_name);
227 config.account_key = Some(account_key);
228 }
229 AzureStorageCredential::BearerToken(_, _) => {
230 }
232 }
233}
234
235fn collect_or_build_endpoint(
236 key_values: &HashMap<String, String>,
237 endpoint_key: &str,
238 service_name: &str,
239) -> Result<Option<String>> {
240 if let Some(endpoint) = key_values.get(endpoint_key) {
241 Ok(Some(endpoint.clone()))
242 } else if let Some(built_endpoint) = collect_endpoint_from_parts(key_values, service_name)? {
243 Ok(Some(built_endpoint.clone()))
244 } else {
245 Ok(None)
246 }
247}
248
249fn collect_endpoint_from_parts(
250 key_values: &HashMap<String, String>,
251 storage_endpoint_name: &str,
252) -> Result<Option<String>> {
253 let (account_name, endpoint_suffix) = match (
254 key_values.get("AccountName"),
255 key_values.get("EndpointSuffix"),
256 ) {
257 (Some(name), Some(suffix)) => (name, suffix),
258 _ => return Ok(None), };
260
261 let protocol = key_values
262 .get("DefaultEndpointsProtocol")
263 .map(String::as_str)
264 .unwrap_or("https"); if protocol != "http" && protocol != "https" {
266 return Err(Error::new(
267 ErrorKind::ConfigInvalid,
268 format!("Invalid DefaultEndpointsProtocol: {}", protocol),
269 ));
270 }
271
272 Ok(Some(format!(
273 "{protocol}://{account_name}.{storage_endpoint_name}.{endpoint_suffix}"
274 )))
275}
276
277#[cfg(test)]
278mod tests {
279 use reqsign::AzureStorageConfig;
280
281 use super::{
282 azure_account_name_from_endpoint, azure_config_from_connection_string, AzureStorageService,
283 };
284
285 #[test]
286 fn test_azure_config_from_connection_string() {
287 let test_cases = vec![
288 ("minimal fields",
289 (AzureStorageService::Blob, "BlobEndpoint=https://testaccount.blob.core.windows.net/"),
290 Some(AzureStorageConfig{
291 endpoint: Some("https://testaccount.blob.core.windows.net/".to_string()),
292 ..Default::default()
293 }),
294 ),
295 ("basic creds and blob endpoint",
296 (AzureStorageService::Blob, "AccountName=testaccount;AccountKey=testkey;BlobEndpoint=https://testaccount.blob.core.windows.net/"),
297 Some(AzureStorageConfig{
298 account_name: Some("testaccount".to_string()),
299 account_key: Some("testkey".to_string()),
300 endpoint: Some("https://testaccount.blob.core.windows.net/".to_string()),
301 ..Default::default()
302 }),
303 ),
304 ("SAS token",
305 (AzureStorageService::Blob, "SharedAccessSignature=blablabla"),
306 Some(AzureStorageConfig{
307 sas_token: Some("blablabla".to_string()),
308 ..Default::default()
309 }),
310 ),
311 ("endpoint from parts",
312 (AzureStorageService::Blob, "AccountName=testaccount;EndpointSuffix=core.windows.net;DefaultEndpointsProtocol=https"),
313 Some(AzureStorageConfig{
314 endpoint: Some("https://testaccount.blob.core.windows.net".to_string()),
315 account_name: Some("testaccount".to_string()),
316 ..Default::default()
317 }),
318 ),
319 ("endpoint from parts and no protocol",
320 (AzureStorageService::Blob, "AccountName=testaccount;EndpointSuffix=core.windows.net"),
321 Some(AzureStorageConfig{
322 endpoint: Some("https://testaccount.blob.core.windows.net".to_string()),
324 account_name: Some("testaccount".to_string()),
325 ..Default::default()
326 }),
327 ),
328 ("adls endpoint from parts",
329 (AzureStorageService::Adls, "AccountName=testaccount;EndpointSuffix=core.windows.net;DefaultEndpointsProtocol=https"),
330 Some(AzureStorageConfig{
331 account_name: Some("testaccount".to_string()),
332 endpoint: Some("https://testaccount.dfs.core.windows.net".to_string()),
333 ..Default::default()
334 }),
335 ),
336 ("file endpoint from field",
337 (AzureStorageService::File, "FileEndpoint=https://testaccount.file.core.windows.net"),
338 Some(AzureStorageConfig{
339 endpoint: Some("https://testaccount.file.core.windows.net".to_string()),
340 ..Default::default()
341 })
342 ),
343 ("file endpoint from parts",
344 (AzureStorageService::File, "AccountName=testaccount;EndpointSuffix=core.windows.net"),
345 Some(AzureStorageConfig{
346 account_name: Some("testaccount".to_string()),
347 endpoint: Some("https://testaccount.file.core.windows.net".to_string()),
348 ..Default::default()
349 }),
350 ),
351 ("prefers sas over key",
352 (AzureStorageService::Blob, "AccountName=testaccount;AccountKey=testkey;SharedAccessSignature=sas_token"),
353 Some(AzureStorageConfig{
354 sas_token: Some("sas_token".to_string()),
355 account_name: Some("testaccount".to_string()),
356 ..Default::default()
357 }),
358 ),
359 ("development storage",
360 (AzureStorageService::Blob, "UseDevelopmentStorage=true",),
361 Some(AzureStorageConfig{
362 account_name: Some("devstoreaccount1".to_string()),
363 account_key: Some("Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==".to_string()),
364 endpoint: Some("http://127.0.0.1:10000/devstoreaccount1".to_string()),
365 ..Default::default()
366 }),
367 ),
368 ("development storage with custom account values",
369 (AzureStorageService::Blob, "UseDevelopmentStorage=true;AccountName=myAccount;AccountKey=myKey"),
370 Some(AzureStorageConfig {
371 endpoint: Some("http://127.0.0.1:10000/myAccount".to_string()),
372 account_name: Some("myAccount".to_string()),
373 account_key: Some("myKey".to_string()),
374 ..Default::default()
375 }),
376 ),
377 ("development storage with custom uri",
378 (AzureStorageService::Blob, "UseDevelopmentStorage=true;DevelopmentStorageProxyUri=http://127.0.0.1:12345"),
379 Some(AzureStorageConfig {
380 endpoint: Some("http://127.0.0.1:12345/devstoreaccount1".to_string()),
381 account_name: Some("devstoreaccount1".to_string()),
382 account_key: Some("Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==".to_string()),
383 ..Default::default()
384 }),
385 ),
386 ("unknown key is ignored",
387 (AzureStorageService::Blob, "SomeUnknownKey=123;BlobEndpoint=https://testaccount.blob.core.windows.net/"),
388 Some(AzureStorageConfig{
389 endpoint: Some("https://testaccount.blob.core.windows.net/".to_string()),
390 ..Default::default()
391 }),
392 ),
393 ("leading and trailing `;`",
394 (AzureStorageService::Blob, ";AccountName=testaccount;"),
395 Some(AzureStorageConfig {
396 account_name: Some("testaccount".to_string()),
397 ..Default::default()
398 }),
399 ),
400 ("line breaks",
401 (AzureStorageService::Blob, r#"
402 AccountName=testaccount;
403 AccountKey=testkey;
404 EndpointSuffix=core.windows.net;
405 DefaultEndpointsProtocol=https"#),
406 Some(AzureStorageConfig {
407 account_name: Some("testaccount".to_string()),
408 account_key: Some("testkey".to_string()),
409 endpoint: Some("https://testaccount.blob.core.windows.net".to_string()),
410 ..Default::default()
411 }),
412 ),
413 ("missing equals",
414 (AzureStorageService::Blob, "AccountNameexample;AccountKey=example;EndpointSuffix=core.windows.net;DefaultEndpointsProtocol=https",),
415 None, ),
417 ("with invalid protocol",
418 (AzureStorageService::Blob, "DefaultEndpointsProtocol=ftp;AccountName=example;EndpointSuffix=core.windows.net",),
419 None, ),
421 ("azdls development storage",
422 (AzureStorageService::Adls, "UseDevelopmentStorage=true"),
423 Some(AzureStorageConfig::default()), ),
425 ];
426
427 for (name, (storage, conn_str), expected) in test_cases {
428 let actual = azure_config_from_connection_string(conn_str, storage);
429
430 if let Some(expected) = expected {
431 assert_azure_storage_config_eq(&actual.expect(name), &expected, name);
432 } else {
433 assert!(actual.is_err(), "Expected error for case: {}", name);
434 }
435 }
436 }
437
438 #[test]
439 fn test_azure_account_name_from_endpoint() {
440 let test_cases = vec![
441 ("https://account.blob.core.windows.net", Some("account")),
442 (
443 "https://account.blob.core.usgovcloudapi.net",
444 Some("account"),
445 ),
446 (
447 "https://account.blob.core.chinacloudapi.cn",
448 Some("account"),
449 ),
450 ("https://account.dfs.core.windows.net", Some("account")),
451 ("https://account.blob.core.windows.net/", Some("account")),
452 ("https://account.blob.unknown.suffix.com", None),
453 ("http://blob.core.windows.net", None),
454 ];
455 for (endpoint, expected_account_name) in test_cases {
456 let account_name = azure_account_name_from_endpoint(endpoint);
457 assert_eq!(
458 account_name,
459 expected_account_name.map(|s| s.to_string()),
460 "Endpoint: {}",
461 endpoint
462 );
463 }
464 }
465
466 fn assert_azure_storage_config_eq(
468 actual: &AzureStorageConfig,
469 expected: &AzureStorageConfig,
470 name: &str,
471 ) {
472 assert_eq!(
473 actual.account_name, expected.account_name,
474 "account_name mismatch: {}",
475 name
476 );
477 assert_eq!(
478 actual.account_key, expected.account_key,
479 "account_key mismatch: {}",
480 name
481 );
482 assert_eq!(
483 actual.endpoint, expected.endpoint,
484 "endpoint mismatch: {}",
485 name
486 );
487 assert_eq!(
488 actual.sas_token, expected.sas_token,
489 "sas_token mismatch: {}",
490 name
491 );
492 }
493}