opendal_core/raw/oio/write/
position_write.rs1use std::sync::Arc;
19
20use futures::Future;
21use futures::FutureExt;
22use futures::select;
23
24use crate::raw::*;
25use crate::*;
26
27pub trait PositionWrite: Send + Sync + Unpin + 'static {
47 fn write_all_at(
49 &self,
50 offset: u64,
51 buf: Buffer,
52 ) -> impl Future<Output = Result<()>> + MaybeSend;
53
54 fn close(&self, size: u64) -> impl Future<Output = Result<Metadata>> + MaybeSend;
56
57 fn abort(&self) -> impl Future<Output = Result<()>> + MaybeSend;
59}
60
61struct WriteInput<W: PositionWrite> {
62 w: Arc<W>,
63 executor: Executor,
64
65 offset: u64,
66 bytes: Buffer,
67}
68
69pub struct PositionWriter<W: PositionWrite> {
71 w: Arc<W>,
72 executor: Executor,
73
74 next_offset: u64,
75 cache: Option<Buffer>,
76 tasks: ConcurrentTasks<WriteInput<W>, ()>,
77}
78
79#[allow(dead_code)]
80impl<W: PositionWrite> PositionWriter<W> {
81 pub fn new(info: Arc<AccessorInfo>, inner: W, concurrent: usize) -> Self {
83 let executor = info.executor();
84
85 Self {
86 w: Arc::new(inner),
87 executor: executor.clone(),
88 next_offset: 0,
89 cache: None,
90
91 tasks: ConcurrentTasks::new(executor, concurrent, 8192, |input| {
92 Box::pin(async move {
93 let fut = input.w.write_all_at(input.offset, input.bytes.clone());
94 match input.executor.timeout() {
95 None => {
96 let result = fut.await;
97 (input, result)
98 }
99 Some(timeout) => {
100 let result = select! {
101 result = fut.fuse() => {
102 result
103 }
104 _ = timeout.fuse() => {
105 Err(Error::new(
106 ErrorKind::Unexpected, "write position timeout")
107 .with_context("offset", input.offset.to_string())
108 .set_temporary())
109 }
110 };
111 (input, result)
112 }
113 }
114 })
115 }),
116 }
117 }
118
119 fn fill_cache(&mut self, bs: Buffer) -> usize {
120 let size = bs.len();
121 assert!(self.cache.is_none());
122 self.cache = Some(bs);
123 size
124 }
125}
126
127impl<W: PositionWrite> oio::Write for PositionWriter<W> {
128 async fn write(&mut self, bs: Buffer) -> Result<()> {
129 if self.cache.is_none() {
130 let _ = self.fill_cache(bs);
131 return Ok(());
132 }
133
134 let bytes = self.cache.clone().expect("pending write must exist");
135 let length = bytes.len() as u64;
136 let offset = self.next_offset;
137
138 self.tasks
139 .execute(WriteInput {
140 w: self.w.clone(),
141 executor: self.executor.clone(),
142 offset,
143 bytes,
144 })
145 .await?;
146 self.cache = None;
147 self.next_offset += length;
148 let _ = self.fill_cache(bs);
149 Ok(())
150 }
151
152 async fn close(&mut self) -> Result<Metadata> {
153 while self.tasks.next().await.transpose()?.is_some() {}
155
156 if let Some(buffer) = self.cache.clone() {
157 let offset = self.next_offset;
158 self.w.write_all_at(offset, buffer.clone()).await?;
159 self.cache = None;
160 self.next_offset += buffer.len() as u64;
161 }
162 let final_size = self.next_offset;
163 self.w.close(final_size).await
164 }
165
166 async fn abort(&mut self) -> Result<()> {
167 self.tasks.clear();
168 self.cache = None;
169 self.w.abort().await?;
170 Ok(())
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use std::collections::HashSet;
177 use std::sync::Mutex;
178
179 use pretty_assertions::assert_eq;
180 use rand::Rng;
181 use rand::RngCore;
182 use rand::thread_rng;
183 use tokio::time::sleep;
184
185 use super::*;
186 use crate::raw::oio::Write;
187
188 struct TestWrite {
189 length: u64,
190 bytes: HashSet<u64>,
191 }
192
193 impl TestWrite {
194 pub fn new() -> Arc<Mutex<Self>> {
195 let v = Self {
196 bytes: HashSet::new(),
197 length: 0,
198 };
199
200 Arc::new(Mutex::new(v))
201 }
202 }
203
204 impl PositionWrite for Arc<Mutex<TestWrite>> {
205 async fn write_all_at(&self, offset: u64, buf: Buffer) -> Result<()> {
206 sleep(Duration::from_millis(50)).await;
208
209 if thread_rng().gen_bool(1.0 / 10.0) {
211 return Err(
212 Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!").set_temporary()
213 );
214 }
215
216 let mut test = self.lock().unwrap();
217 let size = buf.len() as u64;
218 test.length += size;
219
220 let input = (offset..offset + size).collect::<HashSet<_>>();
221
222 assert!(
223 test.bytes.is_disjoint(&input),
224 "input should not have overlap"
225 );
226 test.bytes.extend(input);
227
228 Ok(())
229 }
230
231 async fn close(&self, _size: u64) -> Result<Metadata> {
232 Ok(Metadata::default())
233 }
234
235 async fn abort(&self) -> Result<()> {
236 Ok(())
237 }
238 }
239
240 #[tokio::test]
241 async fn test_position_writer_with_concurrent_errors() {
242 let mut rng = thread_rng();
243
244 let mut w = PositionWriter::new(Arc::default(), TestWrite::new(), 200);
245 let mut total_size = 0u64;
246
247 for _ in 0..1000 {
248 let size = rng.gen_range(1..1024);
249 total_size += size as u64;
250
251 let mut bs = vec![0; size];
252 rng.fill_bytes(&mut bs);
253
254 loop {
255 match w.write(bs.clone().into()).await {
256 Ok(_) => break,
257 Err(e) => {
258 println!("write error: {e:?}");
259 continue;
260 }
261 }
262 }
263 }
264
265 loop {
266 match w.close().await {
267 Ok(n) => {
268 println!("close: {n:?}");
269 break;
270 }
271 Err(e) => {
272 println!("close error: {e:?}");
273 continue;
274 }
275 }
276 }
277
278 let actual_bytes = w.w.lock().unwrap().bytes.clone();
279 let expected_bytes: HashSet<_> = (0..total_size).collect();
280 assert_eq!(actual_bytes, expected_bytes);
281
282 let actual_size = w.w.lock().unwrap().length;
283 assert_eq!(actual_size, total_size);
284 }
285}