opendal_core/raw/oio/write/
multipart_write.rs1use std::sync::Arc;
19
20use futures::Future;
21use futures::FutureExt;
22use futures::select;
23
24use crate::raw::*;
25use crate::*;
26
27pub trait MultipartWrite: Send + Sync + Unpin + 'static {
59 fn write_once(
65 &self,
66 size: u64,
67 body: Buffer,
68 ) -> impl Future<Output = Result<Metadata>> + MaybeSend;
69
70 fn initiate_part(&self) -> impl Future<Output = Result<String>> + MaybeSend;
78
79 fn write_part(
87 &self,
88 upload_id: &str,
89 part_number: usize,
90 size: u64,
91 body: Buffer,
92 ) -> impl Future<Output = Result<MultipartPart>> + MaybeSend;
93
94 fn complete_part(
97 &self,
98 upload_id: &str,
99 parts: &[MultipartPart],
100 ) -> impl Future<Output = Result<Metadata>> + MaybeSend;
101
102 fn abort_part(&self, upload_id: &str) -> impl Future<Output = Result<()>> + MaybeSend;
104}
105
106#[derive(Clone)]
114pub struct MultipartPart {
115 pub part_number: usize,
117 pub etag: String,
119 pub checksum: Option<String>,
121 pub size: Option<u64>,
123}
124
125struct WriteInput<W: MultipartWrite> {
126 w: Arc<W>,
127 executor: Executor,
128 upload_id: Arc<String>,
129 part_number: usize,
130 bytes: Buffer,
131}
132
133pub struct MultipartWriter<W: MultipartWrite> {
136 w: Arc<W>,
137 executor: Executor,
138
139 upload_id: Option<Arc<String>>,
140 parts: Vec<MultipartPart>,
141 cache: Option<Buffer>,
142 next_part_number: usize,
143
144 tasks: ConcurrentTasks<WriteInput<W>, MultipartPart>,
145}
146
147impl<W: MultipartWrite> MultipartWriter<W> {
151 pub fn new(info: Arc<AccessorInfo>, inner: W, concurrent: usize) -> Self {
153 let w = Arc::new(inner);
154 let executor = info.executor();
155 Self {
156 w,
157 executor: executor.clone(),
158 upload_id: None,
159 parts: Vec::new(),
160 cache: None,
161 next_part_number: 0,
162
163 tasks: ConcurrentTasks::new(executor, concurrent, 8192, |input| {
164 Box::pin({
165 async move {
166 let fut = input.w.write_part(
167 &input.upload_id,
168 input.part_number,
169 input.bytes.len() as u64,
170 input.bytes.clone(),
171 );
172 match input.executor.timeout() {
173 None => {
174 let result = fut.await;
175 (input, result)
176 }
177 Some(timeout) => {
178 let result = select! {
179 result = fut.fuse() => {
180 result
181 }
182 _ = timeout.fuse() => {
183 Err(Error::new(
184 ErrorKind::Unexpected, "write part timeout")
185 .with_context("upload_id", input.upload_id.to_string())
186 .with_context("part_number", input.part_number.to_string())
187 .set_temporary())
188 }
189 };
190 (input, result)
191 }
192 }
193 }
194 })
195 }),
196 }
197 }
198
199 fn fill_cache(&mut self, bs: Buffer) -> usize {
200 let size = bs.len();
201 assert!(self.cache.is_none());
202 self.cache = Some(bs);
203 size
204 }
205}
206
207impl<W> oio::Write for MultipartWriter<W>
208where
209 W: MultipartWrite,
210{
211 async fn write(&mut self, bs: Buffer) -> Result<()> {
212 let upload_id = match self.upload_id.clone() {
213 Some(v) => v,
214 None => {
215 if self.cache.is_none() {
217 self.fill_cache(bs);
218 return Ok(());
219 }
220
221 let upload_id = self.w.initiate_part().await?;
222 let upload_id = Arc::new(upload_id);
223 self.upload_id = Some(upload_id.clone());
224 upload_id
225 }
226 };
227
228 let bytes = self.cache.clone().expect("pending write must exist");
229 let part_number = self.next_part_number;
230
231 self.tasks
232 .execute(WriteInput {
233 w: self.w.clone(),
234 executor: self.executor.clone(),
235 upload_id: upload_id.clone(),
236 part_number,
237 bytes,
238 })
239 .await?;
240 self.cache = None;
241 self.next_part_number += 1;
242 self.fill_cache(bs);
243 Ok(())
244 }
245
246 async fn close(&mut self) -> Result<Metadata> {
247 let upload_id = match self.upload_id.clone() {
248 Some(v) => v,
249 None => {
250 let (size, body) = match self.cache.clone() {
251 Some(cache) => (cache.len(), cache),
252 None => (0, Buffer::new()),
253 };
254
255 let meta = self.w.write_once(size as u64, body).await?;
257 self.cache = None;
259 return Ok(meta);
260 }
261 };
262
263 if let Some(cache) = self.cache.clone() {
264 let part_number = self.next_part_number;
265
266 self.tasks
267 .execute(WriteInput {
268 w: self.w.clone(),
269 executor: self.executor.clone(),
270 upload_id: upload_id.clone(),
271 part_number,
272 bytes: cache,
273 })
274 .await?;
275 self.cache = None;
276 self.next_part_number += 1;
277 }
278
279 loop {
280 let Some(result) = self.tasks.next().await.transpose()? else {
281 break;
282 };
283 self.parts.push(result)
284 }
285
286 if self.parts.len() != self.next_part_number {
287 return Err(Error::new(
288 ErrorKind::Unexpected,
289 "multipart part numbers mismatch, please report bug to opendal",
290 )
291 .with_context("expected", self.next_part_number)
292 .with_context("actual", self.parts.len())
293 .with_context("upload_id", upload_id));
294 }
295 self.w.complete_part(&upload_id, &self.parts).await
296 }
297
298 async fn abort(&mut self) -> Result<()> {
299 let Some(upload_id) = self.upload_id.clone() else {
300 return Ok(());
301 };
302
303 self.tasks.clear();
304 self.cache = None;
305 self.w.abort_part(&upload_id).await?;
306 Ok(())
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use mea::mutex::Mutex;
313 use pretty_assertions::assert_eq;
314 use rand::Rng;
315 use rand::RngCore;
316 use rand::thread_rng;
317 use tokio::time::sleep;
318 use tokio::time::timeout;
319
320 use super::*;
321 use crate::raw::oio::Write;
322
323 struct TestWrite {
324 upload_id: String,
325 part_numbers: Vec<usize>,
326 length: u64,
327 content: Option<Buffer>,
328 }
329
330 impl TestWrite {
331 pub fn new() -> Arc<Mutex<Self>> {
332 let v = Self {
333 upload_id: uuid::Uuid::new_v4().to_string(),
334 part_numbers: Vec::new(),
335 length: 0,
336 content: None,
337 };
338
339 Arc::new(Mutex::new(v))
340 }
341 }
342
343 impl MultipartWrite for Arc<Mutex<TestWrite>> {
344 async fn write_once(&self, size: u64, body: Buffer) -> Result<Metadata> {
345 sleep(Duration::from_nanos(50)).await;
346
347 if thread_rng().gen_bool(1.0 / 10.0) {
348 return Err(
349 Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!").set_temporary()
350 );
351 }
352
353 let mut this = self.lock().await;
354 this.length = size;
355 this.content = Some(body);
356 Ok(Metadata::default().with_content_length(size))
357 }
358
359 async fn initiate_part(&self) -> Result<String> {
360 let upload_id = self.lock().await.upload_id.clone();
361 Ok(upload_id)
362 }
363
364 async fn write_part(
365 &self,
366 upload_id: &str,
367 part_number: usize,
368 size: u64,
369 _: Buffer,
370 ) -> Result<MultipartPart> {
371 {
372 let test = self.lock().await;
373 assert_eq!(upload_id, test.upload_id);
374 }
375
376 sleep(Duration::from_nanos(50)).await;
378
379 if thread_rng().gen_bool(1.0 / 10.0) {
381 return Err(
382 Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!").set_temporary()
383 );
384 }
385
386 {
387 let mut test = self.lock().await;
388 test.part_numbers.push(part_number);
389 test.length += size;
390 }
391
392 Ok(MultipartPart {
393 part_number,
394 etag: "etag".to_string(),
395 checksum: None,
396 size: None,
397 })
398 }
399
400 async fn complete_part(
401 &self,
402 upload_id: &str,
403 parts: &[MultipartPart],
404 ) -> Result<Metadata> {
405 let test = self.lock().await;
406 assert_eq!(upload_id, test.upload_id);
407 assert_eq!(parts.len(), test.part_numbers.len());
408
409 Ok(Metadata::default().with_content_length(test.length))
410 }
411
412 async fn abort_part(&self, upload_id: &str) -> Result<()> {
413 let test = self.lock().await;
414 assert_eq!(upload_id, test.upload_id);
415
416 Ok(())
417 }
418 }
419
420 struct TimeoutExecutor {
421 exec: Arc<dyn Execute>,
422 }
423
424 impl TimeoutExecutor {
425 pub fn new() -> Self {
426 Self {
427 exec: Executor::new().into_inner(),
428 }
429 }
430 }
431
432 impl Execute for TimeoutExecutor {
433 fn execute(&self, f: BoxedStaticFuture<()>) {
434 self.exec.execute(f)
435 }
436
437 fn timeout(&self) -> Option<BoxedStaticFuture<()>> {
438 let time = thread_rng().gen_range(0..100);
439 Some(Box::pin(tokio::time::sleep(Duration::from_nanos(time))))
440 }
441 }
442
443 #[tokio::test]
444 async fn test_multipart_upload_writer_with_concurrent_errors() {
445 let mut rng = thread_rng();
446
447 let info = Arc::new(AccessorInfo::default());
448 info.update_executor(|_| Executor::with(TimeoutExecutor::new()));
449
450 let mut w = MultipartWriter::new(info, TestWrite::new(), 200);
451 let mut total_size = 0u64;
452
453 for _ in 0..1000 {
454 let size = rng.gen_range(1..1024);
455 total_size += size as u64;
456
457 let mut bs = vec![0; size];
458 rng.fill_bytes(&mut bs);
459
460 loop {
461 match timeout(Duration::from_nanos(10), w.write(bs.clone().into())).await {
462 Ok(Ok(_)) => break,
463 Ok(Err(_)) => continue,
464 Err(_) => {
465 continue;
466 }
467 }
468 }
469 }
470
471 loop {
472 match timeout(Duration::from_nanos(10), w.close()).await {
473 Ok(Ok(_)) => break,
474 Ok(Err(_)) => continue,
475 Err(_) => {
476 continue;
477 }
478 }
479 }
480
481 let actual_parts: Vec<_> = w.parts.into_iter().map(|v| v.part_number).collect();
482 let expected_parts: Vec<_> = (0..1000).collect();
483 assert_eq!(actual_parts, expected_parts);
484
485 let actual_size = w.w.lock().await.length;
486 assert_eq!(actual_size, total_size);
487 }
488
489 #[tokio::test]
490 async fn test_multipart_writer_with_retry_when_write_once_error() {
491 let mut rng = thread_rng();
492
493 for _ in 0..100 {
494 let mut w = MultipartWriter::new(Arc::default(), TestWrite::new(), 200);
495 let size = rng.gen_range(1..1024);
496 let mut bs = vec![0; size];
497 rng.fill_bytes(&mut bs);
498
499 loop {
500 match w.write(bs.clone().into()).await {
501 Ok(_) => break,
502 Err(_) => continue,
503 }
504 }
505
506 loop {
507 match w.close().await {
508 Ok(_) => break,
509 Err(_) => continue,
510 }
511 }
512
513 let inner = w.w.lock().await;
514 assert_eq!(inner.length, size as u64);
515 assert!(inner.content.is_some());
516 assert_eq!(inner.content.clone().unwrap().to_bytes(), bs);
517 }
518 }
519}