opendal/raw/
futures_util.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{
19    collections::VecDeque,
20    sync::atomic::Ordering,
21    sync::{atomic::AtomicUsize, Arc},
22};
23
24use futures::FutureExt;
25
26use crate::*;
27
28/// BoxedFuture is the type alias of [`futures::future::BoxFuture`].
29///
30/// We will switch to [`futures::future::LocalBoxFuture`] on wasm32 target.
31#[cfg(not(target_arch = "wasm32"))]
32pub type BoxedFuture<'a, T> = futures::future::BoxFuture<'a, T>;
33#[cfg(target_arch = "wasm32")]
34pub type BoxedFuture<'a, T> = futures::future::LocalBoxFuture<'a, T>;
35
36/// BoxedStaticFuture is the type alias of [`futures::future::BoxFuture`].
37///
38/// We will switch to [`futures::future::LocalBoxFuture`] on wasm32 target.
39#[cfg(not(target_arch = "wasm32"))]
40pub type BoxedStaticFuture<T> = futures::future::BoxFuture<'static, T>;
41#[cfg(target_arch = "wasm32")]
42pub type BoxedStaticFuture<T> = futures::future::LocalBoxFuture<'static, T>;
43
44/// MaybeSend is a marker to determine whether a type is `Send` or not.
45/// We use this trait to wrap the `Send` requirement for wasm32 target.
46///
47/// # Safety
48///
49/// [`MaybeSend`] is equivalent to `Send` on non-wasm32 target.
50/// And it's empty trait on wasm32 target to indicate that a type is not `Send`.
51#[cfg(not(target_arch = "wasm32"))]
52pub trait MaybeSend: Send {}
53#[cfg(target_arch = "wasm32")]
54pub trait MaybeSend {}
55
56#[cfg(not(target_arch = "wasm32"))]
57impl<T: Send> MaybeSend for T {}
58#[cfg(target_arch = "wasm32")]
59impl<T> MaybeSend for T {}
60
61/// ConcurrentTasks is used to execute tasks concurrently.
62///
63/// ConcurrentTasks has two generic types:
64///
65/// - `I` represents the input type of the task.
66/// - `O` represents the output type of the task.
67///
68/// # Implementation Notes
69///
70/// The code patterns below are intentional; please do not modify them unless you fully understand these notes.
71///
72/// ```skip
73///  let (i, o) = self
74///     .tasks
75///     .front_mut()                                        // Use `front_mut` instead of `pop_front`
76///     .expect("tasks must be available")
77///     .await;
78/// ...
79/// match o {
80///     Ok(o) => {
81///         let _ = self.tasks.pop_front();                 // `pop_front` after got `Ok(o)`
82///         self.results.push_back(o)
83///     }
84///     Err(err) => {
85///         if err.is_temporary() {
86///             let task = self.create_task(i);
87///             self.tasks
88///                 .front_mut()
89///                 .expect("tasks must be available")
90///                 .replace(task)                          // Use replace here to instead of `push_front`
91///         } else {
92///             self.clear();
93///             self.errored = true;
94///         }
95///         return Err(err);
96///     }
97/// }
98/// ```
99///
100/// Please keep in mind that there is no guarantee the task will be `await`ed until completion. It's possible
101/// the task may be dropped before it resolves. Therefore, we should keep the `Task` in the `tasks` queue until
102/// it is resolved.
103///
104/// For example, users may have a timeout for the task, and the task will be dropped if it exceeds the timeout.
105/// If we `pop_front` the task before it resolves, the task will be canceled and the result will be lost.
106pub struct ConcurrentTasks<I, O> {
107    /// The executor to execute the tasks.
108    ///
109    /// If user doesn't provide an executor, the tasks will be executed with the default executor.
110    executor: Executor,
111    /// The factory to create the task.
112    ///
113    /// Caller of ConcurrentTasks must provides a factory to create the task for executing.
114    ///
115    /// The factory must accept an input and return a future that resolves to a tuple of input and
116    /// output result. If the given result is error, the error will be returned to users and the
117    /// task will be retried.
118    factory: fn(I) -> BoxedStaticFuture<(I, Result<O>)>,
119
120    /// `tasks` holds the ongoing tasks.
121    ///
122    /// Please keep in mind that all tasks are running in the background by `Executor`. We only need
123    /// to poll the tasks to see if they are ready.
124    ///
125    /// Dropping task without `await` it will cancel the task.
126    tasks: VecDeque<Task<(I, Result<O>)>>,
127    /// `results` stores the successful results.
128    results: VecDeque<O>,
129
130    /// The maximum number of concurrent tasks.
131    concurrent: usize,
132    /// Tracks the number of tasks that have finished execution but have not yet been collected.
133    /// This count is subtracted from the total concurrency capacity, ensuring that the system
134    /// always schedules new tasks to maintain the user's desired concurrency level.
135    ///
136    /// Example: If `concurrency = 10` and `completed_but_unretrieved = 3`,
137    ///          the system can still spawn 7 new tasks (since 3 slots are "logically occupied"
138    ///          by uncollected results).
139    completed_but_unretrieved: Arc<AtomicUsize>,
140    /// hitting the last unrecoverable error.
141    ///
142    /// If concurrent tasks hit an unrecoverable error, it will stop executing new tasks and return
143    /// an unrecoverable error to users.
144    errored: bool,
145}
146
147impl<I: Send + 'static, O: Send + 'static> ConcurrentTasks<I, O> {
148    /// Create a new concurrent tasks with given executor, concurrent and factory.
149    ///
150    /// The factory is a function pointer that shouldn't capture any context.
151    pub fn new(
152        executor: Executor,
153        concurrent: usize,
154        factory: fn(I) -> BoxedStaticFuture<(I, Result<O>)>,
155    ) -> Self {
156        Self {
157            executor,
158            factory,
159
160            tasks: VecDeque::with_capacity(concurrent),
161            results: VecDeque::with_capacity(concurrent),
162            concurrent,
163            completed_but_unretrieved: Arc::default(),
164            errored: false,
165        }
166    }
167
168    /// Return true if the tasks are running concurrently.
169    #[inline]
170    fn is_concurrent(&self) -> bool {
171        self.concurrent > 1
172    }
173
174    /// Clear all tasks and results.
175    ///
176    /// All ongoing tasks will be canceled.
177    pub fn clear(&mut self) {
178        self.tasks.clear();
179        self.results.clear();
180    }
181
182    /// Check if there are remaining space to push new tasks.
183    #[inline]
184    pub fn has_remaining(&self) -> bool {
185        self.tasks.len() < self.concurrent + self.completed_but_unretrieved.load(Ordering::Relaxed)
186    }
187
188    /// Chunk if there are remaining results to fetch.
189    #[inline]
190    pub fn has_result(&self) -> bool {
191        !self.results.is_empty()
192    }
193
194    /// Create a task with given input.
195    pub fn create_task(&self, input: I) -> Task<(I, Result<O>)> {
196        let completed = self.completed_but_unretrieved.clone();
197
198        let fut = (self.factory)(input).inspect(move |_| {
199            completed.fetch_add(1, Ordering::Relaxed);
200        });
201
202        self.executor.execute(fut)
203    }
204
205    /// Execute the task with given input.
206    ///
207    /// - Execute the task in the current thread if is not concurrent.
208    /// - Execute the task in the background if there are available slots.
209    /// - Await the first task in the queue if there is no available slots.
210    pub async fn execute(&mut self, input: I) -> Result<()> {
211        if self.errored {
212            return Err(Error::new(
213                ErrorKind::Unexpected,
214                "concurrent tasks met an unrecoverable error",
215            ));
216        }
217
218        // Short path for non-concurrent case.
219        if !self.is_concurrent() {
220            let (_, o) = (self.factory)(input).await;
221            return match o {
222                Ok(o) => {
223                    self.results.push_back(o);
224                    Ok(())
225                }
226                // We don't need to rebuild the future if it's not concurrent.
227                Err(err) => Err(err),
228            };
229        }
230
231        if !self.has_remaining() {
232            let (i, o) = self
233                .tasks
234                .front_mut()
235                .expect("tasks must be available")
236                .await;
237            self.completed_but_unretrieved
238                .fetch_sub(1, Ordering::Relaxed);
239            match o {
240                Ok(o) => {
241                    let _ = self.tasks.pop_front();
242                    self.results.push_back(o)
243                }
244                Err(err) => {
245                    // Retry this task if the error is temporary
246                    if err.is_temporary() {
247                        let task = self.create_task(i);
248                        self.tasks
249                            .front_mut()
250                            .expect("tasks must be available")
251                            .replace(task)
252                    } else {
253                        self.clear();
254                        self.errored = true;
255                    }
256                    return Err(err);
257                }
258            }
259        }
260
261        self.tasks.push_back(self.create_task(input));
262        Ok(())
263    }
264
265    /// Fetch the successful result from the result queue.
266    pub async fn next(&mut self) -> Option<Result<O>> {
267        if self.errored {
268            return Some(Err(Error::new(
269                ErrorKind::Unexpected,
270                "concurrent tasks met an unrecoverable error",
271            )));
272        }
273
274        if let Some(result) = self.results.pop_front() {
275            return Some(Ok(result));
276        }
277
278        if let Some(task) = self.tasks.front_mut() {
279            let (i, o) = task.await;
280            self.completed_but_unretrieved
281                .fetch_sub(1, Ordering::Relaxed);
282            return match o {
283                Ok(o) => {
284                    let _ = self.tasks.pop_front();
285                    Some(Ok(o))
286                }
287                Err(err) => {
288                    // Retry this task if the error is temporary
289                    if err.is_temporary() {
290                        let task = self.create_task(i);
291                        self.tasks
292                            .front_mut()
293                            .expect("tasks must be available")
294                            .replace(task)
295                    } else {
296                        self.clear();
297                        self.errored = true;
298                    }
299                    Some(Err(err))
300                }
301            };
302        }
303
304        None
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use std::time::Duration;
311
312    use rand::Rng;
313    use tokio::time::sleep;
314
315    use super::*;
316    use pretty_assertions::assert_eq;
317
318    #[tokio::test]
319    async fn test_concurrent_tasks() {
320        let executor = Executor::new();
321
322        let mut tasks = ConcurrentTasks::new(executor, 16, |(i, dur)| {
323            Box::pin(async move {
324                sleep(dur).await;
325
326                // 5% rate to fail.
327                if rand::thread_rng().gen_range(0..100) > 90 {
328                    return (
329                        (i, dur),
330                        Err(Error::new(ErrorKind::Unexpected, "I'm lucky").set_temporary()),
331                    );
332                }
333                ((i, dur), Ok(i))
334            })
335        });
336
337        let mut ans = vec![];
338
339        for i in 0..10240 {
340            // Sleep up to 10ms
341            let dur = Duration::from_millis(rand::thread_rng().gen_range(0..10));
342            loop {
343                let res = tasks.execute((i, dur)).await;
344                if res.is_ok() {
345                    break;
346                }
347            }
348        }
349
350        loop {
351            match tasks.next().await.transpose() {
352                Ok(Some(i)) => ans.push(i),
353                Ok(None) => break,
354                Err(_) => continue,
355            }
356        }
357
358        assert_eq!(ans, (0..10240).collect::<Vec<_>>())
359    }
360}