opendal/raw/futures_util.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::VecDeque;
19use std::sync::atomic::AtomicUsize;
20use std::sync::atomic::Ordering;
21use std::sync::Arc;
22
23use futures::FutureExt;
24
25use crate::*;
26
27/// BoxedFuture is the type alias of [`futures::future::BoxFuture`].
28///
29/// We will switch to [`futures::future::LocalBoxFuture`] on wasm32 target.
30#[cfg(not(target_arch = "wasm32"))]
31pub type BoxedFuture<'a, T> = futures::future::BoxFuture<'a, T>;
32#[cfg(target_arch = "wasm32")]
33pub type BoxedFuture<'a, T> = futures::future::LocalBoxFuture<'a, T>;
34
35/// BoxedStaticFuture is the type alias of [`futures::future::BoxFuture`].
36///
37/// We will switch to [`futures::future::LocalBoxFuture`] on wasm32 target.
38#[cfg(not(target_arch = "wasm32"))]
39pub type BoxedStaticFuture<T> = futures::future::BoxFuture<'static, T>;
40#[cfg(target_arch = "wasm32")]
41pub type BoxedStaticFuture<T> = futures::future::LocalBoxFuture<'static, T>;
42
43/// MaybeSend is a marker to determine whether a type is `Send` or not.
44/// We use this trait to wrap the `Send` requirement for wasm32 target.
45///
46/// # Safety
47///
48/// [`MaybeSend`] is equivalent to `Send` on non-wasm32 target.
49/// And it's empty trait on wasm32 target to indicate that a type is not `Send`.
50#[cfg(not(target_arch = "wasm32"))]
51pub trait MaybeSend: Send {}
52#[cfg(target_arch = "wasm32")]
53pub trait MaybeSend {}
54
55#[cfg(not(target_arch = "wasm32"))]
56impl<T: Send> MaybeSend for T {}
57#[cfg(target_arch = "wasm32")]
58impl<T> MaybeSend for T {}
59
60/// ConcurrentTasks is used to execute tasks concurrently.
61///
62/// ConcurrentTasks has two generic types:
63///
64/// - `I` represents the input type of the task.
65/// - `O` represents the output type of the task.
66///
67/// # Implementation Notes
68///
69/// The code patterns below are intentional; please do not modify them unless you fully understand these notes.
70///
71/// ```skip
72/// let (i, o) = self
73/// .tasks
74/// .front_mut() // Use `front_mut` instead of `pop_front`
75/// .expect("tasks must be available")
76/// .await;
77/// ...
78/// match o {
79/// Ok(o) => {
80/// let _ = self.tasks.pop_front(); // `pop_front` after got `Ok(o)`
81/// self.results.push_back(o)
82/// }
83/// Err(err) => {
84/// if err.is_temporary() {
85/// let task = self.create_task(i);
86/// self.tasks
87/// .front_mut()
88/// .expect("tasks must be available")
89/// .replace(task) // Use replace here to instead of `push_front`
90/// } else {
91/// self.clear();
92/// self.errored = true;
93/// }
94/// return Err(err);
95/// }
96/// }
97/// ```
98///
99/// Please keep in mind that there is no guarantee the task will be `await`ed until completion. It's possible
100/// the task may be dropped before it resolves. Therefore, we should keep the `Task` in the `tasks` queue until
101/// it is resolved.
102///
103/// For example, users may have a timeout for the task, and the task will be dropped if it exceeds the timeout.
104/// If we `pop_front` the task before it resolves, the task will be canceled and the result will be lost.
105pub struct ConcurrentTasks<I, O> {
106 /// The executor to execute the tasks.
107 ///
108 /// If user doesn't provide an executor, the tasks will be executed with the default executor.
109 executor: Executor,
110 /// The factory to create the task.
111 ///
112 /// Caller of ConcurrentTasks must provides a factory to create the task for executing.
113 ///
114 /// The factory must accept an input and return a future that resolves to a tuple of input and
115 /// output result. If the given result is error, the error will be returned to users and the
116 /// task will be retried.
117 factory: fn(I) -> BoxedStaticFuture<(I, Result<O>)>,
118
119 /// `tasks` holds the ongoing tasks.
120 ///
121 /// Please keep in mind that all tasks are running in the background by `Executor`. We only need
122 /// to poll the tasks to see if they are ready.
123 ///
124 /// Dropping task without `await` it will cancel the task.
125 tasks: VecDeque<Task<(I, Result<O>)>>,
126 /// `results` stores the successful results.
127 results: VecDeque<O>,
128
129 /// The maximum number of concurrent tasks.
130 concurrent: usize,
131 /// Tracks the number of tasks that have finished execution but have not yet been collected.
132 /// This count is subtracted from the total concurrency capacity, ensuring that the system
133 /// always schedules new tasks to maintain the user's desired concurrency level.
134 ///
135 /// Example: If `concurrency = 10` and `completed_but_unretrieved = 3`,
136 /// the system can still spawn 7 new tasks (since 3 slots are "logically occupied"
137 /// by uncollected results).
138 completed_but_unretrieved: Arc<AtomicUsize>,
139 /// hitting the last unrecoverable error.
140 ///
141 /// If concurrent tasks hit an unrecoverable error, it will stop executing new tasks and return
142 /// an unrecoverable error to users.
143 errored: bool,
144}
145
146impl<I: Send + 'static, O: Send + 'static> ConcurrentTasks<I, O> {
147 /// Create a new concurrent tasks with given executor, concurrent and factory.
148 ///
149 /// The factory is a function pointer that shouldn't capture any context.
150 pub fn new(
151 executor: Executor,
152 concurrent: usize,
153 factory: fn(I) -> BoxedStaticFuture<(I, Result<O>)>,
154 ) -> Self {
155 Self {
156 executor,
157 factory,
158
159 tasks: VecDeque::with_capacity(concurrent),
160 results: VecDeque::with_capacity(concurrent),
161 concurrent,
162 completed_but_unretrieved: Arc::default(),
163 errored: false,
164 }
165 }
166
167 /// Return true if the tasks are running concurrently.
168 #[inline]
169 fn is_concurrent(&self) -> bool {
170 self.concurrent > 1
171 }
172
173 /// Clear all tasks and results.
174 ///
175 /// All ongoing tasks will be canceled.
176 pub fn clear(&mut self) {
177 self.tasks.clear();
178 self.results.clear();
179 }
180
181 /// Check if there are remaining space to push new tasks.
182 #[inline]
183 pub fn has_remaining(&self) -> bool {
184 self.tasks.len() < self.concurrent + self.completed_but_unretrieved.load(Ordering::Relaxed)
185 }
186
187 /// Chunk if there are remaining results to fetch.
188 #[inline]
189 pub fn has_result(&self) -> bool {
190 !self.results.is_empty()
191 }
192
193 /// Create a task with given input.
194 pub fn create_task(&self, input: I) -> Task<(I, Result<O>)> {
195 let completed = self.completed_but_unretrieved.clone();
196
197 let fut = (self.factory)(input).inspect(move |_| {
198 completed.fetch_add(1, Ordering::Relaxed);
199 });
200
201 self.executor.execute(fut)
202 }
203
204 /// Execute the task with given input.
205 ///
206 /// - Execute the task in the current thread if is not concurrent.
207 /// - Execute the task in the background if there are available slots.
208 /// - Await the first task in the queue if there is no available slots.
209 pub async fn execute(&mut self, input: I) -> Result<()> {
210 if self.errored {
211 return Err(Error::new(
212 ErrorKind::Unexpected,
213 "concurrent tasks met an unrecoverable error",
214 ));
215 }
216
217 // Short path for non-concurrent case.
218 if !self.is_concurrent() {
219 let (_, o) = (self.factory)(input).await;
220 return match o {
221 Ok(o) => {
222 self.results.push_back(o);
223 Ok(())
224 }
225 // We don't need to rebuild the future if it's not concurrent.
226 Err(err) => Err(err),
227 };
228 }
229
230 if !self.has_remaining() {
231 let (i, o) = self
232 .tasks
233 .front_mut()
234 .expect("tasks must be available")
235 .await;
236 self.completed_but_unretrieved
237 .fetch_sub(1, Ordering::Relaxed);
238 match o {
239 Ok(o) => {
240 let _ = self.tasks.pop_front();
241 self.results.push_back(o)
242 }
243 Err(err) => {
244 // Retry this task if the error is temporary
245 if err.is_temporary() {
246 let task = self.create_task(i);
247 self.tasks
248 .front_mut()
249 .expect("tasks must be available")
250 .replace(task)
251 } else {
252 self.clear();
253 self.errored = true;
254 }
255 return Err(err);
256 }
257 }
258 }
259
260 self.tasks.push_back(self.create_task(input));
261 Ok(())
262 }
263
264 /// Fetch the successful result from the result queue.
265 pub async fn next(&mut self) -> Option<Result<O>> {
266 if self.errored {
267 return Some(Err(Error::new(
268 ErrorKind::Unexpected,
269 "concurrent tasks met an unrecoverable error",
270 )));
271 }
272
273 if let Some(result) = self.results.pop_front() {
274 return Some(Ok(result));
275 }
276
277 if let Some(task) = self.tasks.front_mut() {
278 let (i, o) = task.await;
279 self.completed_but_unretrieved
280 .fetch_sub(1, Ordering::Relaxed);
281 return match o {
282 Ok(o) => {
283 let _ = self.tasks.pop_front();
284 Some(Ok(o))
285 }
286 Err(err) => {
287 // Retry this task if the error is temporary
288 if err.is_temporary() {
289 let task = self.create_task(i);
290 self.tasks
291 .front_mut()
292 .expect("tasks must be available")
293 .replace(task)
294 } else {
295 self.clear();
296 self.errored = true;
297 }
298 Some(Err(err))
299 }
300 };
301 }
302
303 None
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use std::time::Duration;
310
311 use pretty_assertions::assert_eq;
312 use rand::Rng;
313 use tokio::time::sleep;
314
315 use super::*;
316
317 #[tokio::test]
318 async fn test_concurrent_tasks() {
319 let executor = Executor::new();
320
321 let mut tasks = ConcurrentTasks::new(executor, 16, |(i, dur)| {
322 Box::pin(async move {
323 sleep(dur).await;
324
325 // 5% rate to fail.
326 if rand::thread_rng().gen_range(0..100) > 90 {
327 return (
328 (i, dur),
329 Err(Error::new(ErrorKind::Unexpected, "I'm lucky").set_temporary()),
330 );
331 }
332 ((i, dur), Ok(i))
333 })
334 });
335
336 let mut ans = vec![];
337
338 for i in 0..10240 {
339 // Sleep up to 10ms
340 let dur = Duration::from_millis(rand::thread_rng().gen_range(0..10));
341 loop {
342 let res = tasks.execute((i, dur)).await;
343 if res.is_ok() {
344 break;
345 }
346 }
347 }
348
349 loop {
350 match tasks.next().await.transpose() {
351 Ok(Some(i)) => ans.push(i),
352 Ok(None) => break,
353 Err(_) => continue,
354 }
355 }
356
357 assert_eq!(ans, (0..10240).collect::<Vec<_>>())
358 }
359}