opendal/raw/oio/write/
position_write.rs1use std::sync::Arc;
19
20use futures::select;
21use futures::Future;
22use futures::FutureExt;
23
24use crate::raw::*;
25use crate::*;
26
27pub trait PositionWrite: Send + Sync + Unpin + 'static {
47 fn write_all_at(
49 &self,
50 offset: u64,
51 buf: Buffer,
52 ) -> impl Future<Output = Result<()>> + MaybeSend;
53
54 fn close(&self) -> impl Future<Output = Result<Metadata>> + MaybeSend;
56
57 fn abort(&self) -> impl Future<Output = Result<()>> + MaybeSend;
59}
60
61struct WriteInput<W: PositionWrite> {
62 w: Arc<W>,
63 executor: Executor,
64
65 offset: u64,
66 bytes: Buffer,
67}
68
69pub struct PositionWriter<W: PositionWrite> {
71 w: Arc<W>,
72 executor: Executor,
73
74 next_offset: u64,
75 cache: Option<Buffer>,
76 tasks: ConcurrentTasks<WriteInput<W>, ()>,
77}
78
79#[allow(dead_code)]
80impl<W: PositionWrite> PositionWriter<W> {
81 pub fn new(info: Arc<AccessorInfo>, inner: W, concurrent: usize) -> Self {
83 let executor = info.executor();
84
85 Self {
86 w: Arc::new(inner),
87 executor: executor.clone(),
88 next_offset: 0,
89 cache: None,
90
91 tasks: ConcurrentTasks::new(executor, concurrent, |input| {
92 Box::pin(async move {
93 let fut = input.w.write_all_at(input.offset, input.bytes.clone());
94 match input.executor.timeout() {
95 None => {
96 let result = fut.await;
97 (input, result)
98 }
99 Some(timeout) => {
100 let result = select! {
101 result = fut.fuse() => {
102 result
103 }
104 _ = timeout.fuse() => {
105 Err(Error::new(
106 ErrorKind::Unexpected, "write position timeout")
107 .with_context("offset", input.offset.to_string())
108 .set_temporary())
109 }
110 };
111 (input, result)
112 }
113 }
114 })
115 }),
116 }
117 }
118
119 fn fill_cache(&mut self, bs: Buffer) -> usize {
120 let size = bs.len();
121 assert!(self.cache.is_none());
122 self.cache = Some(bs);
123 size
124 }
125}
126
127impl<W: PositionWrite> oio::Write for PositionWriter<W> {
128 async fn write(&mut self, bs: Buffer) -> Result<()> {
129 if self.cache.is_none() {
130 let _ = self.fill_cache(bs);
131 return Ok(());
132 }
133
134 let bytes = self.cache.clone().expect("pending write must exist");
135 let length = bytes.len() as u64;
136 let offset = self.next_offset;
137
138 self.tasks
139 .execute(WriteInput {
140 w: self.w.clone(),
141 executor: self.executor.clone(),
142 offset,
143 bytes,
144 })
145 .await?;
146 self.cache = None;
147 self.next_offset += length;
148 let _ = self.fill_cache(bs);
149 Ok(())
150 }
151
152 async fn close(&mut self) -> Result<Metadata> {
153 while self.tasks.next().await.transpose()?.is_some() {}
155
156 if let Some(buffer) = self.cache.clone() {
157 let offset = self.next_offset;
158 self.w.write_all_at(offset, buffer).await?;
159 self.cache = None;
160 }
161 self.w.close().await
162 }
163
164 async fn abort(&mut self) -> Result<()> {
165 self.tasks.clear();
166 self.cache = None;
167 self.w.abort().await?;
168 Ok(())
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use std::collections::HashSet;
175 use std::sync::Mutex;
176 use std::time::Duration;
177
178 use pretty_assertions::assert_eq;
179 use rand::thread_rng;
180 use rand::Rng;
181 use rand::RngCore;
182 use tokio::time::sleep;
183
184 use super::*;
185 use crate::raw::oio::Write;
186
187 struct TestWrite {
188 length: u64,
189 bytes: HashSet<u64>,
190 }
191
192 impl TestWrite {
193 pub fn new() -> Arc<Mutex<Self>> {
194 let v = Self {
195 bytes: HashSet::new(),
196 length: 0,
197 };
198
199 Arc::new(Mutex::new(v))
200 }
201 }
202
203 impl PositionWrite for Arc<Mutex<TestWrite>> {
204 async fn write_all_at(&self, offset: u64, buf: Buffer) -> Result<()> {
205 sleep(Duration::from_millis(50)).await;
207
208 if thread_rng().gen_bool(1.0 / 10.0) {
210 return Err(
211 Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!").set_temporary()
212 );
213 }
214
215 let mut test = self.lock().unwrap();
216 let size = buf.len() as u64;
217 test.length += size;
218
219 let input = (offset..offset + size).collect::<HashSet<_>>();
220
221 assert!(
222 test.bytes.is_disjoint(&input),
223 "input should not have overlap"
224 );
225 test.bytes.extend(input);
226
227 Ok(())
228 }
229
230 async fn close(&self) -> Result<Metadata> {
231 Ok(Metadata::default())
232 }
233
234 async fn abort(&self) -> Result<()> {
235 Ok(())
236 }
237 }
238
239 #[tokio::test]
240 async fn test_position_writer_with_concurrent_errors() {
241 let mut rng = thread_rng();
242
243 let mut w = PositionWriter::new(Arc::default(), TestWrite::new(), 200);
244 let mut total_size = 0u64;
245
246 for _ in 0..1000 {
247 let size = rng.gen_range(1..1024);
248 total_size += size as u64;
249
250 let mut bs = vec![0; size];
251 rng.fill_bytes(&mut bs);
252
253 loop {
254 match w.write(bs.clone().into()).await {
255 Ok(_) => break,
256 Err(e) => {
257 println!("write error: {:?}", e);
258 continue;
259 }
260 }
261 }
262 }
263
264 loop {
265 match w.close().await {
266 Ok(n) => {
267 println!("close: {:?}", n);
268 break;
269 }
270 Err(e) => {
271 println!("close error: {:?}", e);
272 continue;
273 }
274 }
275 }
276
277 let actual_bytes = w.w.lock().unwrap().bytes.clone();
278 let expected_bytes: HashSet<_> = (0..total_size).collect();
279 assert_eq!(actual_bytes, expected_bytes);
280
281 let actual_size = w.w.lock().unwrap().length;
282 assert_eq!(actual_size, total_size);
283 }
284}