opendal_core/raw/oio/write/
multipart_write.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use futures::Future;
21use futures::FutureExt;
22use futures::select;
23
24use crate::raw::*;
25use crate::*;
26
27/// MultipartWrite is used to implement [`oio::Write`] based on multipart
28/// uploads. By implementing MultipartWrite, services don't need to
29/// care about the details of uploading parts.
30///
31/// # Architecture
32///
33/// The architecture after adopting [`MultipartWrite`]:
34///
35/// - Services impl `MultipartWrite`
36/// - `MultipartWriter` impl `Write`
37/// - Expose `MultipartWriter` as `Accessor::Writer`
38///
39/// # Notes
40///
41/// `MultipartWrite` has an oneshot optimization when `write` has been called only once:
42///
43/// ```no_build
44/// w.write(bs).await?;
45/// w.close().await?;
46/// ```
47///
48/// We will use `write_once` instead of starting a new multipart upload.
49///
50/// # Requirements
51///
52/// Services that implement `MultipartWrite` must fulfill the following requirements:
53///
54/// - Must be a http service that could accept `AsyncBody`.
55/// - Don't need initialization before writing.
56/// - Upload ID is generated by the service via `initiate_part`.
57/// - Complete upload by an ordered part list.
58pub trait MultipartWrite: Send + Sync + Unpin + 'static {
59    /// write_once is used to write the data to underlying storage at once.
60    ///
61    /// MultipartWriter will call this API when:
62    ///
63    /// - All the data has been written to the buffer and we can perform the upload at once.
64    fn write_once(
65        &self,
66        size: u64,
67        body: Buffer,
68    ) -> impl Future<Output = Result<Metadata>> + MaybeSend;
69
70    /// initiate_part will call start a multipart upload and return the upload id.
71    ///
72    /// MultipartWriter will call this when:
73    ///
74    /// - the total size of data is unknown.
75    /// - the total size of data is known, but the size of current write
76    ///   is less than the total size.
77    fn initiate_part(&self) -> impl Future<Output = Result<String>> + MaybeSend;
78
79    /// write_part will write a part of the data and returns the result
80    /// [`MultipartPart`].
81    ///
82    /// MultipartWriter will call this API and stores the result in
83    /// order.
84    ///
85    /// - part_number is the index of the part, starting from 0.
86    fn write_part(
87        &self,
88        upload_id: &str,
89        part_number: usize,
90        size: u64,
91        body: Buffer,
92    ) -> impl Future<Output = Result<MultipartPart>> + MaybeSend;
93
94    /// complete_part will complete the multipart upload to build the final
95    /// file.
96    fn complete_part(
97        &self,
98        upload_id: &str,
99        parts: &[MultipartPart],
100    ) -> impl Future<Output = Result<Metadata>> + MaybeSend;
101
102    /// abort_part will cancel the multipart upload and purge all data.
103    fn abort_part(&self, upload_id: &str) -> impl Future<Output = Result<()>> + MaybeSend;
104}
105
106/// The result of [`MultipartWrite::write_part`].
107///
108/// services implement should convert MultipartPart to their own represents.
109///
110/// - `part_number` is the index of the part, starting from 0.
111/// - `etag` is the `ETag` of the part.
112/// - `checksum` is the optional checksum of the part.
113#[derive(Clone)]
114pub struct MultipartPart {
115    /// The number of the part, starting from 0.
116    pub part_number: usize,
117    /// The etag of the part.
118    pub etag: String,
119    /// The checksum of the part.
120    pub checksum: Option<String>,
121    /// The size of the part in bytes.
122    pub size: Option<u64>,
123}
124
125struct WriteInput<W: MultipartWrite> {
126    w: Arc<W>,
127    executor: Executor,
128    upload_id: Arc<String>,
129    part_number: usize,
130    bytes: Buffer,
131}
132
133/// MultipartWriter will implement [`oio::Write`] based on multipart
134/// uploads.
135pub struct MultipartWriter<W: MultipartWrite> {
136    w: Arc<W>,
137    executor: Executor,
138
139    upload_id: Option<Arc<String>>,
140    parts: Vec<MultipartPart>,
141    cache: Option<Buffer>,
142    next_part_number: usize,
143
144    tasks: ConcurrentTasks<WriteInput<W>, MultipartPart>,
145}
146
147/// # Safety
148///
149/// wasm32 is a special target that we only have one event-loop for this state.
150impl<W: MultipartWrite> MultipartWriter<W> {
151    /// Create a new MultipartWriter.
152    pub fn new(info: Arc<AccessorInfo>, inner: W, concurrent: usize) -> Self {
153        let w = Arc::new(inner);
154        let executor = info.executor();
155        Self {
156            w,
157            executor: executor.clone(),
158            upload_id: None,
159            parts: Vec::new(),
160            cache: None,
161            next_part_number: 0,
162
163            tasks: ConcurrentTasks::new(executor, concurrent, 8192, |input| {
164                Box::pin({
165                    async move {
166                        let fut = input.w.write_part(
167                            &input.upload_id,
168                            input.part_number,
169                            input.bytes.len() as u64,
170                            input.bytes.clone(),
171                        );
172                        match input.executor.timeout() {
173                            None => {
174                                let result = fut.await;
175                                (input, result)
176                            }
177                            Some(timeout) => {
178                                let result = select! {
179                                    result = fut.fuse() => {
180                                        result
181                                    }
182                                    _ = timeout.fuse() => {
183                                        Err(Error::new(
184                                            ErrorKind::Unexpected, "write part timeout")
185                                                .with_context("upload_id", input.upload_id.to_string())
186                                                .with_context("part_number", input.part_number.to_string())
187                                                .set_temporary())
188                                    }
189                                };
190                                (input, result)
191                            }
192                        }
193                    }
194                })
195            }),
196        }
197    }
198
199    fn fill_cache(&mut self, bs: Buffer) -> usize {
200        let size = bs.len();
201        assert!(self.cache.is_none());
202        self.cache = Some(bs);
203        size
204    }
205}
206
207impl<W> oio::Write for MultipartWriter<W>
208where
209    W: MultipartWrite,
210{
211    async fn write(&mut self, bs: Buffer) -> Result<()> {
212        let upload_id = match self.upload_id.clone() {
213            Some(v) => v,
214            None => {
215                // Fill cache with the first write.
216                if self.cache.is_none() {
217                    self.fill_cache(bs);
218                    return Ok(());
219                }
220
221                let upload_id = self.w.initiate_part().await?;
222                let upload_id = Arc::new(upload_id);
223                self.upload_id = Some(upload_id.clone());
224                upload_id
225            }
226        };
227
228        let bytes = self.cache.clone().expect("pending write must exist");
229        let part_number = self.next_part_number;
230
231        self.tasks
232            .execute(WriteInput {
233                w: self.w.clone(),
234                executor: self.executor.clone(),
235                upload_id: upload_id.clone(),
236                part_number,
237                bytes,
238            })
239            .await?;
240        self.cache = None;
241        self.next_part_number += 1;
242        self.fill_cache(bs);
243        Ok(())
244    }
245
246    async fn close(&mut self) -> Result<Metadata> {
247        let upload_id = match self.upload_id.clone() {
248            Some(v) => v,
249            None => {
250                let (size, body) = match self.cache.clone() {
251                    Some(cache) => (cache.len(), cache),
252                    None => (0, Buffer::new()),
253                };
254
255                // Call write_once if there is no upload_id.
256                let meta = self.w.write_once(size as u64, body).await?;
257                // make sure to clear the cache only after write_once succeeds; otherwise, retries may fail.
258                self.cache = None;
259                return Ok(meta);
260            }
261        };
262
263        if let Some(cache) = self.cache.clone() {
264            let part_number = self.next_part_number;
265
266            self.tasks
267                .execute(WriteInput {
268                    w: self.w.clone(),
269                    executor: self.executor.clone(),
270                    upload_id: upload_id.clone(),
271                    part_number,
272                    bytes: cache,
273                })
274                .await?;
275            self.cache = None;
276            self.next_part_number += 1;
277        }
278
279        loop {
280            let Some(result) = self.tasks.next().await.transpose()? else {
281                break;
282            };
283            self.parts.push(result)
284        }
285
286        if self.parts.len() != self.next_part_number {
287            return Err(Error::new(
288                ErrorKind::Unexpected,
289                "multipart part numbers mismatch, please report bug to opendal",
290            )
291            .with_context("expected", self.next_part_number)
292            .with_context("actual", self.parts.len())
293            .with_context("upload_id", upload_id));
294        }
295        self.w.complete_part(&upload_id, &self.parts).await
296    }
297
298    async fn abort(&mut self) -> Result<()> {
299        let Some(upload_id) = self.upload_id.clone() else {
300            return Ok(());
301        };
302
303        self.tasks.clear();
304        self.cache = None;
305        self.w.abort_part(&upload_id).await?;
306        Ok(())
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use mea::mutex::Mutex;
313    use pretty_assertions::assert_eq;
314    use rand::Rng;
315    use rand::RngCore;
316    use rand::thread_rng;
317    use tokio::time::sleep;
318    use tokio::time::timeout;
319
320    use super::*;
321    use crate::raw::oio::Write;
322
323    struct TestWrite {
324        upload_id: String,
325        part_numbers: Vec<usize>,
326        length: u64,
327        content: Option<Buffer>,
328    }
329
330    impl TestWrite {
331        pub fn new() -> Arc<Mutex<Self>> {
332            let v = Self {
333                upload_id: uuid::Uuid::new_v4().to_string(),
334                part_numbers: Vec::new(),
335                length: 0,
336                content: None,
337            };
338
339            Arc::new(Mutex::new(v))
340        }
341    }
342
343    impl MultipartWrite for Arc<Mutex<TestWrite>> {
344        async fn write_once(&self, size: u64, body: Buffer) -> Result<Metadata> {
345            sleep(Duration::from_nanos(50)).await;
346
347            if thread_rng().gen_bool(1.0 / 10.0) {
348                return Err(
349                    Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!").set_temporary()
350                );
351            }
352
353            let mut this = self.lock().await;
354            this.length = size;
355            this.content = Some(body);
356            Ok(Metadata::default().with_content_length(size))
357        }
358
359        async fn initiate_part(&self) -> Result<String> {
360            let upload_id = self.lock().await.upload_id.clone();
361            Ok(upload_id)
362        }
363
364        async fn write_part(
365            &self,
366            upload_id: &str,
367            part_number: usize,
368            size: u64,
369            _: Buffer,
370        ) -> Result<MultipartPart> {
371            {
372                let test = self.lock().await;
373                assert_eq!(upload_id, test.upload_id);
374            }
375
376            // Add an async sleep here to enforce some pending.
377            sleep(Duration::from_nanos(50)).await;
378
379            // We will have 10% percent rate for write part to fail.
380            if thread_rng().gen_bool(1.0 / 10.0) {
381                return Err(
382                    Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!").set_temporary()
383                );
384            }
385
386            {
387                let mut test = self.lock().await;
388                test.part_numbers.push(part_number);
389                test.length += size;
390            }
391
392            Ok(MultipartPart {
393                part_number,
394                etag: "etag".to_string(),
395                checksum: None,
396                size: None,
397            })
398        }
399
400        async fn complete_part(
401            &self,
402            upload_id: &str,
403            parts: &[MultipartPart],
404        ) -> Result<Metadata> {
405            let test = self.lock().await;
406            assert_eq!(upload_id, test.upload_id);
407            assert_eq!(parts.len(), test.part_numbers.len());
408
409            Ok(Metadata::default().with_content_length(test.length))
410        }
411
412        async fn abort_part(&self, upload_id: &str) -> Result<()> {
413            let test = self.lock().await;
414            assert_eq!(upload_id, test.upload_id);
415
416            Ok(())
417        }
418    }
419
420    struct TimeoutExecutor {
421        exec: Arc<dyn Execute>,
422    }
423
424    impl TimeoutExecutor {
425        pub fn new() -> Self {
426            Self {
427                exec: Executor::new().into_inner(),
428            }
429        }
430    }
431
432    impl Execute for TimeoutExecutor {
433        fn execute(&self, f: BoxedStaticFuture<()>) {
434            self.exec.execute(f)
435        }
436
437        fn timeout(&self) -> Option<BoxedStaticFuture<()>> {
438            let time = thread_rng().gen_range(0..100);
439            Some(Box::pin(tokio::time::sleep(Duration::from_nanos(time))))
440        }
441    }
442
443    #[tokio::test]
444    async fn test_multipart_upload_writer_with_concurrent_errors() {
445        let mut rng = thread_rng();
446
447        let info = Arc::new(AccessorInfo::default());
448        info.update_executor(|_| Executor::with(TimeoutExecutor::new()));
449
450        let mut w = MultipartWriter::new(info, TestWrite::new(), 200);
451        let mut total_size = 0u64;
452
453        for _ in 0..1000 {
454            let size = rng.gen_range(1..1024);
455            total_size += size as u64;
456
457            let mut bs = vec![0; size];
458            rng.fill_bytes(&mut bs);
459
460            loop {
461                match timeout(Duration::from_nanos(10), w.write(bs.clone().into())).await {
462                    Ok(Ok(_)) => break,
463                    Ok(Err(_)) => continue,
464                    Err(_) => {
465                        continue;
466                    }
467                }
468            }
469        }
470
471        loop {
472            match timeout(Duration::from_nanos(10), w.close()).await {
473                Ok(Ok(_)) => break,
474                Ok(Err(_)) => continue,
475                Err(_) => {
476                    continue;
477                }
478            }
479        }
480
481        let actual_parts: Vec<_> = w.parts.into_iter().map(|v| v.part_number).collect();
482        let expected_parts: Vec<_> = (0..1000).collect();
483        assert_eq!(actual_parts, expected_parts);
484
485        let actual_size = w.w.lock().await.length;
486        assert_eq!(actual_size, total_size);
487    }
488
489    #[tokio::test]
490    async fn test_multipart_writer_with_retry_when_write_once_error() {
491        let mut rng = thread_rng();
492
493        for _ in 0..100 {
494            let mut w = MultipartWriter::new(Arc::default(), TestWrite::new(), 200);
495            let size = rng.gen_range(1..1024);
496            let mut bs = vec![0; size];
497            rng.fill_bytes(&mut bs);
498
499            loop {
500                match w.write(bs.clone().into()).await {
501                    Ok(_) => break,
502                    Err(_) => continue,
503                }
504            }
505
506            loop {
507                match w.close().await {
508                    Ok(_) => break,
509                    Err(_) => continue,
510                }
511            }
512
513            let inner = w.w.lock().await;
514            assert_eq!(inner.length, size as u64);
515            assert!(inner.content.is_some());
516            assert_eq!(inner.content.clone().unwrap().to_bytes(), bs);
517        }
518    }
519}