Skip to main content

opendal_core/raw/oio/read/
position_read.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use futures::Future;
21
22use crate::raw::*;
23use crate::*;
24
25const DEFAULT_POSITION_READ_MAX_BUF_SIZE: usize = 2 * 1024 * 1024;
26
27/// PositionRead is used to implement [`oio::Read`] based on positioned reads.
28///
29/// Services that implement [`PositionRead`] must support position-independent
30/// reads. `size` is the maximum number of bytes to read, and implementations may
31/// return fewer bytes. Returning an empty buffer means EOF.
32pub trait PositionRead: Send + Sync + Unpin + 'static {
33    /// Read up to `size` bytes from `offset`.
34    fn read_at(&self, offset: u64, size: usize)
35    -> impl Future<Output = Result<Buffer>> + MaybeSend;
36}
37
38/// PositionReader implements [`oio::Read`] based on [`PositionRead`].
39pub struct PositionReader<R: PositionRead> {
40    inner: Arc<R>,
41    max_buf_size: usize,
42}
43
44impl<R: PositionRead> PositionReader<R> {
45    /// Create a new [`PositionReader`].
46    pub fn new(inner: R) -> Self {
47        Self {
48            inner: Arc::new(inner),
49            max_buf_size: DEFAULT_POSITION_READ_MAX_BUF_SIZE,
50        }
51    }
52
53    /// Set the maximum buffer size used by [`PositionReader`].
54    pub fn with_max_buf_size(mut self, buf_size: usize) -> Self {
55        assert!(
56            buf_size > 0,
57            "position read max buffer size must not be zero"
58        );
59
60        self.max_buf_size = buf_size;
61        self
62    }
63
64    /// Consume the reader and return the inner [`PositionRead`].
65    ///
66    /// # Panics
67    ///
68    /// Panics if there are active streams that still share the inner reader.
69    pub fn into_inner(self) -> R {
70        Arc::into_inner(self.inner).expect("position reader must not be shared")
71    }
72}
73
74impl<R: PositionRead> oio::Read for PositionReader<R> {
75    async fn open(&self, range: BytesRange) -> Result<(RpRead, Box<dyn oio::ReadStreamDyn>)> {
76        let stream = PositionReadStream::new(self.inner.clone(), range, self.max_buf_size);
77        Ok((RpRead::default(), Box::new(stream)))
78    }
79
80    async fn read(&self, range: BytesRange) -> Result<(RpRead, Buffer)> {
81        let size = range
82            .size()
83            .ok_or_else(|| Error::new(ErrorKind::Unsupported, "read requires a bounded range"))?;
84
85        let mut offset = range.offset();
86        let mut remaining = size;
87        let mut bufs = Vec::new();
88
89        while remaining > 0 {
90            let read_size = remaining.min(self.max_buf_size as u64) as usize;
91            let buf = self.inner.read_at(offset, read_size).await?;
92            check_position_read_size(read_size, buf.len())?;
93            if buf.is_empty() {
94                return Err(Error::new(
95                    ErrorKind::RangeNotSatisfied,
96                    "range exceeds content length",
97                )
98                .with_context("offset", offset)
99                .with_context("remaining", remaining));
100            }
101
102            let n = buf.len() as u64;
103            offset += n;
104            remaining -= n;
105            bufs.push(buf);
106        }
107
108        Ok((RpRead::default(), bufs.into_iter().flatten().collect()))
109    }
110}
111
112struct PositionReadStream<R: PositionRead> {
113    inner: Arc<R>,
114    offset: u64,
115    remaining: Option<u64>,
116    max_buf_size: usize,
117    done: bool,
118}
119
120impl<R: PositionRead> PositionReadStream<R> {
121    fn new(inner: Arc<R>, range: BytesRange, max_buf_size: usize) -> Self {
122        Self {
123            inner,
124            offset: range.offset(),
125            remaining: range.size(),
126            max_buf_size,
127            done: false,
128        }
129    }
130}
131
132impl<R: PositionRead> oio::ReadStream for PositionReadStream<R> {
133    async fn read(&mut self) -> Result<Buffer> {
134        if self.done || self.remaining == Some(0) {
135            return Ok(Buffer::new());
136        }
137
138        let read_size = self
139            .remaining
140            .map(|remaining| remaining.min(self.max_buf_size as u64) as usize)
141            .unwrap_or(self.max_buf_size);
142
143        let buf = self.inner.read_at(self.offset, read_size).await?;
144        check_position_read_size(read_size, buf.len())?;
145        if buf.is_empty() {
146            self.done = true;
147            if let Some(remaining) = self.remaining {
148                return Err(Error::new(
149                    ErrorKind::RangeNotSatisfied,
150                    "range exceeds content length",
151                )
152                .with_context("offset", self.offset)
153                .with_context("remaining", remaining));
154            }
155            return Ok(Buffer::new());
156        }
157
158        let n = buf.len() as u64;
159        self.offset += n;
160        if let Some(remaining) = &mut self.remaining {
161            *remaining -= n;
162        }
163
164        Ok(buf)
165    }
166}
167
168fn check_position_read_size(expected: usize, actual: usize) -> Result<()> {
169    if actual > expected {
170        return Err(
171            Error::new(ErrorKind::Unexpected, "reader got unexpected data size")
172                .with_context("expect", expected)
173                .with_context("actual", actual),
174        );
175    }
176
177    Ok(())
178}
179
180#[cfg(test)]
181mod tests {
182    use std::sync::Arc;
183    use std::sync::Mutex;
184
185    use bytes::Bytes;
186
187    use super::*;
188    use crate::raw::oio::Read;
189    use crate::raw::oio::ReadStream;
190
191    struct TestPositionRead {
192        content: Bytes,
193        max_read: usize,
194        calls: Arc<Mutex<Vec<(u64, usize)>>>,
195    }
196
197    impl TestPositionRead {
198        fn new(content: &'static [u8], max_read: usize) -> Self {
199            Self {
200                content: Bytes::from_static(content),
201                max_read,
202                calls: Arc::default(),
203            }
204        }
205    }
206
207    impl PositionRead for TestPositionRead {
208        async fn read_at(&self, offset: u64, size: usize) -> Result<Buffer> {
209            self.calls.lock().unwrap().push((offset, size));
210
211            let offset = offset as usize;
212            if offset >= self.content.len() {
213                return Ok(Buffer::new());
214            }
215
216            let end = offset + size.min(self.max_read).min(self.content.len() - offset);
217            Ok(Buffer::from(self.content.slice(offset..end)))
218        }
219    }
220
221    #[tokio::test]
222    async fn test_position_reader_read_handles_partial_reads() -> Result<()> {
223        let inner = TestPositionRead::new(b"0123456789", 2);
224        let calls = inner.calls.clone();
225        let reader = PositionReader::new(inner).with_max_buf_size(4);
226
227        let (_, buf) = reader.read(BytesRange::from(2..8)).await?;
228
229        assert_eq!(buf.to_vec(), b"234567");
230        assert_eq!(calls.lock().unwrap().as_slice(), &[(2, 4), (4, 4), (6, 2)]);
231
232        Ok(())
233    }
234
235    #[tokio::test]
236    async fn test_position_reader_read_reports_early_eof() -> Result<()> {
237        let reader =
238            PositionReader::new(TestPositionRead::new(b"0123456789", 4)).with_max_buf_size(4);
239
240        let err = reader.read(BytesRange::from(8..12)).await.unwrap_err();
241
242        assert_eq!(err.kind(), ErrorKind::RangeNotSatisfied);
243        Ok(())
244    }
245
246    #[tokio::test]
247    async fn test_position_reader_open_stops_at_eof() -> Result<()> {
248        let reader =
249            PositionReader::new(TestPositionRead::new(b"0123456789", 2)).with_max_buf_size(4);
250        let (_, mut stream) = reader.open(BytesRange::from(8..)).await?;
251
252        let buf = stream.read_all().await?;
253
254        assert_eq!(buf.to_vec(), b"89");
255        Ok(())
256    }
257}