opendal_core/raw/oio/write/
multipart_write.rs1use std::sync::Arc;
19
20use futures::Future;
21use futures::FutureExt;
22use futures::select;
23
24use crate::raw::*;
25use crate::*;
26
27pub trait MultipartWrite: Send + Sync + Unpin + 'static {
59 fn write_once(
65 &self,
66 size: u64,
67 body: Buffer,
68 ) -> impl Future<Output = Result<Metadata>> + MaybeSend;
69
70 fn initiate_part(&self) -> impl Future<Output = Result<String>> + MaybeSend;
78
79 fn write_part(
87 &self,
88 upload_id: &str,
89 part_number: usize,
90 size: u64,
91 body: Buffer,
92 ) -> impl Future<Output = Result<MultipartPart>> + MaybeSend;
93
94 fn complete_part(
97 &self,
98 upload_id: &str,
99 parts: &[MultipartPart],
100 ) -> impl Future<Output = Result<Metadata>> + MaybeSend;
101
102 fn abort_part(&self, upload_id: &str) -> impl Future<Output = Result<()>> + MaybeSend;
104}
105
106#[derive(Clone)]
114pub struct MultipartPart {
115 pub part_number: usize,
117 pub etag: String,
119 pub checksum: Option<String>,
121}
122
123struct WriteInput<W: MultipartWrite> {
124 w: Arc<W>,
125 executor: Executor,
126 upload_id: Arc<String>,
127 part_number: usize,
128 bytes: Buffer,
129}
130
131pub struct MultipartWriter<W: MultipartWrite> {
134 w: Arc<W>,
135 executor: Executor,
136
137 upload_id: Option<Arc<String>>,
138 parts: Vec<MultipartPart>,
139 cache: Option<Buffer>,
140 next_part_number: usize,
141
142 tasks: ConcurrentTasks<WriteInput<W>, MultipartPart>,
143}
144
145impl<W: MultipartWrite> MultipartWriter<W> {
149 pub fn new(info: Arc<AccessorInfo>, inner: W, concurrent: usize) -> Self {
151 let w = Arc::new(inner);
152 let executor = info.executor();
153 Self {
154 w,
155 executor: executor.clone(),
156 upload_id: None,
157 parts: Vec::new(),
158 cache: None,
159 next_part_number: 0,
160
161 tasks: ConcurrentTasks::new(executor, concurrent, 8192, |input| {
162 Box::pin({
163 async move {
164 let fut = input.w.write_part(
165 &input.upload_id,
166 input.part_number,
167 input.bytes.len() as u64,
168 input.bytes.clone(),
169 );
170 match input.executor.timeout() {
171 None => {
172 let result = fut.await;
173 (input, result)
174 }
175 Some(timeout) => {
176 let result = select! {
177 result = fut.fuse() => {
178 result
179 }
180 _ = timeout.fuse() => {
181 Err(Error::new(
182 ErrorKind::Unexpected, "write part timeout")
183 .with_context("upload_id", input.upload_id.to_string())
184 .with_context("part_number", input.part_number.to_string())
185 .set_temporary())
186 }
187 };
188 (input, result)
189 }
190 }
191 }
192 })
193 }),
194 }
195 }
196
197 fn fill_cache(&mut self, bs: Buffer) -> usize {
198 let size = bs.len();
199 assert!(self.cache.is_none());
200 self.cache = Some(bs);
201 size
202 }
203}
204
205impl<W> oio::Write for MultipartWriter<W>
206where
207 W: MultipartWrite,
208{
209 async fn write(&mut self, bs: Buffer) -> Result<()> {
210 let upload_id = match self.upload_id.clone() {
211 Some(v) => v,
212 None => {
213 if self.cache.is_none() {
215 self.fill_cache(bs);
216 return Ok(());
217 }
218
219 let upload_id = self.w.initiate_part().await?;
220 let upload_id = Arc::new(upload_id);
221 self.upload_id = Some(upload_id.clone());
222 upload_id
223 }
224 };
225
226 let bytes = self.cache.clone().expect("pending write must exist");
227 let part_number = self.next_part_number;
228
229 self.tasks
230 .execute(WriteInput {
231 w: self.w.clone(),
232 executor: self.executor.clone(),
233 upload_id: upload_id.clone(),
234 part_number,
235 bytes,
236 })
237 .await?;
238 self.cache = None;
239 self.next_part_number += 1;
240 self.fill_cache(bs);
241 Ok(())
242 }
243
244 async fn close(&mut self) -> Result<Metadata> {
245 let upload_id = match self.upload_id.clone() {
246 Some(v) => v,
247 None => {
248 let (size, body) = match self.cache.clone() {
249 Some(cache) => (cache.len(), cache),
250 None => (0, Buffer::new()),
251 };
252
253 let meta = self.w.write_once(size as u64, body).await?;
255 self.cache = None;
257 return Ok(meta);
258 }
259 };
260
261 if let Some(cache) = self.cache.clone() {
262 let part_number = self.next_part_number;
263
264 self.tasks
265 .execute(WriteInput {
266 w: self.w.clone(),
267 executor: self.executor.clone(),
268 upload_id: upload_id.clone(),
269 part_number,
270 bytes: cache,
271 })
272 .await?;
273 self.cache = None;
274 self.next_part_number += 1;
275 }
276
277 loop {
278 let Some(result) = self.tasks.next().await.transpose()? else {
279 break;
280 };
281 self.parts.push(result)
282 }
283
284 if self.parts.len() != self.next_part_number {
285 return Err(Error::new(
286 ErrorKind::Unexpected,
287 "multipart part numbers mismatch, please report bug to opendal",
288 )
289 .with_context("expected", self.next_part_number)
290 .with_context("actual", self.parts.len())
291 .with_context("upload_id", upload_id));
292 }
293 self.w.complete_part(&upload_id, &self.parts).await
294 }
295
296 async fn abort(&mut self) -> Result<()> {
297 let Some(upload_id) = self.upload_id.clone() else {
298 return Ok(());
299 };
300
301 self.tasks.clear();
302 self.cache = None;
303 self.w.abort_part(&upload_id).await?;
304 Ok(())
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use mea::mutex::Mutex;
311 use pretty_assertions::assert_eq;
312 use rand::Rng;
313 use rand::RngCore;
314 use rand::thread_rng;
315 use tokio::time::sleep;
316 use tokio::time::timeout;
317
318 use super::*;
319 use crate::raw::oio::Write;
320
321 struct TestWrite {
322 upload_id: String,
323 part_numbers: Vec<usize>,
324 length: u64,
325 content: Option<Buffer>,
326 }
327
328 impl TestWrite {
329 pub fn new() -> Arc<Mutex<Self>> {
330 let v = Self {
331 upload_id: uuid::Uuid::new_v4().to_string(),
332 part_numbers: Vec::new(),
333 length: 0,
334 content: None,
335 };
336
337 Arc::new(Mutex::new(v))
338 }
339 }
340
341 impl MultipartWrite for Arc<Mutex<TestWrite>> {
342 async fn write_once(&self, size: u64, body: Buffer) -> Result<Metadata> {
343 sleep(Duration::from_nanos(50)).await;
344
345 if thread_rng().gen_bool(1.0 / 10.0) {
346 return Err(
347 Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!").set_temporary()
348 );
349 }
350
351 let mut this = self.lock().await;
352 this.length = size;
353 this.content = Some(body);
354 Ok(Metadata::default().with_content_length(size))
355 }
356
357 async fn initiate_part(&self) -> Result<String> {
358 let upload_id = self.lock().await.upload_id.clone();
359 Ok(upload_id)
360 }
361
362 async fn write_part(
363 &self,
364 upload_id: &str,
365 part_number: usize,
366 size: u64,
367 _: Buffer,
368 ) -> Result<MultipartPart> {
369 {
370 let test = self.lock().await;
371 assert_eq!(upload_id, test.upload_id);
372 }
373
374 sleep(Duration::from_nanos(50)).await;
376
377 if thread_rng().gen_bool(1.0 / 10.0) {
379 return Err(
380 Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!").set_temporary()
381 );
382 }
383
384 {
385 let mut test = self.lock().await;
386 test.part_numbers.push(part_number);
387 test.length += size;
388 }
389
390 Ok(MultipartPart {
391 part_number,
392 etag: "etag".to_string(),
393 checksum: None,
394 })
395 }
396
397 async fn complete_part(
398 &self,
399 upload_id: &str,
400 parts: &[MultipartPart],
401 ) -> Result<Metadata> {
402 let test = self.lock().await;
403 assert_eq!(upload_id, test.upload_id);
404 assert_eq!(parts.len(), test.part_numbers.len());
405
406 Ok(Metadata::default().with_content_length(test.length))
407 }
408
409 async fn abort_part(&self, upload_id: &str) -> Result<()> {
410 let test = self.lock().await;
411 assert_eq!(upload_id, test.upload_id);
412
413 Ok(())
414 }
415 }
416
417 struct TimeoutExecutor {
418 exec: Arc<dyn Execute>,
419 }
420
421 impl TimeoutExecutor {
422 pub fn new() -> Self {
423 Self {
424 exec: Executor::new().into_inner(),
425 }
426 }
427 }
428
429 impl Execute for TimeoutExecutor {
430 fn execute(&self, f: BoxedStaticFuture<()>) {
431 self.exec.execute(f)
432 }
433
434 fn timeout(&self) -> Option<BoxedStaticFuture<()>> {
435 let time = thread_rng().gen_range(0..100);
436 Some(Box::pin(tokio::time::sleep(Duration::from_nanos(time))))
437 }
438 }
439
440 #[tokio::test]
441 async fn test_multipart_upload_writer_with_concurrent_errors() {
442 let mut rng = thread_rng();
443
444 let info = Arc::new(AccessorInfo::default());
445 info.update_executor(|_| Executor::with(TimeoutExecutor::new()));
446
447 let mut w = MultipartWriter::new(info, TestWrite::new(), 200);
448 let mut total_size = 0u64;
449
450 for _ in 0..1000 {
451 let size = rng.gen_range(1..1024);
452 total_size += size as u64;
453
454 let mut bs = vec![0; size];
455 rng.fill_bytes(&mut bs);
456
457 loop {
458 match timeout(Duration::from_nanos(10), w.write(bs.clone().into())).await {
459 Ok(Ok(_)) => break,
460 Ok(Err(_)) => continue,
461 Err(_) => {
462 continue;
463 }
464 }
465 }
466 }
467
468 loop {
469 match timeout(Duration::from_nanos(10), w.close()).await {
470 Ok(Ok(_)) => break,
471 Ok(Err(_)) => continue,
472 Err(_) => {
473 continue;
474 }
475 }
476 }
477
478 let actual_parts: Vec<_> = w.parts.into_iter().map(|v| v.part_number).collect();
479 let expected_parts: Vec<_> = (0..1000).collect();
480 assert_eq!(actual_parts, expected_parts);
481
482 let actual_size = w.w.lock().await.length;
483 assert_eq!(actual_size, total_size);
484 }
485
486 #[tokio::test]
487 async fn test_multipart_writer_with_retry_when_write_once_error() {
488 let mut rng = thread_rng();
489
490 for _ in 0..100 {
491 let mut w = MultipartWriter::new(Arc::default(), TestWrite::new(), 200);
492 let size = rng.gen_range(1..1024);
493 let mut bs = vec![0; size];
494 rng.fill_bytes(&mut bs);
495
496 loop {
497 match w.write(bs.clone().into()).await {
498 Ok(_) => break,
499 Err(_) => continue,
500 }
501 }
502
503 loop {
504 match w.close().await {
505 Ok(_) => break,
506 Err(_) => continue,
507 }
508 }
509
510 let inner = w.w.lock().await;
511 assert_eq!(inner.length, size as u64);
512 assert!(inner.content.is_some());
513 assert_eq!(inner.content.clone().unwrap().to_bytes(), bs);
514 }
515 }
516}