opendal/layers/
concurrent_limit.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::fmt::Debug;
19use std::pin::Pin;
20use std::sync::Arc;
21use std::task::Context;
22use std::task::Poll;
23
24use futures::Stream;
25use futures::StreamExt;
26use tokio::sync::OwnedSemaphorePermit;
27use tokio::sync::Semaphore;
28
29use crate::raw::*;
30use crate::*;
31
32/// Add concurrent request limit.
33///
34/// # Notes
35///
36/// Users can control how many concurrent connections could be established
37/// between OpenDAL and underlying storage services.
38///
39/// All operators wrapped by this layer will share a common semaphore. This
40/// allows you to reuse the same layer across multiple operators, ensuring
41/// that the total number of concurrent requests across the entire
42/// application does not exceed the limit.
43///
44/// # Examples
45///
46/// Add a concurrent limit layer to the operator:
47///
48/// ```no_run
49/// # use opendal::layers::ConcurrentLimitLayer;
50/// # use opendal::services;
51/// # use opendal::Operator;
52/// # use opendal::Result;
53/// # use opendal::Scheme;
54///
55/// # fn main() -> Result<()> {
56/// let _ = Operator::new(services::Memory::default())?
57///     .layer(ConcurrentLimitLayer::new(1024))
58///     .finish();
59/// Ok(())
60/// # }
61/// ```
62///
63/// Share a concurrent limit layer between the operators:
64///
65/// ```no_run
66/// # use opendal::layers::ConcurrentLimitLayer;
67/// # use opendal::services;
68/// # use opendal::Operator;
69/// # use opendal::Result;
70/// # use opendal::Scheme;
71///
72/// # fn main() -> Result<()> {
73/// let limit = ConcurrentLimitLayer::new(1024);
74///
75/// let _operator_a = Operator::new(services::Memory::default())?
76///     .layer(limit.clone())
77///     .finish();
78/// let _operator_b = Operator::new(services::Memory::default())?
79///     .layer(limit.clone())
80///     .finish();
81///
82/// Ok(())
83/// # }
84/// ```
85#[derive(Clone)]
86pub struct ConcurrentLimitLayer {
87    operation_semaphore: Arc<Semaphore>,
88    http_semaphore: Option<Arc<Semaphore>>,
89}
90
91impl ConcurrentLimitLayer {
92    /// Create a new ConcurrentLimitLayer will specify permits.
93    ///
94    /// This permits will applied to all operations.
95    pub fn new(permits: usize) -> Self {
96        Self {
97            operation_semaphore: Arc::new(Semaphore::new(permits)),
98            http_semaphore: None,
99        }
100    }
101
102    /// Set a concurrent limit for HTTP requests.
103    ///
104    /// This will limit the number of concurrent HTTP requests made by the
105    /// operator.
106    pub fn with_http_concurrent_limit(mut self, permits: usize) -> Self {
107        self.http_semaphore = Some(Arc::new(Semaphore::new(permits)));
108        self
109    }
110}
111
112impl<A: Access> Layer<A> for ConcurrentLimitLayer {
113    type LayeredAccess = ConcurrentLimitAccessor<A>;
114
115    fn layer(&self, inner: A) -> Self::LayeredAccess {
116        let info = inner.info();
117
118        // Update http client with metrics http fetcher.
119        info.update_http_client(|client| {
120            HttpClient::with(ConcurrentLimitHttpFetcher {
121                inner: client.into_inner(),
122                http_semaphore: self.http_semaphore.clone(),
123            })
124        });
125
126        ConcurrentLimitAccessor {
127            inner,
128            semaphore: self.operation_semaphore.clone(),
129        }
130    }
131}
132
133pub struct ConcurrentLimitHttpFetcher {
134    inner: HttpFetcher,
135    http_semaphore: Option<Arc<Semaphore>>,
136}
137
138impl HttpFetch for ConcurrentLimitHttpFetcher {
139    async fn fetch(&self, req: http::Request<Buffer>) -> Result<http::Response<HttpBody>> {
140        let Some(semaphore) = self.http_semaphore.clone() else {
141            return self.inner.fetch(req).await;
142        };
143
144        let permit = semaphore
145            .acquire_owned()
146            .await
147            .expect("semaphore must be valid");
148
149        let resp = self.inner.fetch(req).await?;
150        let (parts, body) = resp.into_parts();
151        let body = body.map_inner(|s| {
152            Box::new(ConcurrentLimitStream {
153                inner: s,
154                _permit: permit,
155            })
156        });
157        Ok(http::Response::from_parts(parts, body))
158    }
159}
160
161pub struct ConcurrentLimitStream<S> {
162    inner: S,
163    // Hold on this permit until this reader has been dropped.
164    _permit: OwnedSemaphorePermit,
165}
166
167impl<S> Stream for ConcurrentLimitStream<S>
168where
169    S: Stream<Item = Result<Buffer>> + Unpin + 'static,
170{
171    type Item = Result<Buffer>;
172
173    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
174        self.inner.poll_next_unpin(cx)
175    }
176}
177
178#[derive(Debug, Clone)]
179pub struct ConcurrentLimitAccessor<A: Access> {
180    inner: A,
181    semaphore: Arc<Semaphore>,
182}
183
184impl<A: Access> LayeredAccess for ConcurrentLimitAccessor<A> {
185    type Inner = A;
186    type Reader = ConcurrentLimitWrapper<A::Reader>;
187    type BlockingReader = ConcurrentLimitWrapper<A::BlockingReader>;
188    type Writer = ConcurrentLimitWrapper<A::Writer>;
189    type BlockingWriter = ConcurrentLimitWrapper<A::BlockingWriter>;
190    type Lister = ConcurrentLimitWrapper<A::Lister>;
191    type BlockingLister = ConcurrentLimitWrapper<A::BlockingLister>;
192    type Deleter = ConcurrentLimitWrapper<A::Deleter>;
193    type BlockingDeleter = ConcurrentLimitWrapper<A::BlockingDeleter>;
194
195    fn inner(&self) -> &Self::Inner {
196        &self.inner
197    }
198
199    async fn create_dir(&self, path: &str, args: OpCreateDir) -> Result<RpCreateDir> {
200        let _permit = self
201            .semaphore
202            .acquire()
203            .await
204            .expect("semaphore must be valid");
205
206        self.inner.create_dir(path, args).await
207    }
208
209    async fn read(&self, path: &str, args: OpRead) -> Result<(RpRead, Self::Reader)> {
210        let permit = self
211            .semaphore
212            .clone()
213            .acquire_owned()
214            .await
215            .expect("semaphore must be valid");
216
217        self.inner
218            .read(path, args)
219            .await
220            .map(|(rp, r)| (rp, ConcurrentLimitWrapper::new(r, permit)))
221    }
222
223    async fn write(&self, path: &str, args: OpWrite) -> Result<(RpWrite, Self::Writer)> {
224        let permit = self
225            .semaphore
226            .clone()
227            .acquire_owned()
228            .await
229            .expect("semaphore must be valid");
230
231        self.inner
232            .write(path, args)
233            .await
234            .map(|(rp, w)| (rp, ConcurrentLimitWrapper::new(w, permit)))
235    }
236
237    async fn stat(&self, path: &str, args: OpStat) -> Result<RpStat> {
238        let _permit = self
239            .semaphore
240            .acquire()
241            .await
242            .expect("semaphore must be valid");
243
244        self.inner.stat(path, args).await
245    }
246
247    async fn delete(&self) -> Result<(RpDelete, Self::Deleter)> {
248        let permit = self
249            .semaphore
250            .clone()
251            .acquire_owned()
252            .await
253            .expect("semaphore must be valid");
254
255        self.inner
256            .delete()
257            .await
258            .map(|(rp, w)| (rp, ConcurrentLimitWrapper::new(w, permit)))
259    }
260
261    async fn list(&self, path: &str, args: OpList) -> Result<(RpList, Self::Lister)> {
262        let permit = self
263            .semaphore
264            .clone()
265            .acquire_owned()
266            .await
267            .expect("semaphore must be valid");
268
269        self.inner
270            .list(path, args)
271            .await
272            .map(|(rp, s)| (rp, ConcurrentLimitWrapper::new(s, permit)))
273    }
274
275    fn blocking_create_dir(&self, path: &str, args: OpCreateDir) -> Result<RpCreateDir> {
276        let _permit = self
277            .semaphore
278            .try_acquire()
279            .expect("semaphore must be valid");
280
281        self.inner.blocking_create_dir(path, args)
282    }
283
284    fn blocking_read(&self, path: &str, args: OpRead) -> Result<(RpRead, Self::BlockingReader)> {
285        let permit = self
286            .semaphore
287            .clone()
288            .try_acquire_owned()
289            .expect("semaphore must be valid");
290
291        self.inner
292            .blocking_read(path, args)
293            .map(|(rp, r)| (rp, ConcurrentLimitWrapper::new(r, permit)))
294    }
295
296    fn blocking_write(&self, path: &str, args: OpWrite) -> Result<(RpWrite, Self::BlockingWriter)> {
297        let permit = self
298            .semaphore
299            .clone()
300            .try_acquire_owned()
301            .expect("semaphore must be valid");
302
303        self.inner
304            .blocking_write(path, args)
305            .map(|(rp, w)| (rp, ConcurrentLimitWrapper::new(w, permit)))
306    }
307
308    fn blocking_stat(&self, path: &str, args: OpStat) -> Result<RpStat> {
309        let _permit = self
310            .semaphore
311            .try_acquire()
312            .expect("semaphore must be valid");
313
314        self.inner.blocking_stat(path, args)
315    }
316
317    fn blocking_delete(&self) -> Result<(RpDelete, Self::BlockingDeleter)> {
318        let permit = self
319            .semaphore
320            .clone()
321            .try_acquire_owned()
322            .expect("semaphore must be valid");
323
324        self.inner
325            .blocking_delete()
326            .map(|(rp, w)| (rp, ConcurrentLimitWrapper::new(w, permit)))
327    }
328
329    fn blocking_list(&self, path: &str, args: OpList) -> Result<(RpList, Self::BlockingLister)> {
330        let permit = self
331            .semaphore
332            .clone()
333            .try_acquire_owned()
334            .expect("semaphore must be valid");
335
336        self.inner
337            .blocking_list(path, args)
338            .map(|(rp, it)| (rp, ConcurrentLimitWrapper::new(it, permit)))
339    }
340}
341
342pub struct ConcurrentLimitWrapper<R> {
343    inner: R,
344
345    // Hold on this permit until this reader has been dropped.
346    _permit: OwnedSemaphorePermit,
347}
348
349impl<R> ConcurrentLimitWrapper<R> {
350    fn new(inner: R, permit: OwnedSemaphorePermit) -> Self {
351        Self {
352            inner,
353            _permit: permit,
354        }
355    }
356}
357
358impl<R: oio::Read> oio::Read for ConcurrentLimitWrapper<R> {
359    async fn read(&mut self) -> Result<Buffer> {
360        self.inner.read().await
361    }
362}
363
364impl<R: oio::BlockingRead> oio::BlockingRead for ConcurrentLimitWrapper<R> {
365    fn read(&mut self) -> Result<Buffer> {
366        self.inner.read()
367    }
368}
369
370impl<R: oio::Write> oio::Write for ConcurrentLimitWrapper<R> {
371    async fn write(&mut self, bs: Buffer) -> Result<()> {
372        self.inner.write(bs).await
373    }
374
375    async fn close(&mut self) -> Result<Metadata> {
376        self.inner.close().await
377    }
378
379    async fn abort(&mut self) -> Result<()> {
380        self.inner.abort().await
381    }
382}
383
384impl<R: oio::BlockingWrite> oio::BlockingWrite for ConcurrentLimitWrapper<R> {
385    fn write(&mut self, bs: Buffer) -> Result<()> {
386        self.inner.write(bs)
387    }
388
389    fn close(&mut self) -> Result<Metadata> {
390        self.inner.close()
391    }
392}
393
394impl<R: oio::List> oio::List for ConcurrentLimitWrapper<R> {
395    async fn next(&mut self) -> Result<Option<oio::Entry>> {
396        self.inner.next().await
397    }
398}
399
400impl<R: oio::BlockingList> oio::BlockingList for ConcurrentLimitWrapper<R> {
401    fn next(&mut self) -> Result<Option<oio::Entry>> {
402        self.inner.next()
403    }
404}
405
406impl<R: oio::Delete> oio::Delete for ConcurrentLimitWrapper<R> {
407    fn delete(&mut self, path: &str, args: OpDelete) -> Result<()> {
408        self.inner.delete(path, args)
409    }
410
411    async fn flush(&mut self) -> Result<usize> {
412        self.inner.flush().await
413    }
414}
415
416impl<R: oio::BlockingDelete> oio::BlockingDelete for ConcurrentLimitWrapper<R> {
417    fn delete(&mut self, path: &str, args: OpDelete) -> Result<()> {
418        self.inner.delete(path, args)
419    }
420
421    fn flush(&mut self) -> Result<usize> {
422        self.inner.flush()
423    }
424}