1use std::sync::Arc;
19
20use futures::select;
21use futures::Future;
22use futures::FutureExt;
23
24use crate::raw::*;
25use crate::*;
26
27pub trait MultipartWrite: Send + Sync + Unpin + 'static {
59 fn write_once(
65 &self,
66 size: u64,
67 body: Buffer,
68 ) -> impl Future<Output = Result<Metadata>> + MaybeSend;
69
70 fn initiate_part(&self) -> impl Future<Output = Result<String>> + MaybeSend;
78
79 fn write_part(
87 &self,
88 upload_id: &str,
89 part_number: usize,
90 size: u64,
91 body: Buffer,
92 ) -> impl Future<Output = Result<MultipartPart>> + MaybeSend;
93
94 fn complete_part(
97 &self,
98 upload_id: &str,
99 parts: &[MultipartPart],
100 ) -> impl Future<Output = Result<Metadata>> + MaybeSend;
101
102 fn abort_part(&self, upload_id: &str) -> impl Future<Output = Result<()>> + MaybeSend;
104}
105
106#[derive(Clone)]
114pub struct MultipartPart {
115 pub part_number: usize,
117 pub etag: String,
119 pub checksum: Option<String>,
121}
122
123struct WriteInput<W: MultipartWrite> {
124 w: Arc<W>,
125 executor: Executor,
126 upload_id: Arc<String>,
127 part_number: usize,
128 bytes: Buffer,
129}
130
131pub struct MultipartWriter<W: MultipartWrite> {
134 w: Arc<W>,
135 executor: Executor,
136
137 upload_id: Option<Arc<String>>,
138 parts: Vec<MultipartPart>,
139 cache: Option<Buffer>,
140 next_part_number: usize,
141
142 tasks: ConcurrentTasks<WriteInput<W>, MultipartPart>,
143}
144
145impl<W: MultipartWrite> MultipartWriter<W> {
149 pub fn new(info: Arc<AccessorInfo>, inner: W, concurrent: usize) -> Self {
151 let w = Arc::new(inner);
152 let executor = info.executor();
153 Self {
154 w,
155 executor: executor.clone(),
156 upload_id: None,
157 parts: Vec::new(),
158 cache: None,
159 next_part_number: 0,
160
161 tasks: ConcurrentTasks::new(executor, concurrent, |input| {
162 Box::pin({
163 async move {
164 let fut = input.w.write_part(
165 &input.upload_id,
166 input.part_number,
167 input.bytes.len() as u64,
168 input.bytes.clone(),
169 );
170 match input.executor.timeout() {
171 None => {
172 let result = fut.await;
173 (input, result)
174 }
175 Some(timeout) => {
176 let result = select! {
177 result = fut.fuse() => {
178 result
179 }
180 _ = timeout.fuse() => {
181 Err(Error::new(
182 ErrorKind::Unexpected, "write part timeout")
183 .with_context("upload_id", input.upload_id.to_string())
184 .with_context("part_number", input.part_number.to_string())
185 .set_temporary())
186 }
187 };
188 (input, result)
189 }
190 }
191 }
192 })
193 }),
194 }
195 }
196
197 fn fill_cache(&mut self, bs: Buffer) -> usize {
198 let size = bs.len();
199 assert!(self.cache.is_none());
200 self.cache = Some(bs);
201 size
202 }
203}
204
205impl<W> oio::Write for MultipartWriter<W>
206where
207 W: MultipartWrite,
208{
209 async fn write(&mut self, bs: Buffer) -> Result<()> {
210 let upload_id = match self.upload_id.clone() {
211 Some(v) => v,
212 None => {
213 if self.cache.is_none() {
215 self.fill_cache(bs);
216 return Ok(());
217 }
218
219 let upload_id = self.w.initiate_part().await?;
220 let upload_id = Arc::new(upload_id);
221 self.upload_id = Some(upload_id.clone());
222 upload_id
223 }
224 };
225
226 let bytes = self.cache.clone().expect("pending write must exist");
227 let part_number = self.next_part_number;
228
229 self.tasks
230 .execute(WriteInput {
231 w: self.w.clone(),
232 executor: self.executor.clone(),
233 upload_id: upload_id.clone(),
234 part_number,
235 bytes,
236 })
237 .await?;
238 self.cache = None;
239 self.next_part_number += 1;
240 self.fill_cache(bs);
241 Ok(())
242 }
243
244 async fn close(&mut self) -> Result<Metadata> {
245 let upload_id = match self.upload_id.clone() {
246 Some(v) => v,
247 None => {
248 let (size, body) = match self.cache.clone() {
249 Some(cache) => (cache.len(), cache),
250 None => (0, Buffer::new()),
251 };
252
253 let meta = self.w.write_once(size as u64, body).await?;
255 self.cache = None;
257 return Ok(meta);
258 }
259 };
260
261 if let Some(cache) = self.cache.clone() {
262 let part_number = self.next_part_number;
263
264 self.tasks
265 .execute(WriteInput {
266 w: self.w.clone(),
267 executor: self.executor.clone(),
268 upload_id: upload_id.clone(),
269 part_number,
270 bytes: cache,
271 })
272 .await?;
273 self.cache = None;
274 self.next_part_number += 1;
275 }
276
277 loop {
278 let Some(result) = self.tasks.next().await.transpose()? else {
279 break;
280 };
281 self.parts.push(result)
282 }
283
284 if self.parts.len() != self.next_part_number {
285 return Err(Error::new(
286 ErrorKind::Unexpected,
287 "multipart part numbers mismatch, please report bug to opendal",
288 )
289 .with_context("expected", self.next_part_number)
290 .with_context("actual", self.parts.len())
291 .with_context("upload_id", upload_id));
292 }
293 self.w.complete_part(&upload_id, &self.parts).await
294 }
295
296 async fn abort(&mut self) -> Result<()> {
297 let Some(upload_id) = self.upload_id.clone() else {
298 return Ok(());
299 };
300
301 self.tasks.clear();
302 self.cache = None;
303 self.w.abort_part(&upload_id).await?;
304 Ok(())
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use std::time::Duration;
311
312 use pretty_assertions::assert_eq;
313 use rand::thread_rng;
314 use rand::Rng;
315 use rand::RngCore;
316 use tokio::sync::Mutex;
317 use tokio::time::sleep;
318 use tokio::time::timeout;
319
320 use super::*;
321 use crate::raw::oio::Write;
322
323 struct TestWrite {
324 upload_id: String,
325 part_numbers: Vec<usize>,
326 length: u64,
327 content: Option<Buffer>,
328 }
329
330 impl TestWrite {
331 pub fn new() -> Arc<Mutex<Self>> {
332 let v = Self {
333 upload_id: uuid::Uuid::new_v4().to_string(),
334 part_numbers: Vec::new(),
335 length: 0,
336 content: None,
337 };
338
339 Arc::new(Mutex::new(v))
340 }
341 }
342
343 impl MultipartWrite for Arc<Mutex<TestWrite>> {
344 async fn write_once(&self, size: u64, body: Buffer) -> Result<Metadata> {
345 sleep(Duration::from_nanos(50)).await;
346
347 if thread_rng().gen_bool(1.0 / 10.0) {
348 return Err(
349 Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!").set_temporary()
350 );
351 }
352
353 let mut this = self.lock().await;
354 this.length = size;
355 this.content = Some(body);
356 Ok(Metadata::default().with_content_length(size))
357 }
358
359 async fn initiate_part(&self) -> Result<String> {
360 let upload_id = self.lock().await.upload_id.clone();
361 Ok(upload_id)
362 }
363
364 async fn write_part(
365 &self,
366 upload_id: &str,
367 part_number: usize,
368 size: u64,
369 _: Buffer,
370 ) -> Result<MultipartPart> {
371 {
372 let test = self.lock().await;
373 assert_eq!(upload_id, test.upload_id);
374 }
375
376 sleep(Duration::from_nanos(50)).await;
378
379 if thread_rng().gen_bool(1.0 / 10.0) {
381 return Err(
382 Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!").set_temporary()
383 );
384 }
385
386 {
387 let mut test = self.lock().await;
388 test.part_numbers.push(part_number);
389 test.length += size;
390 }
391
392 Ok(MultipartPart {
393 part_number,
394 etag: "etag".to_string(),
395 checksum: None,
396 })
397 }
398
399 async fn complete_part(
400 &self,
401 upload_id: &str,
402 parts: &[MultipartPart],
403 ) -> Result<Metadata> {
404 let test = self.lock().await;
405 assert_eq!(upload_id, test.upload_id);
406 assert_eq!(parts.len(), test.part_numbers.len());
407
408 Ok(Metadata::default().with_content_length(test.length))
409 }
410
411 async fn abort_part(&self, upload_id: &str) -> Result<()> {
412 let test = self.lock().await;
413 assert_eq!(upload_id, test.upload_id);
414
415 Ok(())
416 }
417 }
418
419 struct TimeoutExecutor {
420 exec: Arc<dyn Execute>,
421 }
422
423 impl TimeoutExecutor {
424 pub fn new() -> Self {
425 Self {
426 exec: Executor::new().into_inner(),
427 }
428 }
429 }
430
431 impl Execute for TimeoutExecutor {
432 fn execute(&self, f: BoxedStaticFuture<()>) {
433 self.exec.execute(f)
434 }
435
436 fn timeout(&self) -> Option<BoxedStaticFuture<()>> {
437 let time = thread_rng().gen_range(0..100);
438 Some(Box::pin(tokio::time::sleep(Duration::from_nanos(time))))
439 }
440 }
441
442 #[tokio::test]
443 async fn test_multipart_upload_writer_with_concurrent_errors() {
444 let mut rng = thread_rng();
445
446 let info = Arc::new(AccessorInfo::default());
447 info.update_executor(|_| Executor::with(TimeoutExecutor::new()));
448
449 let mut w = MultipartWriter::new(info, TestWrite::new(), 200);
450 let mut total_size = 0u64;
451
452 for _ in 0..1000 {
453 let size = rng.gen_range(1..1024);
454 total_size += size as u64;
455
456 let mut bs = vec![0; size];
457 rng.fill_bytes(&mut bs);
458
459 loop {
460 match timeout(Duration::from_nanos(10), w.write(bs.clone().into())).await {
461 Ok(Ok(_)) => break,
462 Ok(Err(_)) => continue,
463 Err(_) => {
464 continue;
465 }
466 }
467 }
468 }
469
470 loop {
471 match timeout(Duration::from_nanos(10), w.close()).await {
472 Ok(Ok(_)) => break,
473 Ok(Err(_)) => continue,
474 Err(_) => {
475 continue;
476 }
477 }
478 }
479
480 let actual_parts: Vec<_> = w.parts.into_iter().map(|v| v.part_number).collect();
481 let expected_parts: Vec<_> = (0..1000).collect();
482 assert_eq!(actual_parts, expected_parts);
483
484 let actual_size = w.w.lock().await.length;
485 assert_eq!(actual_size, total_size);
486 }
487
488 #[tokio::test]
489 async fn test_multipart_writer_with_retry_when_write_once_error() {
490 let mut rng = thread_rng();
491
492 for _ in 0..100 {
493 let mut w = MultipartWriter::new(Arc::default(), TestWrite::new(), 200);
494 let size = rng.gen_range(1..1024);
495 let mut bs = vec![0; size];
496 rng.fill_bytes(&mut bs);
497
498 loop {
499 match w.write(bs.clone().into()).await {
500 Ok(_) => break,
501 Err(_) => continue,
502 }
503 }
504
505 loop {
506 match w.close().await {
507 Ok(_) => break,
508 Err(_) => continue,
509 }
510 }
511
512 let inner = w.w.lock().await;
513 assert_eq!(inner.length, size as u64);
514 assert!(inner.content.is_some());
515 assert_eq!(inner.content.clone().unwrap().to_bytes(), bs);
516 }
517 }
518}