opendal/raw/oio/write/
position_write.rs1use std::sync::Arc;
19
20use futures::Future;
21use futures::FutureExt;
22use futures::select;
23
24use crate::raw::*;
25use crate::*;
26
27pub trait PositionWrite: Send + Sync + Unpin + 'static {
47 fn write_all_at(
49 &self,
50 offset: u64,
51 buf: Buffer,
52 ) -> impl Future<Output = Result<()>> + MaybeSend;
53
54 fn close(&self, size: u64) -> impl Future<Output = Result<Metadata>> + MaybeSend;
56
57 fn abort(&self) -> impl Future<Output = Result<()>> + MaybeSend;
59}
60
61struct WriteInput<W: PositionWrite> {
62 w: Arc<W>,
63 executor: Executor,
64
65 offset: u64,
66 bytes: Buffer,
67}
68
69pub struct PositionWriter<W: PositionWrite> {
71 w: Arc<W>,
72 executor: Executor,
73
74 next_offset: u64,
75 cache: Option<Buffer>,
76 tasks: ConcurrentTasks<WriteInput<W>, ()>,
77}
78
79#[allow(dead_code)]
80impl<W: PositionWrite> PositionWriter<W> {
81 pub fn new(info: Arc<AccessorInfo>, inner: W, concurrent: usize) -> Self {
83 let executor = info.executor();
84
85 Self {
86 w: Arc::new(inner),
87 executor: executor.clone(),
88 next_offset: 0,
89 cache: None,
90
91 tasks: ConcurrentTasks::new(executor, concurrent, 8192, |input| {
92 Box::pin(async move {
93 let fut = input.w.write_all_at(input.offset, input.bytes.clone());
94 match input.executor.timeout() {
95 None => {
96 let result = fut.await;
97 (input, result)
98 }
99 Some(timeout) => {
100 let result = select! {
101 result = fut.fuse() => {
102 result
103 }
104 _ = timeout.fuse() => {
105 Err(Error::new(
106 ErrorKind::Unexpected, "write position timeout")
107 .with_context("offset", input.offset.to_string())
108 .set_temporary())
109 }
110 };
111 (input, result)
112 }
113 }
114 })
115 }),
116 }
117 }
118
119 fn fill_cache(&mut self, bs: Buffer) -> usize {
120 let size = bs.len();
121 assert!(self.cache.is_none());
122 self.cache = Some(bs);
123 size
124 }
125}
126
127impl<W: PositionWrite> oio::Write for PositionWriter<W> {
128 async fn write(&mut self, bs: Buffer) -> Result<()> {
129 if self.cache.is_none() {
130 let _ = self.fill_cache(bs);
131 return Ok(());
132 }
133
134 let bytes = self.cache.clone().expect("pending write must exist");
135 let length = bytes.len() as u64;
136 let offset = self.next_offset;
137
138 self.tasks
139 .execute(WriteInput {
140 w: self.w.clone(),
141 executor: self.executor.clone(),
142 offset,
143 bytes,
144 })
145 .await?;
146 self.cache = None;
147 self.next_offset += length;
148 let _ = self.fill_cache(bs);
149 Ok(())
150 }
151
152 async fn close(&mut self) -> Result<Metadata> {
153 while self.tasks.next().await.transpose()?.is_some() {}
155
156 if let Some(buffer) = self.cache.clone() {
157 let offset = self.next_offset;
158 self.w.write_all_at(offset, buffer.clone()).await?;
159 self.cache = None;
160 self.next_offset += buffer.len() as u64;
161 }
162 let final_size = self.next_offset;
163 self.w.close(final_size).await
164 }
165
166 async fn abort(&mut self) -> Result<()> {
167 self.tasks.clear();
168 self.cache = None;
169 self.w.abort().await?;
170 Ok(())
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use std::collections::HashSet;
177 use std::sync::Mutex;
178 use std::time::Duration;
179
180 use pretty_assertions::assert_eq;
181 use rand::Rng;
182 use rand::RngCore;
183 use rand::thread_rng;
184 use tokio::time::sleep;
185
186 use super::*;
187 use crate::raw::oio::Write;
188
189 struct TestWrite {
190 length: u64,
191 bytes: HashSet<u64>,
192 }
193
194 impl TestWrite {
195 pub fn new() -> Arc<Mutex<Self>> {
196 let v = Self {
197 bytes: HashSet::new(),
198 length: 0,
199 };
200
201 Arc::new(Mutex::new(v))
202 }
203 }
204
205 impl PositionWrite for Arc<Mutex<TestWrite>> {
206 async fn write_all_at(&self, offset: u64, buf: Buffer) -> Result<()> {
207 sleep(Duration::from_millis(50)).await;
209
210 if thread_rng().gen_bool(1.0 / 10.0) {
212 return Err(
213 Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!").set_temporary()
214 );
215 }
216
217 let mut test = self.lock().unwrap();
218 let size = buf.len() as u64;
219 test.length += size;
220
221 let input = (offset..offset + size).collect::<HashSet<_>>();
222
223 assert!(
224 test.bytes.is_disjoint(&input),
225 "input should not have overlap"
226 );
227 test.bytes.extend(input);
228
229 Ok(())
230 }
231
232 async fn close(&self, _size: u64) -> Result<Metadata> {
233 Ok(Metadata::default())
234 }
235
236 async fn abort(&self) -> Result<()> {
237 Ok(())
238 }
239 }
240
241 #[tokio::test]
242 async fn test_position_writer_with_concurrent_errors() {
243 let mut rng = thread_rng();
244
245 let mut w = PositionWriter::new(Arc::default(), TestWrite::new(), 200);
246 let mut total_size = 0u64;
247
248 for _ in 0..1000 {
249 let size = rng.gen_range(1..1024);
250 total_size += size as u64;
251
252 let mut bs = vec![0; size];
253 rng.fill_bytes(&mut bs);
254
255 loop {
256 match w.write(bs.clone().into()).await {
257 Ok(_) => break,
258 Err(e) => {
259 println!("write error: {e:?}");
260 continue;
261 }
262 }
263 }
264 }
265
266 loop {
267 match w.close().await {
268 Ok(n) => {
269 println!("close: {n:?}");
270 break;
271 }
272 Err(e) => {
273 println!("close error: {e:?}");
274 continue;
275 }
276 }
277 }
278
279 let actual_bytes = w.w.lock().unwrap().bytes.clone();
280 let expected_bytes: HashSet<_> = (0..total_size).collect();
281 assert_eq!(actual_bytes, expected_bytes);
282
283 let actual_size = w.w.lock().unwrap().length;
284 assert_eq!(actual_size, total_size);
285 }
286}