Skip to main content

opendal_core/raw/oio/read/
position_read.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use futures::Future;
21use mea::once::OnceCell;
22
23use crate::raw::*;
24use crate::*;
25
26const DEFAULT_POSITION_READ_MAX_BUF_SIZE: usize = 2 * 1024 * 1024;
27
28/// PositionRead is used to implement [`oio::Read`] based on positioned reads.
29///
30/// Services that implement [`PositionRead`] create a positioned read handle lazily
31/// and must support position-independent reads on that handle. `size` is the
32/// maximum number of bytes to read, and implementations may return fewer bytes.
33/// Returning an empty buffer means EOF.
34pub trait PositionRead: Send + Sync + Unpin + 'static {
35    /// The opened positioned read handle.
36    type Handle: Send + Sync + Unpin + 'static;
37
38    /// Open the positioned read handle.
39    fn open(&self) -> impl Future<Output = Result<Self::Handle>> + MaybeSend;
40
41    /// Read up to `size` bytes from `offset`.
42    fn read_at(
43        handle: &Self::Handle,
44        offset: u64,
45        size: usize,
46    ) -> impl Future<Output = Result<Buffer>> + MaybeSend;
47}
48
49/// PositionReader implements [`oio::Read`] based on [`PositionRead`].
50pub struct PositionReader<R: PositionRead> {
51    reader: Arc<R>,
52    handle: Arc<OnceCell<R::Handle>>,
53    max_buf_size: usize,
54}
55
56impl<R: PositionRead> PositionReader<R> {
57    /// Create a new [`PositionReader`].
58    pub fn new(reader: R) -> Self {
59        Self {
60            reader: Arc::new(reader),
61            handle: Arc::new(OnceCell::new()),
62            max_buf_size: DEFAULT_POSITION_READ_MAX_BUF_SIZE,
63        }
64    }
65
66    /// Set the maximum buffer size used by [`PositionReader`].
67    pub fn with_max_buf_size(mut self, buf_size: usize) -> Self {
68        assert!(
69            buf_size > 0,
70            "position read max buffer size must not be zero"
71        );
72
73        self.max_buf_size = buf_size;
74        self
75    }
76
77    /// Consume the reader and return the inner [`PositionRead`].
78    ///
79    /// # Panics
80    ///
81    /// Panics if there are active streams that still share the inner reader.
82    pub fn into_inner(self) -> R {
83        Arc::into_inner(self.reader).expect("position reader must not be shared")
84    }
85
86    async fn handle(&self) -> Result<&R::Handle> {
87        self.handle.get_or_try_init(|| self.reader.open()).await
88    }
89}
90
91impl<R: PositionRead> oio::Read for PositionReader<R> {
92    async fn open(&self, range: BytesRange) -> Result<(RpRead, Box<dyn oio::ReadStreamDyn>)> {
93        let stream = PositionReadStream::new(
94            self.reader.clone(),
95            self.handle.clone(),
96            range,
97            self.max_buf_size,
98        );
99        Ok((
100            RpRead::default(),
101            Box::new(stream) as Box<dyn oio::ReadStreamDyn>,
102        ))
103    }
104
105    async fn read(&self, range: BytesRange) -> Result<(RpRead, Buffer)> {
106        let size = range
107            .size()
108            .ok_or_else(|| Error::new(ErrorKind::Unsupported, "read requires a bounded range"))?;
109
110        let mut offset = range.offset();
111        let mut remaining = size;
112        let mut bufs = Vec::new();
113
114        while remaining > 0 {
115            let read_size = remaining.min(self.max_buf_size as u64) as usize;
116            let handle = self.handle().await?;
117            let buf = R::read_at(handle, offset, read_size).await?;
118            check_position_read_size(read_size, buf.len())?;
119            if buf.is_empty() {
120                return Err(Error::new(
121                    ErrorKind::RangeNotSatisfied,
122                    "range exceeds content length",
123                )
124                .with_context("offset", offset)
125                .with_context("remaining", remaining));
126            }
127
128            let n = buf.len() as u64;
129            offset += n;
130            remaining -= n;
131            bufs.push(buf);
132        }
133
134        Ok((RpRead::default(), bufs.into_iter().flatten().collect()))
135    }
136}
137
138struct PositionReadStream<R: PositionRead> {
139    reader: Arc<R>,
140    handle: Arc<OnceCell<R::Handle>>,
141    offset: u64,
142    remaining: Option<u64>,
143    max_buf_size: usize,
144    done: bool,
145}
146
147impl<R: PositionRead> PositionReadStream<R> {
148    fn new(
149        reader: Arc<R>,
150        handle: Arc<OnceCell<R::Handle>>,
151        range: BytesRange,
152        max_buf_size: usize,
153    ) -> Self {
154        Self {
155            reader,
156            handle,
157            offset: range.offset(),
158            remaining: range.size(),
159            max_buf_size,
160            done: false,
161        }
162    }
163}
164
165impl<R: PositionRead> oio::ReadStream for PositionReadStream<R> {
166    async fn read(&mut self) -> Result<Buffer> {
167        if self.done || self.remaining == Some(0) {
168            return Ok(Buffer::new());
169        }
170
171        let read_size = self
172            .remaining
173            .map(|remaining| remaining.min(self.max_buf_size as u64) as usize)
174            .unwrap_or(self.max_buf_size);
175
176        let handle = self.handle.get_or_try_init(|| self.reader.open()).await?;
177        let buf = R::read_at(handle, self.offset, read_size).await?;
178        check_position_read_size(read_size, buf.len())?;
179        if buf.is_empty() {
180            self.done = true;
181            if let Some(remaining) = self.remaining {
182                return Err(Error::new(
183                    ErrorKind::RangeNotSatisfied,
184                    "range exceeds content length",
185                )
186                .with_context("offset", self.offset)
187                .with_context("remaining", remaining));
188            }
189            return Ok(Buffer::new());
190        }
191
192        let n = buf.len() as u64;
193        self.offset += n;
194        if let Some(remaining) = &mut self.remaining {
195            *remaining -= n;
196        }
197
198        Ok(buf)
199    }
200}
201
202fn check_position_read_size(expected: usize, actual: usize) -> Result<()> {
203    if actual > expected {
204        return Err(
205            Error::new(ErrorKind::Unexpected, "reader got unexpected data size")
206                .with_context("expect", expected)
207                .with_context("actual", actual),
208        );
209    }
210
211    Ok(())
212}
213
214#[cfg(test)]
215mod tests {
216    use std::sync::Arc;
217    use std::sync::Mutex;
218
219    use bytes::Bytes;
220
221    use super::*;
222    use crate::raw::oio::Read;
223    use crate::raw::oio::ReadStream;
224
225    struct TestPositionRead {
226        content: Bytes,
227        max_read: usize,
228        calls: Arc<Mutex<Vec<(u64, usize)>>>,
229        opens: Arc<Mutex<usize>>,
230    }
231
232    impl TestPositionRead {
233        fn new(content: &'static [u8], max_read: usize) -> Self {
234            Self {
235                content: Bytes::from_static(content),
236                max_read,
237                calls: Arc::default(),
238                opens: Arc::default(),
239            }
240        }
241    }
242
243    impl PositionRead for TestPositionRead {
244        type Handle = Self;
245
246        async fn open(&self) -> Result<Self::Handle> {
247            *self.opens.lock().unwrap() += 1;
248            Ok(Self {
249                content: self.content.clone(),
250                max_read: self.max_read,
251                calls: self.calls.clone(),
252                opens: self.opens.clone(),
253            })
254        }
255
256        async fn read_at(handle: &Self::Handle, offset: u64, size: usize) -> Result<Buffer> {
257            handle.calls.lock().unwrap().push((offset, size));
258
259            let offset = offset as usize;
260            if offset >= handle.content.len() {
261                return Ok(Buffer::new());
262            }
263
264            let end = offset + size.min(handle.max_read).min(handle.content.len() - offset);
265            Ok(Buffer::from(handle.content.slice(offset..end)))
266        }
267    }
268
269    #[tokio::test]
270    async fn test_position_reader_read_handles_partial_reads() -> Result<()> {
271        let inner = TestPositionRead::new(b"0123456789", 2);
272        let calls = inner.calls.clone();
273        let opens = inner.opens.clone();
274        let reader = PositionReader::new(inner).with_max_buf_size(4);
275
276        let (_, buf) = reader.read(BytesRange::from(2..8)).await?;
277
278        assert_eq!(buf.to_vec(), b"234567");
279        assert_eq!(calls.lock().unwrap().as_slice(), &[(2, 4), (4, 4), (6, 2)]);
280        assert_eq!(*opens.lock().unwrap(), 1);
281
282        Ok(())
283    }
284
285    #[tokio::test]
286    async fn test_position_reader_read_reports_early_eof() -> Result<()> {
287        let reader =
288            PositionReader::new(TestPositionRead::new(b"0123456789", 4)).with_max_buf_size(4);
289
290        let err = reader.read(BytesRange::from(8..12)).await.unwrap_err();
291
292        assert_eq!(err.kind(), ErrorKind::RangeNotSatisfied);
293        Ok(())
294    }
295
296    #[tokio::test]
297    async fn test_position_reader_open_stops_at_eof() -> Result<()> {
298        let reader =
299            PositionReader::new(TestPositionRead::new(b"0123456789", 2)).with_max_buf_size(4);
300        let (_, mut stream) = reader.open(BytesRange::from(8..)).await?;
301
302        let buf = stream.read_all().await?;
303
304        assert_eq!(buf.to_vec(), b"89");
305        Ok(())
306    }
307}