opendal_core/raw/oio/write/
position_write.rs1use std::sync::Arc;
19
20use futures::Future;
21use futures::FutureExt;
22use futures::select;
23
24use crate::raw::*;
25use crate::*;
26
27pub trait PositionWrite: Send + Sync + Unpin + 'static {
47 fn write_all_at(
49 &self,
50 offset: u64,
51 buf: Buffer,
52 ) -> impl Future<Output = Result<()>> + MaybeSend;
53
54 fn close(&self, size: u64) -> impl Future<Output = Result<Metadata>> + MaybeSend;
56
57 fn abort(&self) -> impl Future<Output = Result<()>> + MaybeSend;
59}
60
61struct WriteInput<W: PositionWrite> {
62 w: Arc<W>,
63 executor: Executor,
64
65 offset: u64,
66 bytes: Buffer,
67}
68
69pub struct PositionWriter<W: PositionWrite> {
71 w: Arc<W>,
72 executor: Executor,
73
74 next_offset: u64,
75 cache: Option<Buffer>,
76 tasks: ConcurrentTasks<WriteInput<W>, ()>,
77}
78
79#[allow(dead_code)]
80impl<W: PositionWrite> PositionWriter<W> {
81 pub fn new(info: Arc<AccessorInfo>, inner: W, concurrent: usize) -> Self {
83 let executor = info.executor();
84
85 Self {
86 w: Arc::new(inner),
87 executor: executor.clone(),
88 next_offset: 0,
89 cache: None,
90
91 tasks: ConcurrentTasks::new(executor, concurrent, 8192, |input| {
92 Box::pin(async move {
93 let fut = input.w.write_all_at(input.offset, input.bytes.clone());
94 match input.executor.timeout() {
95 None => {
96 let result = fut.await;
97 (input, result)
98 }
99 Some(timeout) => {
100 let result = select! {
101 result = fut.fuse() => {
102 result
103 }
104 _ = timeout.fuse() => {
105 Err(Error::new(
106 ErrorKind::Unexpected, "write position timeout")
107 .with_context("offset", input.offset.to_string())
108 .set_temporary())
109 }
110 };
111 (input, result)
112 }
113 }
114 })
115 }),
116 }
117 }
118
119 fn fill_cache(&mut self, bs: Buffer) -> usize {
120 let size = bs.len();
121 assert!(self.cache.is_none());
122 self.cache = Some(bs);
123 size
124 }
125}
126
127impl<W: PositionWrite> oio::Write for PositionWriter<W> {
128 async fn write(&mut self, bs: Buffer) -> Result<()> {
129 if self.cache.is_none() {
130 let _ = self.fill_cache(bs);
131 return Ok(());
132 }
133
134 let bytes = self.cache.clone().expect("pending write must exist");
135 let length = bytes.len() as u64;
136 let offset = self.next_offset;
137
138 self.tasks
139 .execute(WriteInput {
140 w: self.w.clone(),
141 executor: self.executor.clone(),
142 offset,
143 bytes,
144 })
145 .await?;
146 self.cache = None;
147 self.next_offset += length;
148 let _ = self.fill_cache(bs);
149 Ok(())
150 }
151
152 async fn close(&mut self) -> Result<Metadata> {
153 while self.tasks.next().await.transpose()?.is_some() {}
155
156 if let Some(buffer) = self.cache.clone() {
157 let offset = self.next_offset;
158 self.w.write_all_at(offset, buffer.clone()).await?;
159 self.cache = None;
160 self.next_offset += buffer.len() as u64;
161 }
162 let final_size = self.next_offset;
163 self.w.close(final_size).await
164 }
165
166 async fn abort(&mut self) -> Result<()> {
167 self.tasks.clear();
168 self.cache = None;
169 self.w.abort().await?;
170 Ok(())
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use std::collections::HashSet;
177 use std::sync::Mutex;
178
179 use pretty_assertions::assert_eq;
180 use rand::{Rng, RngExt, rng};
181 use tokio::time::sleep;
182
183 use super::*;
184 use crate::raw::oio::Write;
185
186 struct TestWrite {
187 length: u64,
188 bytes: HashSet<u64>,
189 }
190
191 impl TestWrite {
192 pub fn new() -> Arc<Mutex<Self>> {
193 let v = Self {
194 bytes: HashSet::new(),
195 length: 0,
196 };
197
198 Arc::new(Mutex::new(v))
199 }
200 }
201
202 impl PositionWrite for Arc<Mutex<TestWrite>> {
203 async fn write_all_at(&self, offset: u64, buf: Buffer) -> Result<()> {
204 sleep(Duration::from_millis(50)).await;
206
207 if rng().random_bool(1.0 / 10.0) {
209 return Err(
210 Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!").set_temporary()
211 );
212 }
213
214 let mut test = self.lock().unwrap();
215 let size = buf.len() as u64;
216 test.length += size;
217
218 let input = (offset..offset + size).collect::<HashSet<_>>();
219
220 assert!(
221 test.bytes.is_disjoint(&input),
222 "input should not have overlap"
223 );
224 test.bytes.extend(input);
225
226 Ok(())
227 }
228
229 async fn close(&self, _size: u64) -> Result<Metadata> {
230 Ok(Metadata::default())
231 }
232
233 async fn abort(&self) -> Result<()> {
234 Ok(())
235 }
236 }
237
238 #[tokio::test]
239 async fn test_position_writer_with_concurrent_errors() {
240 let mut rng = rng();
241
242 let mut w = PositionWriter::new(Arc::default(), TestWrite::new(), 200);
243 let mut total_size = 0u64;
244
245 for _ in 0..1000 {
246 let size = rng.random_range(1..1024);
247 total_size += size as u64;
248
249 let mut bs = vec![0; size];
250 rng.fill_bytes(&mut bs);
251
252 loop {
253 match w.write(bs.clone().into()).await {
254 Ok(_) => break,
255 Err(e) => {
256 println!("write error: {e:?}");
257 continue;
258 }
259 }
260 }
261 }
262
263 loop {
264 match w.close().await {
265 Ok(n) => {
266 println!("close: {n:?}");
267 break;
268 }
269 Err(e) => {
270 println!("close error: {e:?}");
271 continue;
272 }
273 }
274 }
275
276 let actual_bytes = w.w.lock().unwrap().bytes.clone();
277 let expected_bytes: HashSet<_> = (0..total_size).collect();
278 assert_eq!(actual_bytes, expected_bytes);
279
280 let actual_size = w.w.lock().unwrap().length;
281 assert_eq!(actual_size, total_size);
282 }
283}