opendal/types/context/
write.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use crate::raw::oio::Write;
21use crate::raw::*;
22use crate::*;
23
24/// WriteContext holds the immutable context for give write operation.
25pub struct WriteContext {
26    /// The accessor to the storage services.
27    acc: Accessor,
28    /// Path to the file.
29    path: String,
30    /// Arguments for the write operation.
31    args: OpWrite,
32    /// Options for the writer.
33    options: OpWriter,
34}
35
36impl WriteContext {
37    /// Create a new WriteContext.
38    #[inline]
39    pub fn new(acc: Accessor, path: String, args: OpWrite, options: OpWriter) -> Self {
40        Self {
41            acc,
42            path,
43            args,
44            options,
45        }
46    }
47
48    /// Get the accessor.
49    #[inline]
50    pub fn accessor(&self) -> &Accessor {
51        &self.acc
52    }
53
54    /// Get the path.
55    #[inline]
56    pub fn path(&self) -> &str {
57        &self.path
58    }
59
60    /// Get the arguments.
61    #[inline]
62    pub fn args(&self) -> &OpWrite {
63        &self.args
64    }
65
66    /// Get the options.
67    #[inline]
68    pub fn options(&self) -> &OpWriter {
69        &self.options
70    }
71
72    /// Calculate the chunk size for this write process.
73    ///
74    /// Returns the chunk size and if the chunk size is exact.
75    fn calculate_chunk_size(&self) -> (Option<usize>, bool) {
76        let cap = self.accessor().info().full_capability();
77
78        let exact = self.options().chunk().is_some();
79        let chunk_size = self
80            .options()
81            .chunk()
82            .or(cap.write_multi_min_size)
83            .map(|mut size| {
84                if let Some(v) = cap.write_multi_max_size {
85                    size = size.min(v);
86                }
87                if let Some(v) = cap.write_multi_min_size {
88                    size = size.max(v);
89                }
90
91                size
92            });
93
94        (chunk_size, exact)
95    }
96}
97
98pub struct WriteGenerator<W> {
99    w: W,
100
101    /// The size for buffer, we will flush the underlying storage at the size of this buffer.
102    chunk_size: Option<usize>,
103    /// If `exact` is true, the size of the data written to the underlying storage is
104    /// exactly `chunk_size` bytes.
105    exact: bool,
106    buffer: oio::QueueBuf,
107}
108
109impl WriteGenerator<oio::Writer> {
110    /// Create a new exact buf writer.
111    pub async fn create(ctx: Arc<WriteContext>) -> Result<Self> {
112        let (chunk_size, exact) = ctx.calculate_chunk_size();
113        let (_, w) = ctx.acc.write(ctx.path(), ctx.args().clone()).await?;
114
115        Ok(Self {
116            w,
117            chunk_size,
118            exact,
119            buffer: oio::QueueBuf::new(),
120        })
121    }
122
123    /// Allow building from existing oio::Writer for easier testing.
124    #[cfg(test)]
125    fn new(w: oio::Writer, chunk_size: Option<usize>, exact: bool) -> Self {
126        Self {
127            w,
128            chunk_size,
129            exact,
130            buffer: oio::QueueBuf::new(),
131        }
132    }
133}
134
135impl WriteGenerator<oio::Writer> {
136    /// Write the entire buffer into writer.
137    pub async fn write(&mut self, mut bs: Buffer) -> Result<usize> {
138        let Some(chunk_size) = self.chunk_size else {
139            let size = bs.len();
140            self.w.write_dyn(bs).await?;
141            return Ok(size);
142        };
143
144        if self.buffer.len() + bs.len() < chunk_size {
145            let size = bs.len();
146            self.buffer.push(bs);
147            return Ok(size);
148        }
149
150        // Condition:
151        // - exact is false
152        // - buffer + bs is larger than chunk_size.
153        // Action:
154        // - write buffer + bs directly.
155        if !self.exact {
156            let fill_size = bs.len();
157            self.buffer.push(bs);
158            let buf = self.buffer.take().collect();
159            self.w.write_dyn(buf).await?;
160            return Ok(fill_size);
161        }
162
163        // Condition:
164        // - exact is true: we need write buffer in exact chunk size.
165        // - buffer is larger than chunk_size
166        //   - in exact mode, the size must be chunk_size, use `>=` just for safe coding.
167        // Action:
168        // - write existing buffer in chunk_size to make more rooms for writing data.
169        if self.buffer.len() >= chunk_size {
170            let buf = self.buffer.take().collect();
171            self.w.write_dyn(buf).await?;
172        }
173
174        // Condition
175        // - exact is true.
176        // - buffer size must lower than chunk_size.
177        // Action:
178        // - write bs to buffer with remaining size.
179        let remaining = chunk_size - self.buffer.len();
180        bs.truncate(remaining);
181        let n = bs.len();
182        self.buffer.push(bs);
183        Ok(n)
184    }
185
186    /// Finish the write process.
187    pub async fn close(&mut self) -> Result<Metadata> {
188        loop {
189            if self.buffer.is_empty() {
190                break;
191            }
192
193            let buf = self.buffer.take().collect();
194            self.w.write_dyn(buf).await?;
195        }
196
197        self.w.close().await
198    }
199
200    /// Abort the write process.
201    pub async fn abort(&mut self) -> Result<()> {
202        self.buffer.clear();
203        self.w.abort().await
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use bytes::Buf;
210    use bytes::BufMut;
211    use bytes::Bytes;
212    use log::debug;
213    use pretty_assertions::assert_eq;
214    use rand::thread_rng;
215    use rand::Rng;
216    use rand::RngCore;
217    use sha2::Digest;
218    use sha2::Sha256;
219    use tokio::sync::Mutex;
220
221    use super::*;
222    use crate::raw::oio::Write;
223
224    struct MockWriter {
225        buf: Arc<Mutex<Vec<u8>>>,
226    }
227
228    impl Write for MockWriter {
229        async fn write(&mut self, bs: Buffer) -> Result<()> {
230            debug!("test_fuzz_exact_buf_writer: flush size: {}", &bs.len());
231
232            let mut buf = self.buf.lock().await;
233            buf.put(bs);
234            Ok(())
235        }
236
237        async fn close(&mut self) -> Result<Metadata> {
238            Ok(Metadata::default())
239        }
240
241        async fn abort(&mut self) -> Result<()> {
242            Ok(())
243        }
244    }
245
246    #[tokio::test]
247    async fn test_exact_buf_writer_short_write() -> Result<()> {
248        let _ = tracing_subscriber::fmt()
249            .pretty()
250            .with_test_writer()
251            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
252            .try_init();
253
254        let mut rng = thread_rng();
255        let mut expected = vec![0; 5];
256        rng.fill_bytes(&mut expected);
257
258        let buf = Arc::new(Mutex::new(vec![]));
259        let mut w = WriteGenerator::new(Box::new(MockWriter { buf: buf.clone() }), Some(10), true);
260
261        let mut bs = Bytes::from(expected.clone());
262        while !bs.is_empty() {
263            let n = w.write(bs.clone().into()).await?;
264            bs.advance(n);
265        }
266
267        w.close().await?;
268
269        let buf = buf.lock().await;
270        assert_eq!(buf.len(), expected.len());
271        assert_eq!(
272            format!("{:x}", Sha256::digest(&*buf)),
273            format!("{:x}", Sha256::digest(&expected))
274        );
275        Ok(())
276    }
277
278    #[tokio::test]
279    async fn test_inexact_buf_writer_large_write() -> Result<()> {
280        let _ = tracing_subscriber::fmt()
281            .pretty()
282            .with_test_writer()
283            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
284            .try_init();
285
286        let buf = Arc::new(Mutex::new(vec![]));
287        let mut w = WriteGenerator::new(Box::new(MockWriter { buf: buf.clone() }), Some(10), false);
288
289        let mut rng = thread_rng();
290        let mut expected = vec![0; 15];
291        rng.fill_bytes(&mut expected);
292
293        let bs = Bytes::from(expected.clone());
294        // The MockWriter always returns the first chunk size.
295        let n = w.write(bs.into()).await?;
296        assert_eq!(expected.len(), n);
297
298        w.close().await?;
299
300        let buf = buf.lock().await;
301        assert_eq!(buf.len(), expected.len());
302        assert_eq!(
303            format!("{:x}", Sha256::digest(&*buf)),
304            format!("{:x}", Sha256::digest(&expected))
305        );
306        Ok(())
307    }
308
309    #[tokio::test]
310    async fn test_inexact_buf_writer_combine_small() -> Result<()> {
311        let _ = tracing_subscriber::fmt()
312            .pretty()
313            .with_test_writer()
314            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
315            .try_init();
316
317        let buf = Arc::new(Mutex::new(vec![]));
318        let mut w = WriteGenerator::new(Box::new(MockWriter { buf: buf.clone() }), Some(10), false);
319
320        let mut rng = thread_rng();
321        let mut expected = vec![];
322
323        let mut new_content = |size| {
324            let mut content = vec![0; size];
325            rng.fill_bytes(&mut content);
326            expected.extend_from_slice(&content);
327            Bytes::from(content)
328        };
329
330        // content > chunk size.
331        let content = new_content(15);
332        assert_eq!(15, w.write(content.into()).await?);
333        // content < chunk size.
334        let content = new_content(5);
335        assert_eq!(5, w.write(content.into()).await?);
336        // content > chunk size, but 5 bytes in queue.
337        let content = new_content(15);
338        // The MockWriter can send all 15 bytes together, so we can only advance 5 bytes.
339        assert_eq!(15, w.write(content.clone().into()).await?);
340
341        w.close().await?;
342
343        let buf = buf.lock().await;
344        assert_eq!(buf.len(), expected.len());
345        assert_eq!(
346            format!("{:x}", Sha256::digest(&*buf)),
347            format!("{:x}", Sha256::digest(&expected))
348        );
349        Ok(())
350    }
351
352    #[tokio::test]
353    async fn test_inexact_buf_writer_queue_remaining() -> Result<()> {
354        let _ = tracing_subscriber::fmt()
355            .pretty()
356            .with_test_writer()
357            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
358            .try_init();
359
360        let buf = Arc::new(Mutex::new(vec![]));
361        let mut w = WriteGenerator::new(Box::new(MockWriter { buf: buf.clone() }), Some(10), false);
362
363        let mut rng = thread_rng();
364        let mut expected = vec![];
365
366        let mut new_content = |size| {
367            let mut content = vec![0; size];
368            rng.fill_bytes(&mut content);
369            expected.extend_from_slice(&content);
370            Bytes::from(content)
371        };
372
373        // content > chunk size.
374        let content = new_content(15);
375        assert_eq!(15, w.write(content.into()).await?);
376        // content < chunk size.
377        let content = new_content(5);
378        assert_eq!(5, w.write(content.into()).await?);
379        // content < chunk size.
380        let content = new_content(3);
381        assert_eq!(3, w.write(content.into()).await?);
382        // content > chunk size, but can send all chunks in the queue.
383        let content = new_content(15);
384        assert_eq!(15, w.write(content.clone().into()).await?);
385
386        w.close().await?;
387
388        let buf = buf.lock().await;
389        assert_eq!(buf.len(), expected.len());
390        assert_eq!(
391            format!("{:x}", Sha256::digest(&*buf)),
392            format!("{:x}", Sha256::digest(&expected))
393        );
394        Ok(())
395    }
396
397    #[tokio::test]
398    async fn test_fuzz_exact_buf_writer() -> Result<()> {
399        let _ = tracing_subscriber::fmt()
400            .pretty()
401            .with_test_writer()
402            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
403            .try_init();
404
405        let mut rng = thread_rng();
406        let mut expected = vec![];
407
408        let buf = Arc::new(Mutex::new(vec![]));
409        let buffer_size = rng.gen_range(1..10);
410        let mut writer = WriteGenerator::new(
411            Box::new(MockWriter { buf: buf.clone() }),
412            Some(buffer_size),
413            true,
414        );
415        debug!("test_fuzz_exact_buf_writer: buffer size: {buffer_size}");
416
417        for _ in 0..1000 {
418            let size = rng.gen_range(1..20);
419            debug!("test_fuzz_exact_buf_writer: write size: {size}");
420            let mut content = vec![0; size];
421            rng.fill_bytes(&mut content);
422
423            expected.extend_from_slice(&content);
424
425            let mut bs = Bytes::from(content.clone());
426            while !bs.is_empty() {
427                let n = writer.write(bs.clone().into()).await?;
428                bs.advance(n);
429            }
430        }
431        writer.close().await?;
432
433        let buf = buf.lock().await;
434        assert_eq!(buf.len(), expected.len());
435        assert_eq!(
436            format!("{:x}", Sha256::digest(&*buf)),
437            format!("{:x}", Sha256::digest(&expected))
438        );
439        Ok(())
440    }
441}