1use std::sync::Arc;
19
20use futures::select;
21use futures::Future;
22use futures::FutureExt;
23use futures::TryFutureExt;
24use uuid::Uuid;
25
26use crate::raw::*;
27use crate::*;
28
29pub trait BlockWrite: Send + Sync + Unpin + 'static {
61 fn write_once(
67 &self,
68 size: u64,
69 body: Buffer,
70 ) -> impl Future<Output = Result<Metadata>> + MaybeSend;
71
72 fn write_block(
79 &self,
80 block_id: Uuid,
81 size: u64,
82 body: Buffer,
83 ) -> impl Future<Output = Result<()>> + MaybeSend;
84
85 fn complete_block(
88 &self,
89 block_ids: Vec<Uuid>,
90 ) -> impl Future<Output = Result<Metadata>> + MaybeSend;
91
92 fn abort_block(&self, block_ids: Vec<Uuid>) -> impl Future<Output = Result<()>> + MaybeSend;
94}
95
96struct WriteInput<W: BlockWrite> {
97 w: Arc<W>,
98 executor: Executor,
99 block_id: Uuid,
100 bytes: Buffer,
101}
102
103pub struct BlockWriter<W: BlockWrite> {
106 w: Arc<W>,
107 executor: Executor,
108
109 started: bool,
110 block_ids: Vec<Uuid>,
111 cache: Option<Buffer>,
112 tasks: ConcurrentTasks<WriteInput<W>, Uuid>,
113}
114
115impl<W: BlockWrite> BlockWriter<W> {
116 pub fn new(info: Arc<AccessorInfo>, inner: W, concurrent: usize) -> Self {
118 let executor = info.executor();
119
120 Self {
121 w: Arc::new(inner),
122 executor: executor.clone(),
123 started: false,
124 block_ids: Vec::new(),
125 cache: None,
126
127 tasks: ConcurrentTasks::new(executor, concurrent, |input| {
128 Box::pin(async move {
129 let fut = input
130 .w
131 .write_block(
132 input.block_id,
133 input.bytes.len() as u64,
134 input.bytes.clone(),
135 )
136 .map_ok(|_| input.block_id);
137 match input.executor.timeout() {
138 None => {
139 let result = fut.await;
140 (input, result)
141 }
142 Some(timeout) => {
143 let result = select! {
144 result = fut.fuse() => {
145 result
146 }
147 _ = timeout.fuse() => {
148 Err(Error::new(
149 ErrorKind::Unexpected, "write block timeout")
150 .with_context("block_id", input.block_id.to_string())
151 .set_temporary())
152 }
153 };
154 (input, result)
155 }
156 }
157 })
158 }),
159 }
160 }
161
162 fn fill_cache(&mut self, bs: Buffer) -> usize {
163 let size = bs.len();
164 assert!(self.cache.is_none());
165 self.cache = Some(bs);
166 size
167 }
168}
169
170impl<W> oio::Write for BlockWriter<W>
171where
172 W: BlockWrite,
173{
174 async fn write(&mut self, bs: Buffer) -> Result<()> {
175 if !self.started && self.cache.is_none() {
176 self.fill_cache(bs);
177 return Ok(());
178 }
179
180 self.started = true;
182
183 let bytes = self.cache.clone().expect("pending write must exist");
184 self.tasks
185 .execute(WriteInput {
186 w: self.w.clone(),
187 executor: self.executor.clone(),
188 block_id: Uuid::new_v4(),
189 bytes,
190 })
191 .await?;
192 self.cache = None;
193 self.fill_cache(bs);
194 Ok(())
195 }
196
197 async fn close(&mut self) -> Result<Metadata> {
198 if !self.started {
199 let (size, body) = match self.cache.clone() {
200 Some(cache) => (cache.len(), cache),
201 None => (0, Buffer::new()),
202 };
203
204 let meta = self.w.write_once(size as u64, body).await?;
205 self.cache = None;
206 return Ok(meta);
207 }
208
209 if let Some(cache) = self.cache.clone() {
210 self.tasks
211 .execute(WriteInput {
212 w: self.w.clone(),
213 executor: self.executor.clone(),
214 block_id: Uuid::new_v4(),
215 bytes: cache,
216 })
217 .await?;
218 self.cache = None;
219 }
220
221 loop {
222 let Some(result) = self.tasks.next().await.transpose()? else {
223 break;
224 };
225 self.block_ids.push(result);
226 }
227
228 let block_ids = self.block_ids.clone();
229 self.w.complete_block(block_ids).await
230 }
231
232 async fn abort(&mut self) -> Result<()> {
233 if !self.started {
234 return Ok(());
235 }
236
237 self.tasks.clear();
238 self.cache = None;
239 self.w.abort_block(self.block_ids.clone()).await?;
240 Ok(())
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use std::collections::HashMap;
247 use std::sync::Mutex;
248 use std::time::Duration;
249
250 use pretty_assertions::assert_eq;
251 use rand::thread_rng;
252 use rand::Rng;
253 use rand::RngCore;
254 use tokio::time::sleep;
255
256 use super::*;
257 use crate::raw::oio::Write;
258
259 struct TestWrite {
260 length: u64,
261 bytes: HashMap<Uuid, Buffer>,
262 content: Option<Buffer>,
263 }
264
265 impl TestWrite {
266 pub fn new() -> Arc<Mutex<Self>> {
267 let v = Self {
268 length: 0,
269 bytes: HashMap::new(),
270 content: None,
271 };
272
273 Arc::new(Mutex::new(v))
274 }
275 }
276
277 impl BlockWrite for Arc<Mutex<TestWrite>> {
278 async fn write_once(&self, size: u64, body: Buffer) -> Result<Metadata> {
279 sleep(Duration::from_nanos(50)).await;
280
281 if thread_rng().gen_bool(1.0 / 10.0) {
282 return Err(
283 Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!").set_temporary()
284 );
285 }
286
287 let mut this = self.lock().unwrap();
288 this.length = size;
289 this.content = Some(body);
290 Ok(Metadata::default())
291 }
292
293 async fn write_block(&self, block_id: Uuid, size: u64, body: Buffer) -> Result<()> {
294 sleep(Duration::from_millis(50)).await;
296
297 if thread_rng().gen_bool(1.0 / 10.0) {
299 return Err(
300 Error::new(ErrorKind::Unexpected, "I'm a crazy monkey!").set_temporary()
301 );
302 }
303
304 let mut this = self.lock().unwrap();
305 this.length += size;
306 this.bytes.insert(block_id, body);
307
308 Ok(())
309 }
310
311 async fn complete_block(&self, block_ids: Vec<Uuid>) -> Result<Metadata> {
312 let mut this = self.lock().unwrap();
313 let mut bs = Vec::new();
314 for id in block_ids {
315 bs.push(this.bytes[&id].clone());
316 }
317 this.content = Some(bs.into_iter().flatten().collect());
318
319 Ok(Metadata::default())
320 }
321
322 async fn abort_block(&self, _: Vec<Uuid>) -> Result<()> {
323 Ok(())
324 }
325 }
326
327 #[tokio::test]
328 async fn test_block_writer_with_concurrent_errors() {
329 let mut rng = thread_rng();
330
331 let mut w = BlockWriter::new(Arc::default(), TestWrite::new(), 8);
332 let mut total_size = 0u64;
333 let mut expected_content = Vec::new();
334
335 for _ in 0..1000 {
336 let size = rng.gen_range(1..1024);
337 total_size += size as u64;
338
339 let mut bs = vec![0; size];
340 rng.fill_bytes(&mut bs);
341
342 expected_content.extend_from_slice(&bs);
343
344 loop {
345 match w.write(bs.clone().into()).await {
346 Ok(_) => break,
347 Err(_) => continue,
348 }
349 }
350 }
351
352 loop {
353 match w.close().await {
354 Ok(_) => break,
355 Err(_) => continue,
356 }
357 }
358
359 let inner = w.w.lock().unwrap();
360
361 assert_eq!(total_size, inner.length, "length must be the same");
362 assert!(inner.content.is_some());
363 assert_eq!(
364 expected_content,
365 inner.content.clone().unwrap().to_bytes(),
366 "content must be the same"
367 );
368 }
369
370 #[tokio::test]
371 async fn test_block_writer_with_retry_when_write_once_error() {
372 let mut rng = thread_rng();
373
374 for _ in 1..100 {
375 let mut w = BlockWriter::new(Arc::default(), TestWrite::new(), 8);
376
377 let size = rng.gen_range(1..1024);
378 let mut bs = vec![0; size];
379 rng.fill_bytes(&mut bs);
380
381 loop {
382 match w.write(bs.clone().into()).await {
383 Ok(_) => break,
384 Err(_) => continue,
385 }
386 }
387
388 loop {
389 match w.close().await {
390 Ok(_) => break,
391 Err(_) => continue,
392 }
393 }
394
395 let inner = w.w.lock().unwrap();
396 assert_eq!(size as u64, inner.length, "length must be the same");
397 assert!(inner.content.is_some());
398 assert_eq!(
399 bs,
400 inner.content.clone().unwrap().to_bytes(),
401 "content must be the same"
402 );
403 }
404 }
405}