opendal/types/context/
write.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use crate::raw::oio::Write;
21use crate::raw::*;
22use crate::*;
23
24/// WriteContext holds the immutable context for give write operation.
25pub struct WriteContext {
26    /// The accessor to the storage services.
27    acc: Accessor,
28    /// Path to the file.
29    path: String,
30    /// Arguments for the write operation.
31    args: OpWrite,
32    /// Options for the writer.
33    options: OpWriter,
34}
35
36impl WriteContext {
37    /// Create a new WriteContext.
38    #[inline]
39    pub fn new(acc: Accessor, path: String, args: OpWrite, options: OpWriter) -> Self {
40        Self {
41            acc,
42            path,
43            args,
44            options,
45        }
46    }
47
48    /// Get the accessor.
49    #[inline]
50    pub fn accessor(&self) -> &Accessor {
51        &self.acc
52    }
53
54    /// Get the path.
55    #[inline]
56    pub fn path(&self) -> &str {
57        &self.path
58    }
59
60    /// Get the arguments.
61    #[inline]
62    pub fn args(&self) -> &OpWrite {
63        &self.args
64    }
65
66    /// Get the options.
67    #[inline]
68    pub fn options(&self) -> &OpWriter {
69        &self.options
70    }
71
72    /// Calculate the chunk size for this write process.
73    ///
74    /// Returns the chunk size and if the chunk size is exact.
75    fn calculate_chunk_size(&self) -> (Option<usize>, bool) {
76        let cap = self.accessor().info().full_capability();
77
78        let exact = self.options().chunk().is_some();
79        let chunk_size = self
80            .options()
81            .chunk()
82            .or(cap.write_multi_min_size)
83            .map(|mut size| {
84                if let Some(v) = cap.write_multi_max_size {
85                    size = size.min(v);
86                }
87                if let Some(v) = cap.write_multi_min_size {
88                    size = size.max(v);
89                }
90
91                size
92            });
93
94        (chunk_size, exact)
95    }
96}
97
98pub struct WriteGenerator<W> {
99    w: W,
100
101    /// The size for buffer, we will flush the underlying storage at the size of this buffer.
102    chunk_size: Option<usize>,
103    /// If `exact` is true, the size of the data written to the underlying storage is
104    /// exactly `chunk_size` bytes.
105    exact: bool,
106    buffer: oio::QueueBuf,
107}
108
109impl WriteGenerator<oio::Writer> {
110    /// Create a new exact buf writer.
111    pub async fn create(ctx: Arc<WriteContext>) -> Result<Self> {
112        let (chunk_size, exact) = ctx.calculate_chunk_size();
113        let (_, w) = ctx.acc.write(ctx.path(), ctx.args().clone()).await?;
114
115        Ok(Self {
116            w,
117            chunk_size,
118            exact,
119            buffer: oio::QueueBuf::new(),
120        })
121    }
122
123    /// Allow building from existing oio::Writer for easier testing.
124    #[cfg(test)]
125    fn new(w: oio::Writer, chunk_size: Option<usize>, exact: bool) -> Self {
126        Self {
127            w,
128            chunk_size,
129            exact,
130            buffer: oio::QueueBuf::new(),
131        }
132    }
133}
134
135impl WriteGenerator<oio::Writer> {
136    /// Write the entire buffer into writer.
137    pub async fn write(&mut self, mut bs: Buffer) -> Result<usize> {
138        let Some(chunk_size) = self.chunk_size else {
139            let size = bs.len();
140            self.w.write_dyn(bs).await?;
141            return Ok(size);
142        };
143
144        if self.buffer.len() + bs.len() < chunk_size {
145            let size = bs.len();
146            self.buffer.push(bs);
147            return Ok(size);
148        }
149
150        // Condition:
151        // - exact is false
152        // - buffer + bs is larger than chunk_size.
153        // Action:
154        // - write buffer + bs directly.
155        if !self.exact {
156            let fill_size = bs.len();
157            self.buffer.push(bs);
158            let buf = self.buffer.take().collect();
159            self.w.write_dyn(buf).await?;
160            return Ok(fill_size);
161        }
162
163        // Condition:
164        // - exact is true: we need write buffer in exact chunk size.
165        // - buffer is larger than chunk_size
166        //   - in exact mode, the size must be chunk_size, use `>=` just for safe coding.
167        // Action:
168        // - write existing buffer in chunk_size to make more rooms for writing data.
169        if self.buffer.len() >= chunk_size {
170            let buf = self.buffer.take().collect();
171            self.w.write_dyn(buf).await?;
172        }
173
174        // Condition
175        // - exact is true.
176        // - buffer size must lower than chunk_size.
177        // Action:
178        // - write bs to buffer with remaining size.
179        let remaining = chunk_size - self.buffer.len();
180        bs.truncate(remaining);
181        let n = bs.len();
182        self.buffer.push(bs);
183        Ok(n)
184    }
185
186    /// Finish the write process.
187    pub async fn close(&mut self) -> Result<Metadata> {
188        loop {
189            if self.buffer.is_empty() {
190                break;
191            }
192
193            let buf = self.buffer.take().collect();
194            self.w.write_dyn(buf).await?;
195        }
196
197        self.w.close().await
198    }
199
200    /// Abort the write process.
201    pub async fn abort(&mut self) -> Result<()> {
202        self.buffer.clear();
203        self.w.abort().await
204    }
205}
206
207impl WriteGenerator<oio::BlockingWriter> {
208    /// Create a new exact buf writer.
209    pub fn blocking_create(ctx: Arc<WriteContext>) -> Result<Self> {
210        let (chunk_size, exact) = ctx.calculate_chunk_size();
211        let (_, w) = ctx.acc.blocking_write(ctx.path(), ctx.args().clone())?;
212
213        Ok(Self {
214            w,
215            chunk_size,
216            exact,
217            buffer: oio::QueueBuf::new(),
218        })
219    }
220}
221
222impl WriteGenerator<oio::BlockingWriter> {
223    /// Write the entire buffer into writer.
224    pub fn write(&mut self, mut bs: Buffer) -> Result<usize> {
225        let Some(chunk_size) = self.chunk_size else {
226            let size = bs.len();
227            self.w.write(bs)?;
228            return Ok(size);
229        };
230
231        if self.buffer.len() + bs.len() < chunk_size {
232            let size = bs.len();
233            self.buffer.push(bs);
234            return Ok(size);
235        }
236
237        // Condition:
238        // - exact is false
239        // - buffer + bs is larger than chunk_size.
240        // Action:
241        // - write buffer + bs directly.
242        if !self.exact {
243            let fill_size = bs.len();
244            self.buffer.push(bs);
245            let buf = self.buffer.take().collect();
246            self.w.write(buf)?;
247            return Ok(fill_size);
248        }
249
250        // Condition:
251        // - exact is true: we need write buffer in exact chunk size.
252        // - buffer is larger than chunk_size
253        //   - in exact mode, the size must be chunk_size, use `>=` just for safe coding.
254        // Action:
255        // - write existing buffer in chunk_size to make more rooms for writing data.
256        if self.buffer.len() >= chunk_size {
257            let buf = self.buffer.take().collect();
258            self.w.write(buf)?;
259        }
260
261        // Condition
262        // - exact is true.
263        // - buffer size must lower than chunk_size.
264        // Action:
265        // - write bs to buffer with remaining size.
266        let remaining = chunk_size - self.buffer.len();
267        bs.truncate(remaining);
268        let n = bs.len();
269        self.buffer.push(bs);
270        Ok(n)
271    }
272
273    /// Finish the write process.
274    pub fn close(&mut self) -> Result<Metadata> {
275        loop {
276            if self.buffer.is_empty() {
277                break;
278            }
279
280            let buf = self.buffer.take().collect();
281            self.w.write(buf)?;
282        }
283
284        self.w.close()
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use bytes::Buf;
291    use bytes::BufMut;
292    use bytes::Bytes;
293    use log::debug;
294    use pretty_assertions::assert_eq;
295    use rand::thread_rng;
296    use rand::Rng;
297    use rand::RngCore;
298    use sha2::Digest;
299    use sha2::Sha256;
300    use tokio::sync::Mutex;
301
302    use super::*;
303    use crate::raw::oio::Write;
304
305    struct MockWriter {
306        buf: Arc<Mutex<Vec<u8>>>,
307    }
308
309    impl Write for MockWriter {
310        async fn write(&mut self, bs: Buffer) -> Result<()> {
311            debug!("test_fuzz_exact_buf_writer: flush size: {}", &bs.len());
312
313            let mut buf = self.buf.lock().await;
314            buf.put(bs);
315            Ok(())
316        }
317
318        async fn close(&mut self) -> Result<Metadata> {
319            Ok(Metadata::default())
320        }
321
322        async fn abort(&mut self) -> Result<()> {
323            Ok(())
324        }
325    }
326
327    #[tokio::test]
328    async fn test_exact_buf_writer_short_write() -> Result<()> {
329        let _ = tracing_subscriber::fmt()
330            .pretty()
331            .with_test_writer()
332            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
333            .try_init();
334
335        let mut rng = thread_rng();
336        let mut expected = vec![0; 5];
337        rng.fill_bytes(&mut expected);
338
339        let buf = Arc::new(Mutex::new(vec![]));
340        let mut w = WriteGenerator::new(Box::new(MockWriter { buf: buf.clone() }), Some(10), true);
341
342        let mut bs = Bytes::from(expected.clone());
343        while !bs.is_empty() {
344            let n = w.write(bs.clone().into()).await?;
345            bs.advance(n);
346        }
347
348        w.close().await?;
349
350        let buf = buf.lock().await;
351        assert_eq!(buf.len(), expected.len());
352        assert_eq!(
353            format!("{:x}", Sha256::digest(&*buf)),
354            format!("{:x}", Sha256::digest(&expected))
355        );
356        Ok(())
357    }
358
359    #[tokio::test]
360    async fn test_inexact_buf_writer_large_write() -> Result<()> {
361        let _ = tracing_subscriber::fmt()
362            .pretty()
363            .with_test_writer()
364            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
365            .try_init();
366
367        let buf = Arc::new(Mutex::new(vec![]));
368        let mut w = WriteGenerator::new(Box::new(MockWriter { buf: buf.clone() }), Some(10), false);
369
370        let mut rng = thread_rng();
371        let mut expected = vec![0; 15];
372        rng.fill_bytes(&mut expected);
373
374        let bs = Bytes::from(expected.clone());
375        // The MockWriter always returns the first chunk size.
376        let n = w.write(bs.into()).await?;
377        assert_eq!(expected.len(), n);
378
379        w.close().await?;
380
381        let buf = buf.lock().await;
382        assert_eq!(buf.len(), expected.len());
383        assert_eq!(
384            format!("{:x}", Sha256::digest(&*buf)),
385            format!("{:x}", Sha256::digest(&expected))
386        );
387        Ok(())
388    }
389
390    #[tokio::test]
391    async fn test_inexact_buf_writer_combine_small() -> Result<()> {
392        let _ = tracing_subscriber::fmt()
393            .pretty()
394            .with_test_writer()
395            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
396            .try_init();
397
398        let buf = Arc::new(Mutex::new(vec![]));
399        let mut w = WriteGenerator::new(Box::new(MockWriter { buf: buf.clone() }), Some(10), false);
400
401        let mut rng = thread_rng();
402        let mut expected = vec![];
403
404        let mut new_content = |size| {
405            let mut content = vec![0; size];
406            rng.fill_bytes(&mut content);
407            expected.extend_from_slice(&content);
408            Bytes::from(content)
409        };
410
411        // content > chunk size.
412        let content = new_content(15);
413        assert_eq!(15, w.write(content.into()).await?);
414        // content < chunk size.
415        let content = new_content(5);
416        assert_eq!(5, w.write(content.into()).await?);
417        // content > chunk size, but 5 bytes in queue.
418        let content = new_content(15);
419        // The MockWriter can send all 15 bytes together, so we can only advance 5 bytes.
420        assert_eq!(15, w.write(content.clone().into()).await?);
421
422        w.close().await?;
423
424        let buf = buf.lock().await;
425        assert_eq!(buf.len(), expected.len());
426        assert_eq!(
427            format!("{:x}", Sha256::digest(&*buf)),
428            format!("{:x}", Sha256::digest(&expected))
429        );
430        Ok(())
431    }
432
433    #[tokio::test]
434    async fn test_inexact_buf_writer_queue_remaining() -> Result<()> {
435        let _ = tracing_subscriber::fmt()
436            .pretty()
437            .with_test_writer()
438            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
439            .try_init();
440
441        let buf = Arc::new(Mutex::new(vec![]));
442        let mut w = WriteGenerator::new(Box::new(MockWriter { buf: buf.clone() }), Some(10), false);
443
444        let mut rng = thread_rng();
445        let mut expected = vec![];
446
447        let mut new_content = |size| {
448            let mut content = vec![0; size];
449            rng.fill_bytes(&mut content);
450            expected.extend_from_slice(&content);
451            Bytes::from(content)
452        };
453
454        // content > chunk size.
455        let content = new_content(15);
456        assert_eq!(15, w.write(content.into()).await?);
457        // content < chunk size.
458        let content = new_content(5);
459        assert_eq!(5, w.write(content.into()).await?);
460        // content < chunk size.
461        let content = new_content(3);
462        assert_eq!(3, w.write(content.into()).await?);
463        // content > chunk size, but can send all chunks in the queue.
464        let content = new_content(15);
465        assert_eq!(15, w.write(content.clone().into()).await?);
466
467        w.close().await?;
468
469        let buf = buf.lock().await;
470        assert_eq!(buf.len(), expected.len());
471        assert_eq!(
472            format!("{:x}", Sha256::digest(&*buf)),
473            format!("{:x}", Sha256::digest(&expected))
474        );
475        Ok(())
476    }
477
478    #[tokio::test]
479    async fn test_fuzz_exact_buf_writer() -> Result<()> {
480        let _ = tracing_subscriber::fmt()
481            .pretty()
482            .with_test_writer()
483            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
484            .try_init();
485
486        let mut rng = thread_rng();
487        let mut expected = vec![];
488
489        let buf = Arc::new(Mutex::new(vec![]));
490        let buffer_size = rng.gen_range(1..10);
491        let mut writer = WriteGenerator::new(
492            Box::new(MockWriter { buf: buf.clone() }),
493            Some(buffer_size),
494            true,
495        );
496        debug!("test_fuzz_exact_buf_writer: buffer size: {buffer_size}");
497
498        for _ in 0..1000 {
499            let size = rng.gen_range(1..20);
500            debug!("test_fuzz_exact_buf_writer: write size: {size}");
501            let mut content = vec![0; size];
502            rng.fill_bytes(&mut content);
503
504            expected.extend_from_slice(&content);
505
506            let mut bs = Bytes::from(content.clone());
507            while !bs.is_empty() {
508                let n = writer.write(bs.clone().into()).await?;
509                bs.advance(n);
510            }
511        }
512        writer.close().await?;
513
514        let buf = buf.lock().await;
515        assert_eq!(buf.len(), expected.len());
516        assert_eq!(
517            format!("{:x}", Sha256::digest(&*buf)),
518            format!("{:x}", Sha256::digest(&expected))
519        );
520        Ok(())
521    }
522}