opendal/raw/futures_util.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{
19 collections::VecDeque,
20 sync::atomic::Ordering,
21 sync::{atomic::AtomicUsize, Arc},
22};
23
24use futures::FutureExt;
25
26use crate::*;
27
28/// BoxedFuture is the type alias of [`futures::future::BoxFuture`].
29///
30/// We will switch to [`futures::future::LocalBoxFuture`] on wasm32 target.
31#[cfg(not(target_arch = "wasm32"))]
32pub type BoxedFuture<'a, T> = futures::future::BoxFuture<'a, T>;
33#[cfg(target_arch = "wasm32")]
34pub type BoxedFuture<'a, T> = futures::future::LocalBoxFuture<'a, T>;
35
36/// BoxedStaticFuture is the type alias of [`futures::future::BoxFuture`].
37///
38/// We will switch to [`futures::future::LocalBoxFuture`] on wasm32 target.
39#[cfg(not(target_arch = "wasm32"))]
40pub type BoxedStaticFuture<T> = futures::future::BoxFuture<'static, T>;
41#[cfg(target_arch = "wasm32")]
42pub type BoxedStaticFuture<T> = futures::future::LocalBoxFuture<'static, T>;
43
44/// MaybeSend is a marker to determine whether a type is `Send` or not.
45/// We use this trait to wrap the `Send` requirement for wasm32 target.
46///
47/// # Safety
48///
49/// [`MaybeSend`] is equivalent to `Send` on non-wasm32 target.
50/// And it's empty trait on wasm32 target to indicate that a type is not `Send`.
51#[cfg(not(target_arch = "wasm32"))]
52pub trait MaybeSend: Send {}
53#[cfg(target_arch = "wasm32")]
54pub trait MaybeSend {}
55
56#[cfg(not(target_arch = "wasm32"))]
57impl<T: Send> MaybeSend for T {}
58#[cfg(target_arch = "wasm32")]
59impl<T> MaybeSend for T {}
60
61/// ConcurrentTasks is used to execute tasks concurrently.
62///
63/// ConcurrentTasks has two generic types:
64///
65/// - `I` represents the input type of the task.
66/// - `O` represents the output type of the task.
67///
68/// # Implementation Notes
69///
70/// The code patterns below are intentional; please do not modify them unless you fully understand these notes.
71///
72/// ```skip
73/// let (i, o) = self
74/// .tasks
75/// .front_mut() // Use `front_mut` instead of `pop_front`
76/// .expect("tasks must be available")
77/// .await;
78/// ...
79/// match o {
80/// Ok(o) => {
81/// let _ = self.tasks.pop_front(); // `pop_front` after got `Ok(o)`
82/// self.results.push_back(o)
83/// }
84/// Err(err) => {
85/// if err.is_temporary() {
86/// let task = self.create_task(i);
87/// self.tasks
88/// .front_mut()
89/// .expect("tasks must be available")
90/// .replace(task) // Use replace here to instead of `push_front`
91/// } else {
92/// self.clear();
93/// self.errored = true;
94/// }
95/// return Err(err);
96/// }
97/// }
98/// ```
99///
100/// Please keep in mind that there is no guarantee the task will be `await`ed until completion. It's possible
101/// the task may be dropped before it resolves. Therefore, we should keep the `Task` in the `tasks` queue until
102/// it is resolved.
103///
104/// For example, users may have a timeout for the task, and the task will be dropped if it exceeds the timeout.
105/// If we `pop_front` the task before it resolves, the task will be canceled and the result will be lost.
106pub struct ConcurrentTasks<I, O> {
107 /// The executor to execute the tasks.
108 ///
109 /// If user doesn't provide an executor, the tasks will be executed with the default executor.
110 executor: Executor,
111 /// The factory to create the task.
112 ///
113 /// Caller of ConcurrentTasks must provides a factory to create the task for executing.
114 ///
115 /// The factory must accept an input and return a future that resolves to a tuple of input and
116 /// output result. If the given result is error, the error will be returned to users and the
117 /// task will be retried.
118 factory: fn(I) -> BoxedStaticFuture<(I, Result<O>)>,
119
120 /// `tasks` holds the ongoing tasks.
121 ///
122 /// Please keep in mind that all tasks are running in the background by `Executor`. We only need
123 /// to poll the tasks to see if they are ready.
124 ///
125 /// Dropping task without `await` it will cancel the task.
126 tasks: VecDeque<Task<(I, Result<O>)>>,
127 /// `results` stores the successful results.
128 results: VecDeque<O>,
129
130 /// The maximum number of concurrent tasks.
131 concurrent: usize,
132 /// Tracks the number of tasks that have finished execution but have not yet been collected.
133 /// This count is subtracted from the total concurrency capacity, ensuring that the system
134 /// always schedules new tasks to maintain the user's desired concurrency level.
135 ///
136 /// Example: If `concurrency = 10` and `completed_but_unretrieved = 3`,
137 /// the system can still spawn 7 new tasks (since 3 slots are "logically occupied"
138 /// by uncollected results).
139 completed_but_unretrieved: Arc<AtomicUsize>,
140 /// hitting the last unrecoverable error.
141 ///
142 /// If concurrent tasks hit an unrecoverable error, it will stop executing new tasks and return
143 /// an unrecoverable error to users.
144 errored: bool,
145}
146
147impl<I: Send + 'static, O: Send + 'static> ConcurrentTasks<I, O> {
148 /// Create a new concurrent tasks with given executor, concurrent and factory.
149 ///
150 /// The factory is a function pointer that shouldn't capture any context.
151 pub fn new(
152 executor: Executor,
153 concurrent: usize,
154 factory: fn(I) -> BoxedStaticFuture<(I, Result<O>)>,
155 ) -> Self {
156 Self {
157 executor,
158 factory,
159
160 tasks: VecDeque::with_capacity(concurrent),
161 results: VecDeque::with_capacity(concurrent),
162 concurrent,
163 completed_but_unretrieved: Arc::default(),
164 errored: false,
165 }
166 }
167
168 /// Return true if the tasks are running concurrently.
169 #[inline]
170 fn is_concurrent(&self) -> bool {
171 self.concurrent > 1
172 }
173
174 /// Clear all tasks and results.
175 ///
176 /// All ongoing tasks will be canceled.
177 pub fn clear(&mut self) {
178 self.tasks.clear();
179 self.results.clear();
180 }
181
182 /// Check if there are remaining space to push new tasks.
183 #[inline]
184 pub fn has_remaining(&self) -> bool {
185 self.tasks.len() < self.concurrent + self.completed_but_unretrieved.load(Ordering::Relaxed)
186 }
187
188 /// Chunk if there are remaining results to fetch.
189 #[inline]
190 pub fn has_result(&self) -> bool {
191 !self.results.is_empty()
192 }
193
194 /// Create a task with given input.
195 pub fn create_task(&self, input: I) -> Task<(I, Result<O>)> {
196 let completed = self.completed_but_unretrieved.clone();
197
198 let fut = (self.factory)(input).inspect(move |_| {
199 completed.fetch_add(1, Ordering::Relaxed);
200 });
201
202 self.executor.execute(fut)
203 }
204
205 /// Execute the task with given input.
206 ///
207 /// - Execute the task in the current thread if is not concurrent.
208 /// - Execute the task in the background if there are available slots.
209 /// - Await the first task in the queue if there is no available slots.
210 pub async fn execute(&mut self, input: I) -> Result<()> {
211 if self.errored {
212 return Err(Error::new(
213 ErrorKind::Unexpected,
214 "concurrent tasks met an unrecoverable error",
215 ));
216 }
217
218 // Short path for non-concurrent case.
219 if !self.is_concurrent() {
220 let (_, o) = (self.factory)(input).await;
221 return match o {
222 Ok(o) => {
223 self.results.push_back(o);
224 Ok(())
225 }
226 // We don't need to rebuild the future if it's not concurrent.
227 Err(err) => Err(err),
228 };
229 }
230
231 if !self.has_remaining() {
232 let (i, o) = self
233 .tasks
234 .front_mut()
235 .expect("tasks must be available")
236 .await;
237 self.completed_but_unretrieved
238 .fetch_sub(1, Ordering::Relaxed);
239 match o {
240 Ok(o) => {
241 let _ = self.tasks.pop_front();
242 self.results.push_back(o)
243 }
244 Err(err) => {
245 // Retry this task if the error is temporary
246 if err.is_temporary() {
247 let task = self.create_task(i);
248 self.tasks
249 .front_mut()
250 .expect("tasks must be available")
251 .replace(task)
252 } else {
253 self.clear();
254 self.errored = true;
255 }
256 return Err(err);
257 }
258 }
259 }
260
261 self.tasks.push_back(self.create_task(input));
262 Ok(())
263 }
264
265 /// Fetch the successful result from the result queue.
266 pub async fn next(&mut self) -> Option<Result<O>> {
267 if self.errored {
268 return Some(Err(Error::new(
269 ErrorKind::Unexpected,
270 "concurrent tasks met an unrecoverable error",
271 )));
272 }
273
274 if let Some(result) = self.results.pop_front() {
275 return Some(Ok(result));
276 }
277
278 if let Some(task) = self.tasks.front_mut() {
279 let (i, o) = task.await;
280 self.completed_but_unretrieved
281 .fetch_sub(1, Ordering::Relaxed);
282 return match o {
283 Ok(o) => {
284 let _ = self.tasks.pop_front();
285 Some(Ok(o))
286 }
287 Err(err) => {
288 // Retry this task if the error is temporary
289 if err.is_temporary() {
290 let task = self.create_task(i);
291 self.tasks
292 .front_mut()
293 .expect("tasks must be available")
294 .replace(task)
295 } else {
296 self.clear();
297 self.errored = true;
298 }
299 Some(Err(err))
300 }
301 };
302 }
303
304 None
305 }
306}
307
308#[cfg(test)]
309mod tests {
310 use std::time::Duration;
311
312 use rand::Rng;
313 use tokio::time::sleep;
314
315 use super::*;
316 use pretty_assertions::assert_eq;
317
318 #[tokio::test]
319 async fn test_concurrent_tasks() {
320 let executor = Executor::new();
321
322 let mut tasks = ConcurrentTasks::new(executor, 16, |(i, dur)| {
323 Box::pin(async move {
324 sleep(dur).await;
325
326 // 5% rate to fail.
327 if rand::thread_rng().gen_range(0..100) > 90 {
328 return (
329 (i, dur),
330 Err(Error::new(ErrorKind::Unexpected, "I'm lucky").set_temporary()),
331 );
332 }
333 ((i, dur), Ok(i))
334 })
335 });
336
337 let mut ans = vec![];
338
339 for i in 0..10240 {
340 // Sleep up to 10ms
341 let dur = Duration::from_millis(rand::thread_rng().gen_range(0..10));
342 loop {
343 let res = tasks.execute((i, dur)).await;
344 if res.is_ok() {
345 break;
346 }
347 }
348 }
349
350 loop {
351 match tasks.next().await.transpose() {
352 Ok(Some(i)) => ans.push(i),
353 Ok(None) => break,
354 Err(_) => continue,
355 }
356 }
357
358 assert_eq!(ans, (0..10240).collect::<Vec<_>>())
359 }
360}