opendal/raw/oio/write/
multipart_write.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use futures::select;
21use futures::Future;
22use futures::FutureExt;
23
24use crate::raw::*;
25use crate::*;
26
27/// MultipartWrite is used to implement [`oio::Write`] based on multipart
28/// uploads. By implementing MultipartWrite, services don't need to
29/// care about the details of uploading parts.
30///
31/// # Architecture
32///
33/// The architecture after adopting [`MultipartWrite`]:
34///
35/// - Services impl `MultipartWrite`
36/// - `MultipartWriter` impl `Write`
37/// - Expose `MultipartWriter` as `Accessor::Writer`
38///
39/// # Notes
40///
41/// `MultipartWrite` has an oneshot optimization when `write` has been called only once:
42///
43/// ```no_build
44/// w.write(bs).await?;
45/// w.close().await?;
46/// ```
47///
48/// We will use `write_once` instead of starting a new multipart upload.
49///
50/// # Requirements
51///
52/// Services that implement `BlockWrite` must fulfill the following requirements:
53///
54/// - Must be a http service that could accept `AsyncBody`.
55/// - Don't need initialization before writing.
56/// - Block ID is generated by caller `BlockWrite` instead of services.
57/// - Complete block by an ordered block id list.
58pub trait MultipartWrite: Send + Sync + Unpin + 'static {
59    /// write_once is used to write the data to underlying storage at once.
60    ///
61    /// MultipartWriter will call this API when:
62    ///
63    /// - All the data has been written to the buffer and we can perform the upload at once.
64    fn write_once(
65        &self,
66        size: u64,
67        body: Buffer,
68    ) -> impl Future<Output = Result<Metadata>> + MaybeSend;
69
70    /// initiate_part will call start a multipart upload and return the upload id.
71    ///
72    /// MultipartWriter will call this when:
73    ///
74    /// - the total size of data is unknown.
75    /// - the total size of data is known, but the size of current write
76    ///   is less than the total size.
77    fn initiate_part(&self) -> impl Future<Output = Result<String>> + MaybeSend;
78
79    /// write_part will write a part of the data and returns the result
80    /// [`MultipartPart`].
81    ///
82    /// MultipartWriter will call this API and stores the result in
83    /// order.
84    ///
85    /// - part_number is the index of the part, starting from 0.
86    fn write_part(
87        &self,
88        upload_id: &str,
89        part_number: usize,
90        size: u64,
91        body: Buffer,
92    ) -> impl Future<Output = Result<MultipartPart>> + MaybeSend;
93
94    /// complete_part will complete the multipart upload to build the final
95    /// file.
96    fn complete_part(
97        &self,
98        upload_id: &str,
99        parts: &[MultipartPart],
100    ) -> impl Future<Output = Result<Metadata>> + MaybeSend;
101
102    /// abort_part will cancel the multipart upload and purge all data.
103    fn abort_part(&self, upload_id: &str) -> impl Future<Output = Result<()>> + MaybeSend;
104}
105
106/// The result of [`MultipartWrite::write_part`].
107///
108/// services implement should convert MultipartPart to their own represents.
109///
110/// - `part_number` is the index of the part, starting from 0.
111/// - `etag` is the `ETag` of the part.
112/// - `checksum` is the optional checksum of the part.
113#[derive(Clone)]
114pub struct MultipartPart {
115    /// The number of the part, starting from 0.
116    pub part_number: usize,
117    /// The etag of the part.
118    pub etag: String,
119    /// The checksum of the part.
120    pub checksum: Option<String>,
121}
122
123struct WriteInput<W: MultipartWrite> {
124    w: Arc<W>,
125    executor: Executor,
126    upload_id: Arc<String>,
127    part_number: usize,
128    bytes: Buffer,
129}
130
131/// MultipartWriter will implement [`oio::Write`] based on multipart
132/// uploads.
133pub struct MultipartWriter<W: MultipartWrite> {
134    w: Arc<W>,
135    executor: Executor,
136
137    upload_id: Option<Arc<String>>,
138    parts: Vec<MultipartPart>,
139    cache: Option<Buffer>,
140    next_part_number: usize,
141
142    tasks: ConcurrentTasks<WriteInput<W>, MultipartPart>,
143}
144
145/// # Safety
146///
147/// wasm32 is a special target that we only have one event-loop for this state.
148impl<W: MultipartWrite> MultipartWriter<W> {
149    /// Create a new MultipartWriter.
150    pub fn new(info: Arc<AccessorInfo>, inner: W, concurrent: usize) -> Self {
151        let w = Arc::new(inner);
152        let executor = info.executor();
153        Self {
154            w,
155            executor: executor.clone(),
156            upload_id: None,
157            parts: Vec::new(),
158            cache: None,
159            next_part_number: 0,
160
161            tasks: ConcurrentTasks::new(executor, concurrent, |input| {
162                Box::pin({
163                    async move {
164                        let fut = input.w.write_part(
165                            &input.upload_id,
166                            input.part_number,
167                            input.bytes.len() as u64,
168                            input.bytes.clone(),
169                        );
170                        match input.executor.timeout() {
171                            None => {
172                                let result = fut.await;
173                                (input, result)
174                            }
175                            Some(timeout) => {
176                                let result = select! {
177                                    result = fut.fuse() => {
178                                        result
179                                    }
180                                    _ = timeout.fuse() => {
181                                        Err(Error::new(
182                                            ErrorKind::Unexpected, "write part timeout")
183                                                .with_context("upload_id", input.upload_id.to_string())
184                                                .with_context("part_number", input.part_number.to_string())
185                                                .set_temporary())
186                                    }
187                                };
188                                (input, result)
189                            }
190                        }
191                    }
192                })
193            }),
194        }
195    }
196
197    fn fill_cache(&mut self, bs: Buffer) -> usize {
198        let size = bs.len();
199        assert!(self.cache.is_none());
200        self.cache = Some(bs);
201        size
202    }
203}
204
205impl<W> oio::Write for MultipartWriter<W>
206where
207    W: MultipartWrite,
208{
209    async fn write(&mut self, bs: Buffer) -> Result<()> {
210        let upload_id = match self.upload_id.clone() {
211            Some(v) => v,
212            None => {
213                // Fill cache with the first write.
214                if self.cache.is_none() {
215                    self.fill_cache(bs);
216                    return Ok(());
217                }
218
219                let upload_id = self.w.initiate_part().await?;
220                let upload_id = Arc::new(upload_id);
221                self.upload_id = Some(upload_id.clone());
222                upload_id
223            }
224        };
225
226        let bytes = self.cache.clone().expect("pending write must exist");
227        let part_number = self.next_part_number;
228
229        self.tasks
230            .execute(WriteInput {
231                w: self.w.clone(),
232                executor: self.executor.clone(),
233                upload_id: upload_id.clone(),
234                part_number,
235                bytes,
236            })
237            .await?;
238        self.cache = None;
239        self.next_part_number += 1;
240        self.fill_cache(bs);
241        Ok(())
242    }
243
244    async fn close(&mut self) -> Result<Metadata> {
245        let upload_id = match self.upload_id.clone() {
246            Some(v) => v,
247            None => {
248                let (size, body) = match self.cache.clone() {
249                    Some(cache) => (cache.len(), cache),
250                    None => (0, Buffer::new()),
251                };
252
253                // Call write_once if there is no upload_id.
254                let meta = self.w.write_once(size as u64, body).await?;
255                // make sure to clear the cache only after write_once succeeds; otherwise, retries may fail.
256                self.cache = None;
257                return Ok(meta);
258            }
259        };
260
261        if let Some(cache) = self.cache.clone() {
262            let part_number = self.next_part_number;
263
264            self.tasks
265                .execute(WriteInput {
266                    w: self.w.clone(),
267                    executor: self.executor.clone(),
268                    upload_id: upload_id.clone(),
269                    part_number,
270                    bytes: cache,
271                })
272                .await?;
273            self.cache = None;
274            self.next_part_number += 1;
275        }
276
277        loop {
278            let Some(result) = self.tasks.next().await.transpose()? else {
279                break;
280            };
281            self.parts.push(result)
282        }
283
284        if self.parts.len() != self.next_part_number {
285            return Err(Error::new(
286                ErrorKind::Unexpected,
287                "multipart part numbers mismatch, please report bug to opendal",
288            )
289            .with_context("expected", self.next_part_number)
290            .with_context("actual", self.parts.len())
291            .with_context("upload_id", upload_id));
292        }
293        self.w.complete_part(&upload_id, &self.parts).await
294    }
295
296    async fn abort(&mut self) -> Result<()> {
297        let Some(upload_id) = self.upload_id.clone() else {
298            return Ok(());
299        };
300
301        self.tasks.clear();
302        self.cache = None;
303        self.w.abort_part(&upload_id).await?;
304        Ok(())
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use std::time::Duration;
311
312    use pretty_assertions::assert_eq;
313    use rand::thread_rng;
314    use rand::Rng;
315    use rand::RngCore;
316    use tokio::sync::Mutex;
317    use tokio::time::sleep;
318    use tokio::time::timeout;
319
320    use super::*;
321    use crate::raw::oio::Write;
322
323    struct TestWrite {
324        upload_id: String,
325        part_numbers: Vec<usize>,
326        length: u64,
327        content: Option<Buffer>,
328    }
329
330    impl TestWrite {
331        pub fn new() -> Arc<Mutex<Self>> {
332            let v = Self {
333                upload_id: uuid::Uuid::new_v4().to_string(),
334                part_numbers: Vec::new(),
335                length: 0,
336                content: None,
337            };
338
339            Arc::new(Mutex::new(v))
340        }
341    }
342
343    impl MultipartWrite for Arc<Mutex<TestWrite>> {
344        async fn write_once(&self, size: u64, body: Buffer) -> Result<Metadata> {
345            sleep(Duration::from_nanos(50)).await;
346
347            if thread_rng().gen_bool(1.0 / 10.0) {
348                return Err(
349                    Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!").set_temporary()
350                );
351            }
352
353            let mut this = self.lock().await;
354            this.length = size;
355            this.content = Some(body);
356            Ok(Metadata::default().with_content_length(size))
357        }
358
359        async fn initiate_part(&self) -> Result<String> {
360            let upload_id = self.lock().await.upload_id.clone();
361            Ok(upload_id)
362        }
363
364        async fn write_part(
365            &self,
366            upload_id: &str,
367            part_number: usize,
368            size: u64,
369            _: Buffer,
370        ) -> Result<MultipartPart> {
371            {
372                let test = self.lock().await;
373                assert_eq!(upload_id, test.upload_id);
374            }
375
376            // Add an async sleep here to enforce some pending.
377            sleep(Duration::from_nanos(50)).await;
378
379            // We will have 10% percent rate for write part to fail.
380            if thread_rng().gen_bool(1.0 / 10.0) {
381                return Err(
382                    Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!").set_temporary()
383                );
384            }
385
386            {
387                let mut test = self.lock().await;
388                test.part_numbers.push(part_number);
389                test.length += size;
390            }
391
392            Ok(MultipartPart {
393                part_number,
394                etag: "etag".to_string(),
395                checksum: None,
396            })
397        }
398
399        async fn complete_part(
400            &self,
401            upload_id: &str,
402            parts: &[MultipartPart],
403        ) -> Result<Metadata> {
404            let test = self.lock().await;
405            assert_eq!(upload_id, test.upload_id);
406            assert_eq!(parts.len(), test.part_numbers.len());
407
408            Ok(Metadata::default().with_content_length(test.length))
409        }
410
411        async fn abort_part(&self, upload_id: &str) -> Result<()> {
412            let test = self.lock().await;
413            assert_eq!(upload_id, test.upload_id);
414
415            Ok(())
416        }
417    }
418
419    struct TimeoutExecutor {
420        exec: Arc<dyn Execute>,
421    }
422
423    impl TimeoutExecutor {
424        pub fn new() -> Self {
425            Self {
426                exec: Executor::new().into_inner(),
427            }
428        }
429    }
430
431    impl Execute for TimeoutExecutor {
432        fn execute(&self, f: BoxedStaticFuture<()>) {
433            self.exec.execute(f)
434        }
435
436        fn timeout(&self) -> Option<BoxedStaticFuture<()>> {
437            let time = thread_rng().gen_range(0..100);
438            Some(Box::pin(tokio::time::sleep(Duration::from_nanos(time))))
439        }
440    }
441
442    #[tokio::test]
443    async fn test_multipart_upload_writer_with_concurrent_errors() {
444        let mut rng = thread_rng();
445
446        let info = Arc::new(AccessorInfo::default());
447        info.update_executor(|_| Executor::with(TimeoutExecutor::new()));
448
449        let mut w = MultipartWriter::new(info, TestWrite::new(), 200);
450        let mut total_size = 0u64;
451
452        for _ in 0..1000 {
453            let size = rng.gen_range(1..1024);
454            total_size += size as u64;
455
456            let mut bs = vec![0; size];
457            rng.fill_bytes(&mut bs);
458
459            loop {
460                match timeout(Duration::from_nanos(10), w.write(bs.clone().into())).await {
461                    Ok(Ok(_)) => break,
462                    Ok(Err(_)) => continue,
463                    Err(_) => {
464                        continue;
465                    }
466                }
467            }
468        }
469
470        loop {
471            match timeout(Duration::from_nanos(10), w.close()).await {
472                Ok(Ok(_)) => break,
473                Ok(Err(_)) => continue,
474                Err(_) => {
475                    continue;
476                }
477            }
478        }
479
480        let actual_parts: Vec<_> = w.parts.into_iter().map(|v| v.part_number).collect();
481        let expected_parts: Vec<_> = (0..1000).collect();
482        assert_eq!(actual_parts, expected_parts);
483
484        let actual_size = w.w.lock().await.length;
485        assert_eq!(actual_size, total_size);
486    }
487
488    #[tokio::test]
489    async fn test_multipart_writer_with_retry_when_write_once_error() {
490        let mut rng = thread_rng();
491
492        for _ in 0..100 {
493            let mut w = MultipartWriter::new(Arc::default(), TestWrite::new(), 200);
494            let size = rng.gen_range(1..1024);
495            let mut bs = vec![0; size];
496            rng.fill_bytes(&mut bs);
497
498            loop {
499                match w.write(bs.clone().into()).await {
500                    Ok(_) => break,
501                    Err(_) => continue,
502                }
503            }
504
505            loop {
506                match w.close().await {
507                    Ok(_) => break,
508                    Err(_) => continue,
509                }
510            }
511
512            let inner = w.w.lock().await;
513            assert_eq!(inner.length, size as u64);
514            assert!(inner.content.is_some());
515            assert_eq!(inner.content.clone().unwrap().to_bytes(), bs);
516        }
517    }
518}