opendal_core/raw/oio/write/
multipart_write.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use futures::Future;
21use futures::FutureExt;
22use futures::select;
23
24use crate::raw::*;
25use crate::*;
26
27/// MultipartWrite is used to implement [`oio::Write`] based on multipart
28/// uploads. By implementing MultipartWrite, services don't need to
29/// care about the details of uploading parts.
30///
31/// # Architecture
32///
33/// The architecture after adopting [`MultipartWrite`]:
34///
35/// - Services impl `MultipartWrite`
36/// - `MultipartWriter` impl `Write`
37/// - Expose `MultipartWriter` as `Accessor::Writer`
38///
39/// # Notes
40///
41/// `MultipartWrite` has an oneshot optimization when `write` has been called only once:
42///
43/// ```no_build
44/// w.write(bs).await?;
45/// w.close().await?;
46/// ```
47///
48/// We will use `write_once` instead of starting a new multipart upload.
49///
50/// # Requirements
51///
52/// Services that implement `BlockWrite` must fulfill the following requirements:
53///
54/// - Must be a http service that could accept `AsyncBody`.
55/// - Don't need initialization before writing.
56/// - Block ID is generated by caller `BlockWrite` instead of services.
57/// - Complete block by an ordered block id list.
58pub trait MultipartWrite: Send + Sync + Unpin + 'static {
59    /// write_once is used to write the data to underlying storage at once.
60    ///
61    /// MultipartWriter will call this API when:
62    ///
63    /// - All the data has been written to the buffer and we can perform the upload at once.
64    fn write_once(
65        &self,
66        size: u64,
67        body: Buffer,
68    ) -> impl Future<Output = Result<Metadata>> + MaybeSend;
69
70    /// initiate_part will call start a multipart upload and return the upload id.
71    ///
72    /// MultipartWriter will call this when:
73    ///
74    /// - the total size of data is unknown.
75    /// - the total size of data is known, but the size of current write
76    ///   is less than the total size.
77    fn initiate_part(&self) -> impl Future<Output = Result<String>> + MaybeSend;
78
79    /// write_part will write a part of the data and returns the result
80    /// [`MultipartPart`].
81    ///
82    /// MultipartWriter will call this API and stores the result in
83    /// order.
84    ///
85    /// - part_number is the index of the part, starting from 0.
86    fn write_part(
87        &self,
88        upload_id: &str,
89        part_number: usize,
90        size: u64,
91        body: Buffer,
92    ) -> impl Future<Output = Result<MultipartPart>> + MaybeSend;
93
94    /// complete_part will complete the multipart upload to build the final
95    /// file.
96    fn complete_part(
97        &self,
98        upload_id: &str,
99        parts: &[MultipartPart],
100    ) -> impl Future<Output = Result<Metadata>> + MaybeSend;
101
102    /// abort_part will cancel the multipart upload and purge all data.
103    fn abort_part(&self, upload_id: &str) -> impl Future<Output = Result<()>> + MaybeSend;
104}
105
106/// The result of [`MultipartWrite::write_part`].
107///
108/// services implement should convert MultipartPart to their own represents.
109///
110/// - `part_number` is the index of the part, starting from 0.
111/// - `etag` is the `ETag` of the part.
112/// - `checksum` is the optional checksum of the part.
113#[derive(Clone)]
114pub struct MultipartPart {
115    /// The number of the part, starting from 0.
116    pub part_number: usize,
117    /// The etag of the part.
118    pub etag: String,
119    /// The checksum of the part.
120    pub checksum: Option<String>,
121}
122
123struct WriteInput<W: MultipartWrite> {
124    w: Arc<W>,
125    executor: Executor,
126    upload_id: Arc<String>,
127    part_number: usize,
128    bytes: Buffer,
129}
130
131/// MultipartWriter will implement [`oio::Write`] based on multipart
132/// uploads.
133pub struct MultipartWriter<W: MultipartWrite> {
134    w: Arc<W>,
135    executor: Executor,
136
137    upload_id: Option<Arc<String>>,
138    parts: Vec<MultipartPart>,
139    cache: Option<Buffer>,
140    next_part_number: usize,
141
142    tasks: ConcurrentTasks<WriteInput<W>, MultipartPart>,
143}
144
145/// # Safety
146///
147/// wasm32 is a special target that we only have one event-loop for this state.
148impl<W: MultipartWrite> MultipartWriter<W> {
149    /// Create a new MultipartWriter.
150    pub fn new(info: Arc<AccessorInfo>, inner: W, concurrent: usize) -> Self {
151        let w = Arc::new(inner);
152        let executor = info.executor();
153        Self {
154            w,
155            executor: executor.clone(),
156            upload_id: None,
157            parts: Vec::new(),
158            cache: None,
159            next_part_number: 0,
160
161            tasks: ConcurrentTasks::new(executor, concurrent, 8192, |input| {
162                Box::pin({
163                    async move {
164                        let fut = input.w.write_part(
165                            &input.upload_id,
166                            input.part_number,
167                            input.bytes.len() as u64,
168                            input.bytes.clone(),
169                        );
170                        match input.executor.timeout() {
171                            None => {
172                                let result = fut.await;
173                                (input, result)
174                            }
175                            Some(timeout) => {
176                                let result = select! {
177                                    result = fut.fuse() => {
178                                        result
179                                    }
180                                    _ = timeout.fuse() => {
181                                        Err(Error::new(
182                                            ErrorKind::Unexpected, "write part timeout")
183                                                .with_context("upload_id", input.upload_id.to_string())
184                                                .with_context("part_number", input.part_number.to_string())
185                                                .set_temporary())
186                                    }
187                                };
188                                (input, result)
189                            }
190                        }
191                    }
192                })
193            }),
194        }
195    }
196
197    fn fill_cache(&mut self, bs: Buffer) -> usize {
198        let size = bs.len();
199        assert!(self.cache.is_none());
200        self.cache = Some(bs);
201        size
202    }
203}
204
205impl<W> oio::Write for MultipartWriter<W>
206where
207    W: MultipartWrite,
208{
209    async fn write(&mut self, bs: Buffer) -> Result<()> {
210        let upload_id = match self.upload_id.clone() {
211            Some(v) => v,
212            None => {
213                // Fill cache with the first write.
214                if self.cache.is_none() {
215                    self.fill_cache(bs);
216                    return Ok(());
217                }
218
219                let upload_id = self.w.initiate_part().await?;
220                let upload_id = Arc::new(upload_id);
221                self.upload_id = Some(upload_id.clone());
222                upload_id
223            }
224        };
225
226        let bytes = self.cache.clone().expect("pending write must exist");
227        let part_number = self.next_part_number;
228
229        self.tasks
230            .execute(WriteInput {
231                w: self.w.clone(),
232                executor: self.executor.clone(),
233                upload_id: upload_id.clone(),
234                part_number,
235                bytes,
236            })
237            .await?;
238        self.cache = None;
239        self.next_part_number += 1;
240        self.fill_cache(bs);
241        Ok(())
242    }
243
244    async fn close(&mut self) -> Result<Metadata> {
245        let upload_id = match self.upload_id.clone() {
246            Some(v) => v,
247            None => {
248                let (size, body) = match self.cache.clone() {
249                    Some(cache) => (cache.len(), cache),
250                    None => (0, Buffer::new()),
251                };
252
253                // Call write_once if there is no upload_id.
254                let meta = self.w.write_once(size as u64, body).await?;
255                // make sure to clear the cache only after write_once succeeds; otherwise, retries may fail.
256                self.cache = None;
257                return Ok(meta);
258            }
259        };
260
261        if let Some(cache) = self.cache.clone() {
262            let part_number = self.next_part_number;
263
264            self.tasks
265                .execute(WriteInput {
266                    w: self.w.clone(),
267                    executor: self.executor.clone(),
268                    upload_id: upload_id.clone(),
269                    part_number,
270                    bytes: cache,
271                })
272                .await?;
273            self.cache = None;
274            self.next_part_number += 1;
275        }
276
277        loop {
278            let Some(result) = self.tasks.next().await.transpose()? else {
279                break;
280            };
281            self.parts.push(result)
282        }
283
284        if self.parts.len() != self.next_part_number {
285            return Err(Error::new(
286                ErrorKind::Unexpected,
287                "multipart part numbers mismatch, please report bug to opendal",
288            )
289            .with_context("expected", self.next_part_number)
290            .with_context("actual", self.parts.len())
291            .with_context("upload_id", upload_id));
292        }
293        self.w.complete_part(&upload_id, &self.parts).await
294    }
295
296    async fn abort(&mut self) -> Result<()> {
297        let Some(upload_id) = self.upload_id.clone() else {
298            return Ok(());
299        };
300
301        self.tasks.clear();
302        self.cache = None;
303        self.w.abort_part(&upload_id).await?;
304        Ok(())
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use mea::mutex::Mutex;
311    use pretty_assertions::assert_eq;
312    use rand::Rng;
313    use rand::RngCore;
314    use rand::thread_rng;
315    use tokio::time::sleep;
316    use tokio::time::timeout;
317
318    use super::*;
319    use crate::raw::oio::Write;
320
321    struct TestWrite {
322        upload_id: String,
323        part_numbers: Vec<usize>,
324        length: u64,
325        content: Option<Buffer>,
326    }
327
328    impl TestWrite {
329        pub fn new() -> Arc<Mutex<Self>> {
330            let v = Self {
331                upload_id: uuid::Uuid::new_v4().to_string(),
332                part_numbers: Vec::new(),
333                length: 0,
334                content: None,
335            };
336
337            Arc::new(Mutex::new(v))
338        }
339    }
340
341    impl MultipartWrite for Arc<Mutex<TestWrite>> {
342        async fn write_once(&self, size: u64, body: Buffer) -> Result<Metadata> {
343            sleep(Duration::from_nanos(50)).await;
344
345            if thread_rng().gen_bool(1.0 / 10.0) {
346                return Err(
347                    Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!").set_temporary()
348                );
349            }
350
351            let mut this = self.lock().await;
352            this.length = size;
353            this.content = Some(body);
354            Ok(Metadata::default().with_content_length(size))
355        }
356
357        async fn initiate_part(&self) -> Result<String> {
358            let upload_id = self.lock().await.upload_id.clone();
359            Ok(upload_id)
360        }
361
362        async fn write_part(
363            &self,
364            upload_id: &str,
365            part_number: usize,
366            size: u64,
367            _: Buffer,
368        ) -> Result<MultipartPart> {
369            {
370                let test = self.lock().await;
371                assert_eq!(upload_id, test.upload_id);
372            }
373
374            // Add an async sleep here to enforce some pending.
375            sleep(Duration::from_nanos(50)).await;
376
377            // We will have 10% percent rate for write part to fail.
378            if thread_rng().gen_bool(1.0 / 10.0) {
379                return Err(
380                    Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!").set_temporary()
381                );
382            }
383
384            {
385                let mut test = self.lock().await;
386                test.part_numbers.push(part_number);
387                test.length += size;
388            }
389
390            Ok(MultipartPart {
391                part_number,
392                etag: "etag".to_string(),
393                checksum: None,
394            })
395        }
396
397        async fn complete_part(
398            &self,
399            upload_id: &str,
400            parts: &[MultipartPart],
401        ) -> Result<Metadata> {
402            let test = self.lock().await;
403            assert_eq!(upload_id, test.upload_id);
404            assert_eq!(parts.len(), test.part_numbers.len());
405
406            Ok(Metadata::default().with_content_length(test.length))
407        }
408
409        async fn abort_part(&self, upload_id: &str) -> Result<()> {
410            let test = self.lock().await;
411            assert_eq!(upload_id, test.upload_id);
412
413            Ok(())
414        }
415    }
416
417    struct TimeoutExecutor {
418        exec: Arc<dyn Execute>,
419    }
420
421    impl TimeoutExecutor {
422        pub fn new() -> Self {
423            Self {
424                exec: Executor::new().into_inner(),
425            }
426        }
427    }
428
429    impl Execute for TimeoutExecutor {
430        fn execute(&self, f: BoxedStaticFuture<()>) {
431            self.exec.execute(f)
432        }
433
434        fn timeout(&self) -> Option<BoxedStaticFuture<()>> {
435            let time = thread_rng().gen_range(0..100);
436            Some(Box::pin(tokio::time::sleep(Duration::from_nanos(time))))
437        }
438    }
439
440    #[tokio::test]
441    async fn test_multipart_upload_writer_with_concurrent_errors() {
442        let mut rng = thread_rng();
443
444        let info = Arc::new(AccessorInfo::default());
445        info.update_executor(|_| Executor::with(TimeoutExecutor::new()));
446
447        let mut w = MultipartWriter::new(info, TestWrite::new(), 200);
448        let mut total_size = 0u64;
449
450        for _ in 0..1000 {
451            let size = rng.gen_range(1..1024);
452            total_size += size as u64;
453
454            let mut bs = vec![0; size];
455            rng.fill_bytes(&mut bs);
456
457            loop {
458                match timeout(Duration::from_nanos(10), w.write(bs.clone().into())).await {
459                    Ok(Ok(_)) => break,
460                    Ok(Err(_)) => continue,
461                    Err(_) => {
462                        continue;
463                    }
464                }
465            }
466        }
467
468        loop {
469            match timeout(Duration::from_nanos(10), w.close()).await {
470                Ok(Ok(_)) => break,
471                Ok(Err(_)) => continue,
472                Err(_) => {
473                    continue;
474                }
475            }
476        }
477
478        let actual_parts: Vec<_> = w.parts.into_iter().map(|v| v.part_number).collect();
479        let expected_parts: Vec<_> = (0..1000).collect();
480        assert_eq!(actual_parts, expected_parts);
481
482        let actual_size = w.w.lock().await.length;
483        assert_eq!(actual_size, total_size);
484    }
485
486    #[tokio::test]
487    async fn test_multipart_writer_with_retry_when_write_once_error() {
488        let mut rng = thread_rng();
489
490        for _ in 0..100 {
491            let mut w = MultipartWriter::new(Arc::default(), TestWrite::new(), 200);
492            let size = rng.gen_range(1..1024);
493            let mut bs = vec![0; size];
494            rng.fill_bytes(&mut bs);
495
496            loop {
497                match w.write(bs.clone().into()).await {
498                    Ok(_) => break,
499                    Err(_) => continue,
500                }
501            }
502
503            loop {
504                match w.close().await {
505                    Ok(_) => break,
506                    Err(_) => continue,
507                }
508            }
509
510            let inner = w.w.lock().await;
511            assert_eq!(inner.length, size as u64);
512            assert!(inner.content.is_some());
513            assert_eq!(inner.content.clone().unwrap().to_bytes(), bs);
514        }
515    }
516}