opendal/raw/
futures_util.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::VecDeque;
19use std::sync::atomic::AtomicUsize;
20use std::sync::atomic::Ordering;
21use std::sync::Arc;
22
23use futures::FutureExt;
24
25use crate::*;
26
27/// BoxedFuture is the type alias of [`futures::future::BoxFuture`].
28///
29/// We will switch to [`futures::future::LocalBoxFuture`] on wasm32 target.
30#[cfg(not(target_arch = "wasm32"))]
31pub type BoxedFuture<'a, T> = futures::future::BoxFuture<'a, T>;
32#[cfg(target_arch = "wasm32")]
33pub type BoxedFuture<'a, T> = futures::future::LocalBoxFuture<'a, T>;
34
35/// BoxedStaticFuture is the type alias of [`futures::future::BoxFuture`].
36///
37/// We will switch to [`futures::future::LocalBoxFuture`] on wasm32 target.
38#[cfg(not(target_arch = "wasm32"))]
39pub type BoxedStaticFuture<T> = futures::future::BoxFuture<'static, T>;
40#[cfg(target_arch = "wasm32")]
41pub type BoxedStaticFuture<T> = futures::future::LocalBoxFuture<'static, T>;
42
43/// MaybeSend is a marker to determine whether a type is `Send` or not.
44/// We use this trait to wrap the `Send` requirement for wasm32 target.
45///
46/// # Safety
47///
48/// [`MaybeSend`] is equivalent to `Send` on non-wasm32 target.
49/// And it's empty trait on wasm32 target to indicate that a type is not `Send`.
50#[cfg(not(target_arch = "wasm32"))]
51pub trait MaybeSend: Send {}
52#[cfg(target_arch = "wasm32")]
53pub trait MaybeSend {}
54
55#[cfg(not(target_arch = "wasm32"))]
56impl<T: Send> MaybeSend for T {}
57#[cfg(target_arch = "wasm32")]
58impl<T> MaybeSend for T {}
59
60/// ConcurrentTasks is used to execute tasks concurrently.
61///
62/// ConcurrentTasks has two generic types:
63///
64/// - `I` represents the input type of the task.
65/// - `O` represents the output type of the task.
66///
67/// # Implementation Notes
68///
69/// The code patterns below are intentional; please do not modify them unless you fully understand these notes.
70///
71/// ```skip
72///  let (i, o) = self
73///     .tasks
74///     .front_mut()                                        // Use `front_mut` instead of `pop_front`
75///     .expect("tasks must be available")
76///     .await;
77/// ...
78/// match o {
79///     Ok(o) => {
80///         let _ = self.tasks.pop_front();                 // `pop_front` after got `Ok(o)`
81///         self.results.push_back(o)
82///     }
83///     Err(err) => {
84///         if err.is_temporary() {
85///             let task = self.create_task(i);
86///             self.tasks
87///                 .front_mut()
88///                 .expect("tasks must be available")
89///                 .replace(task)                          // Use replace here to instead of `push_front`
90///         } else {
91///             self.clear();
92///             self.errored = true;
93///         }
94///         return Err(err);
95///     }
96/// }
97/// ```
98///
99/// Please keep in mind that there is no guarantee the task will be `await`ed until completion. It's possible
100/// the task may be dropped before it resolves. Therefore, we should keep the `Task` in the `tasks` queue until
101/// it is resolved.
102///
103/// For example, users may have a timeout for the task, and the task will be dropped if it exceeds the timeout.
104/// If we `pop_front` the task before it resolves, the task will be canceled and the result will be lost.
105pub struct ConcurrentTasks<I, O> {
106    /// The executor to execute the tasks.
107    ///
108    /// If user doesn't provide an executor, the tasks will be executed with the default executor.
109    executor: Executor,
110    /// The factory to create the task.
111    ///
112    /// Caller of ConcurrentTasks must provides a factory to create the task for executing.
113    ///
114    /// The factory must accept an input and return a future that resolves to a tuple of input and
115    /// output result. If the given result is error, the error will be returned to users and the
116    /// task will be retried.
117    factory: fn(I) -> BoxedStaticFuture<(I, Result<O>)>,
118
119    /// `tasks` holds the ongoing tasks.
120    ///
121    /// Please keep in mind that all tasks are running in the background by `Executor`. We only need
122    /// to poll the tasks to see if they are ready.
123    ///
124    /// Dropping task without `await` it will cancel the task.
125    tasks: VecDeque<Task<(I, Result<O>)>>,
126    /// `results` stores the successful results.
127    results: VecDeque<O>,
128
129    /// The maximum number of concurrent tasks.
130    concurrent: usize,
131    /// Tracks the number of tasks that have finished execution but have not yet been collected.
132    /// This count is subtracted from the total concurrency capacity, ensuring that the system
133    /// always schedules new tasks to maintain the user's desired concurrency level.
134    ///
135    /// Example: If `concurrency = 10` and `completed_but_unretrieved = 3`,
136    ///          the system can still spawn 7 new tasks (since 3 slots are "logically occupied"
137    ///          by uncollected results).
138    completed_but_unretrieved: Arc<AtomicUsize>,
139    /// hitting the last unrecoverable error.
140    ///
141    /// If concurrent tasks hit an unrecoverable error, it will stop executing new tasks and return
142    /// an unrecoverable error to users.
143    errored: bool,
144}
145
146impl<I: Send + 'static, O: Send + 'static> ConcurrentTasks<I, O> {
147    /// Create a new concurrent tasks with given executor, concurrent and factory.
148    ///
149    /// The factory is a function pointer that shouldn't capture any context.
150    pub fn new(
151        executor: Executor,
152        concurrent: usize,
153        factory: fn(I) -> BoxedStaticFuture<(I, Result<O>)>,
154    ) -> Self {
155        Self {
156            executor,
157            factory,
158
159            tasks: VecDeque::with_capacity(concurrent),
160            results: VecDeque::with_capacity(concurrent),
161            concurrent,
162            completed_but_unretrieved: Arc::default(),
163            errored: false,
164        }
165    }
166
167    /// Return true if the tasks are running concurrently.
168    #[inline]
169    fn is_concurrent(&self) -> bool {
170        self.concurrent > 1
171    }
172
173    /// Clear all tasks and results.
174    ///
175    /// All ongoing tasks will be canceled.
176    pub fn clear(&mut self) {
177        self.tasks.clear();
178        self.results.clear();
179    }
180
181    /// Check if there are remaining space to push new tasks.
182    #[inline]
183    pub fn has_remaining(&self) -> bool {
184        self.tasks.len() < self.concurrent + self.completed_but_unretrieved.load(Ordering::Relaxed)
185    }
186
187    /// Chunk if there are remaining results to fetch.
188    #[inline]
189    pub fn has_result(&self) -> bool {
190        !self.results.is_empty()
191    }
192
193    /// Create a task with given input.
194    pub fn create_task(&self, input: I) -> Task<(I, Result<O>)> {
195        let completed = self.completed_but_unretrieved.clone();
196
197        let fut = (self.factory)(input).inspect(move |_| {
198            completed.fetch_add(1, Ordering::Relaxed);
199        });
200
201        self.executor.execute(fut)
202    }
203
204    /// Execute the task with given input.
205    ///
206    /// - Execute the task in the current thread if is not concurrent.
207    /// - Execute the task in the background if there are available slots.
208    /// - Await the first task in the queue if there is no available slots.
209    pub async fn execute(&mut self, input: I) -> Result<()> {
210        if self.errored {
211            return Err(Error::new(
212                ErrorKind::Unexpected,
213                "concurrent tasks met an unrecoverable error",
214            ));
215        }
216
217        // Short path for non-concurrent case.
218        if !self.is_concurrent() {
219            let (_, o) = (self.factory)(input).await;
220            return match o {
221                Ok(o) => {
222                    self.results.push_back(o);
223                    Ok(())
224                }
225                // We don't need to rebuild the future if it's not concurrent.
226                Err(err) => Err(err),
227            };
228        }
229
230        if !self.has_remaining() {
231            let (i, o) = self
232                .tasks
233                .front_mut()
234                .expect("tasks must be available")
235                .await;
236            self.completed_but_unretrieved
237                .fetch_sub(1, Ordering::Relaxed);
238            match o {
239                Ok(o) => {
240                    let _ = self.tasks.pop_front();
241                    self.results.push_back(o)
242                }
243                Err(err) => {
244                    // Retry this task if the error is temporary
245                    if err.is_temporary() {
246                        let task = self.create_task(i);
247                        self.tasks
248                            .front_mut()
249                            .expect("tasks must be available")
250                            .replace(task)
251                    } else {
252                        self.clear();
253                        self.errored = true;
254                    }
255                    return Err(err);
256                }
257            }
258        }
259
260        self.tasks.push_back(self.create_task(input));
261        Ok(())
262    }
263
264    /// Fetch the successful result from the result queue.
265    pub async fn next(&mut self) -> Option<Result<O>> {
266        if self.errored {
267            return Some(Err(Error::new(
268                ErrorKind::Unexpected,
269                "concurrent tasks met an unrecoverable error",
270            )));
271        }
272
273        if let Some(result) = self.results.pop_front() {
274            return Some(Ok(result));
275        }
276
277        if let Some(task) = self.tasks.front_mut() {
278            let (i, o) = task.await;
279            self.completed_but_unretrieved
280                .fetch_sub(1, Ordering::Relaxed);
281            return match o {
282                Ok(o) => {
283                    let _ = self.tasks.pop_front();
284                    Some(Ok(o))
285                }
286                Err(err) => {
287                    // Retry this task if the error is temporary
288                    if err.is_temporary() {
289                        let task = self.create_task(i);
290                        self.tasks
291                            .front_mut()
292                            .expect("tasks must be available")
293                            .replace(task)
294                    } else {
295                        self.clear();
296                        self.errored = true;
297                    }
298                    Some(Err(err))
299                }
300            };
301        }
302
303        None
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use std::time::Duration;
310
311    use pretty_assertions::assert_eq;
312    use rand::Rng;
313    use tokio::time::sleep;
314
315    use super::*;
316
317    #[tokio::test]
318    async fn test_concurrent_tasks() {
319        let executor = Executor::new();
320
321        let mut tasks = ConcurrentTasks::new(executor, 16, |(i, dur)| {
322            Box::pin(async move {
323                sleep(dur).await;
324
325                // 5% rate to fail.
326                if rand::thread_rng().gen_range(0..100) > 90 {
327                    return (
328                        (i, dur),
329                        Err(Error::new(ErrorKind::Unexpected, "I'm lucky").set_temporary()),
330                    );
331                }
332                ((i, dur), Ok(i))
333            })
334        });
335
336        let mut ans = vec![];
337
338        for i in 0..10240 {
339            // Sleep up to 10ms
340            let dur = Duration::from_millis(rand::thread_rng().gen_range(0..10));
341            loop {
342                let res = tasks.execute((i, dur)).await;
343                if res.is_ok() {
344                    break;
345                }
346            }
347        }
348
349        loop {
350            match tasks.next().await.transpose() {
351                Ok(Some(i)) => ans.push(i),
352                Ok(None) => break,
353                Err(_) => continue,
354            }
355        }
356
357        assert_eq!(ans, (0..10240).collect::<Vec<_>>())
358    }
359}