mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
core[minor]: Add BaseModel.rate_limiter, RateLimiter abstraction and in-memory implementation (#24669)
This PR proposes to create a rate limiter in the chat model directly, and would replace: https://github.com/langchain-ai/langchain/pull/21992 It resolves most of the constraints that the Runnable rate limiter introduced: 1. It's not annoying to apply the rate limiter to existing code; i.e., possible to roll out the change at the location where the model is instantiated, rather than at every location where the model is used! (Which is necessary if the model is used in different ways in a given application.) 2. batch rate limiting is enforced properly 3. the rate limiter works correctly with streaming 4. the rate limiter is aware of the cache 5. The rate limiter can take into account information about the inputs into the model (we can add optional inputs to it down-the road together with outputs!) The only downside is that information will not be properly reflected in tracing as we don't have any metadata evens about a rate limiter. So the total time spent on a model invocation will be: * time spent waiting for the rate limiter * time spend on the actual model request ## Example ```python from langchain_core.rate_limiters import InMemoryRateLimiter from langchain_groq import ChatGroq groq = ChatGroq(rate_limiter=InMemoryRateLimiter(check_every_n_seconds=1)) groq.invoke('hello') ```
This commit is contained in:
parent
c623ae6661
commit
20690db482
@ -60,6 +60,7 @@ from langchain_core.pydantic_v1 import (
|
||||
Field,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.rate_limiters import BaseRateLimiter
|
||||
from langchain_core.runnables import RunnableMap, RunnablePassthrough
|
||||
from langchain_core.runnables.config import ensure_config, run_in_executor
|
||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||
@ -210,6 +211,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
||||
"""[DEPRECATED] Callback manager to add to the run trace."""
|
||||
|
||||
rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True)
|
||||
"""An optional rate limiter to use for limiting the number of requests."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def raise_deprecation(cls, values: Dict) -> Dict:
|
||||
"""Raise deprecation warning if callback_manager is used.
|
||||
@ -341,6 +345,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
batch_size=1,
|
||||
)
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
|
||||
if self.rate_limiter:
|
||||
self.rate_limiter.acquire(blocking=True)
|
||||
|
||||
try:
|
||||
for chunk in self._stream(messages, stop=stop, **kwargs):
|
||||
if chunk.message.id is None:
|
||||
@ -412,6 +420,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
if self.rate_limiter:
|
||||
self.rate_limiter.acquire(blocking=True)
|
||||
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
try:
|
||||
async for chunk in self._astream(
|
||||
@ -742,6 +753,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
|
||||
# Apply the rate limiter after checking the cache, since
|
||||
# we usually don't want to rate limit cache lookups, but
|
||||
# we do want to rate limit API requests.
|
||||
if self.rate_limiter:
|
||||
self.rate_limiter.acquire(blocking=True)
|
||||
|
||||
# If stream is not explicitly set, check if implicitly requested by
|
||||
# astream_events() or astream_log(). Bail out if _stream not implemented
|
||||
if type(self)._stream != BaseChatModel._stream and kwargs.pop(
|
||||
@ -822,6 +840,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
|
||||
# Apply the rate limiter after checking the cache, since
|
||||
# we usually don't want to rate limit cache lookups, but
|
||||
# we do want to rate limit API requests.
|
||||
if self.rate_limiter:
|
||||
self.rate_limiter.acquire(blocking=True)
|
||||
|
||||
# If stream is not explicitly set, check if implicitly requested by
|
||||
# astream_events() or astream_log(). Bail out if _astream not implemented
|
||||
if (
|
||||
|
@ -1,11 +1,4 @@
|
||||
"""Interface and implementation for time based rate limiters.
|
||||
|
||||
This module defines an interface for rate limiting requests based on time.
|
||||
|
||||
The interface cannot account for the size of the request or any other factors.
|
||||
|
||||
The module also provides an in-memory implementation of the rate limiter.
|
||||
"""
|
||||
"""Interface for a rate limiter and an in-memory rate limiter."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -14,22 +7,14 @@ import asyncio
|
||||
import threading
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
Optional,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core._api import beta
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.runnables.base import (
|
||||
Input,
|
||||
Output,
|
||||
Runnable,
|
||||
)
|
||||
|
||||
|
||||
@beta(message="Introduced in 0.2.24. API subject to change.")
|
||||
class BaseRateLimiter(Runnable[Input, Output], abc.ABC):
|
||||
class BaseRateLimiter(abc.ABC):
|
||||
"""Base class for rate limiters.
|
||||
|
||||
Usage of the base limiter is through the acquire and aacquire methods depending
|
||||
@ -41,18 +26,10 @@ class BaseRateLimiter(Runnable[Input, Output], abc.ABC):
|
||||
|
||||
Current limitations:
|
||||
|
||||
- The rate limiter is not designed to work across different processes. It is
|
||||
an in-memory rate limiter, but it is thread safe.
|
||||
- The rate limiter only supports time-based rate limiting. It does not take
|
||||
into account the size of the request or any other factors.
|
||||
- The current implementation does not handle streaming inputs well and will
|
||||
consume all inputs even if the rate limit has not been reached. Better support
|
||||
for streaming inputs will be added in the future.
|
||||
- When the rate limiter is combined with another runnable via a RunnableSequence,
|
||||
usage of .batch() or .abatch() will only respect the average rate limit.
|
||||
There will be bursty behavior as .batch() and .abatch() wait for each step
|
||||
to complete before starting the next step. One way to mitigate this is to
|
||||
use batch_as_completed() or abatch_as_completed().
|
||||
- Rate limiting information is not surfaced in tracing or callbacks. This means
|
||||
that the total time it takes to invoke a chat model will encompass both
|
||||
the time spent waiting for tokens and the time spent making the request.
|
||||
|
||||
|
||||
.. versionadded:: 0.2.24
|
||||
"""
|
||||
@ -95,55 +72,10 @@ class BaseRateLimiter(Runnable[Input, Output], abc.ABC):
|
||||
True if the tokens were successfully acquired, False otherwise.
|
||||
"""
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
"""Invoke the rate limiter.
|
||||
|
||||
This is a blocking call that waits until the given number of tokens are
|
||||
available.
|
||||
|
||||
Args:
|
||||
input: The input to the rate limiter.
|
||||
config: The configuration for the rate limiter.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
The output of the rate limiter.
|
||||
"""
|
||||
|
||||
def _invoke(input: Input) -> Output:
|
||||
"""Invoke the rate limiter. Internal function."""
|
||||
self.acquire(blocking=True)
|
||||
return cast(Output, input)
|
||||
|
||||
return self._call_with_config(_invoke, input, config, **kwargs)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> Output:
|
||||
"""Invoke the rate limiter. Async version.
|
||||
|
||||
This is a blocking call that waits until the given number of tokens are
|
||||
available.
|
||||
|
||||
Args:
|
||||
input: The input to the rate limiter.
|
||||
config: The configuration for the rate limiter.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
|
||||
async def _ainvoke(input: Input) -> Output:
|
||||
"""Invoke the rate limiter. Internal function."""
|
||||
await self.aacquire(blocking=True)
|
||||
return cast(Output, input)
|
||||
|
||||
return await self._acall_with_config(_ainvoke, input, config, **kwargs)
|
||||
|
||||
|
||||
@beta(message="Introduced in 0.2.24. API subject to change.")
|
||||
class InMemoryRateLimiter(BaseRateLimiter):
|
||||
"""An in memory rate limiter.
|
||||
"""An in memory rate limiter based on a token bucket algorithm.
|
||||
|
||||
This is an in memory rate limiter, so it cannot rate limit across
|
||||
different processes.
|
||||
@ -168,19 +100,13 @@ class InMemoryRateLimiter(BaseRateLimiter):
|
||||
an in-memory rate limiter, but it is thread safe.
|
||||
- The rate limiter only supports time-based rate limiting. It does not take
|
||||
into account the size of the request or any other factors.
|
||||
- The current implementation does not handle streaming inputs well and will
|
||||
consume all inputs even if the rate limit has not been reached. Better support
|
||||
for streaming inputs will be added in the future.
|
||||
- When the rate limiter is combined with another runnable via a RunnableSequence,
|
||||
usage of .batch() or .abatch() will only respect the average rate limit.
|
||||
There will be bursty behavior as .batch() and .abatch() wait for each step
|
||||
to complete before starting the next step. One way to mitigate this is to
|
||||
use batch_as_completed() or abatch_as_completed().
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core import InMemoryRateLimiter
|
||||
|
||||
from langchain_core.runnables import RunnableLambda, InMemoryRateLimiter
|
||||
|
||||
rate_limiter = InMemoryRateLimiter(
|
||||
@ -239,7 +165,7 @@ class InMemoryRateLimiter(BaseRateLimiter):
|
||||
self.check_every_n_seconds = check_every_n_seconds
|
||||
|
||||
def _consume(self) -> bool:
|
||||
"""Consume the given amount of tokens if possible.
|
||||
"""Try to consume a token.
|
||||
|
||||
Returns:
|
||||
True means that the tokens were consumed, and the caller can proceed to
|
||||
@ -317,3 +243,9 @@ class InMemoryRateLimiter(BaseRateLimiter):
|
||||
while not self._consume():
|
||||
await asyncio.sleep(self.check_every_n_seconds)
|
||||
return True
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseRateLimiter",
|
||||
"InMemoryRateLimiter",
|
||||
]
|
@ -43,7 +43,6 @@ from langchain_core.runnables.passthrough import (
|
||||
RunnablePassthrough,
|
||||
RunnablePick,
|
||||
)
|
||||
from langchain_core.runnables.rate_limiter import InMemoryRateLimiter
|
||||
from langchain_core.runnables.router import RouterInput, RouterRunnable
|
||||
from langchain_core.runnables.utils import (
|
||||
AddableDict,
|
||||
@ -65,7 +64,6 @@ __all__ = [
|
||||
"ensure_config",
|
||||
"run_in_executor",
|
||||
"patch_config",
|
||||
"InMemoryRateLimiter",
|
||||
"RouterInput",
|
||||
"RouterRunnable",
|
||||
"Runnable",
|
||||
|
@ -0,0 +1,258 @@
|
||||
import time
|
||||
|
||||
from langchain_core.caches import InMemoryCache
|
||||
from langchain_core.language_models import GenericFakeChatModel
|
||||
from langchain_core.rate_limiters import InMemoryRateLimiter
|
||||
|
||||
|
||||
def test_rate_limit_invoke() -> None:
|
||||
"""Add rate limiter."""
|
||||
|
||||
model = GenericFakeChatModel(
|
||||
messages=iter(["hello", "world", "!"]),
|
||||
rate_limiter=InMemoryRateLimiter(
|
||||
requests_per_second=200, check_every_n_seconds=0.01, max_bucket_size=10
|
||||
),
|
||||
)
|
||||
tic = time.time()
|
||||
model.invoke("foo")
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
# with 0 tokens.
|
||||
assert 0.01 < toc - tic < 0.02
|
||||
|
||||
tic = time.time()
|
||||
model.invoke("foo")
|
||||
toc = time.time()
|
||||
# The second time we call the model, we should have 1 extra token
|
||||
# to proceed immediately.
|
||||
assert toc - tic < 0.005
|
||||
|
||||
# The third time we call the model, we need to wait again for a token
|
||||
tic = time.time()
|
||||
model.invoke("foo")
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
# with 0 tokens.
|
||||
assert 0.01 < toc - tic < 0.02
|
||||
|
||||
|
||||
async def test_rate_limit_ainvoke() -> None:
|
||||
"""Add rate limiter."""
|
||||
|
||||
model = GenericFakeChatModel(
|
||||
messages=iter(["hello", "world", "!"]),
|
||||
rate_limiter=InMemoryRateLimiter(
|
||||
requests_per_second=20, check_every_n_seconds=0.1, max_bucket_size=10
|
||||
),
|
||||
)
|
||||
tic = time.time()
|
||||
await model.ainvoke("foo")
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
# with 0 tokens.
|
||||
assert 0.1 < toc - tic < 0.2
|
||||
|
||||
tic = time.time()
|
||||
await model.ainvoke("foo")
|
||||
toc = time.time()
|
||||
# The second time we call the model, we should have 1 extra token
|
||||
# to proceed immediately.
|
||||
assert toc - tic < 0.01
|
||||
|
||||
# The third time we call the model, we need to wait again for a token
|
||||
tic = time.time()
|
||||
await model.ainvoke("foo")
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
# with 0 tokens.
|
||||
assert 0.1 < toc - tic < 0.2
|
||||
|
||||
|
||||
def test_rate_limit_batch() -> None:
|
||||
"""Test that batch and stream calls work with rate limiters."""
|
||||
model = GenericFakeChatModel(
|
||||
messages=iter(["hello", "world", "!"]),
|
||||
rate_limiter=InMemoryRateLimiter(
|
||||
requests_per_second=200, check_every_n_seconds=0.01, max_bucket_size=10
|
||||
),
|
||||
)
|
||||
# Need 2 tokens to proceed
|
||||
time_to_fill = 2 / 200.0
|
||||
tic = time.time()
|
||||
model.batch(["foo", "foo"])
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
# with 0 tokens.
|
||||
assert time_to_fill < toc - tic < time_to_fill + 0.01
|
||||
|
||||
|
||||
async def test_rate_limit_abatch() -> None:
|
||||
"""Test that batch and stream calls work with rate limiters."""
|
||||
model = GenericFakeChatModel(
|
||||
messages=iter(["hello", "world", "!"]),
|
||||
rate_limiter=InMemoryRateLimiter(
|
||||
requests_per_second=200, check_every_n_seconds=0.01, max_bucket_size=10
|
||||
),
|
||||
)
|
||||
# Need 2 tokens to proceed
|
||||
time_to_fill = 2 / 200.0
|
||||
tic = time.time()
|
||||
await model.abatch(["foo", "foo"])
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
# with 0 tokens.
|
||||
assert time_to_fill < toc - tic < time_to_fill + 0.01
|
||||
|
||||
|
||||
def test_rate_limit_stream() -> None:
|
||||
"""Test rate limit by stream."""
|
||||
model = GenericFakeChatModel(
|
||||
messages=iter(["hello world", "hello world", "hello world"]),
|
||||
rate_limiter=InMemoryRateLimiter(
|
||||
requests_per_second=200, check_every_n_seconds=0.01, max_bucket_size=10
|
||||
),
|
||||
)
|
||||
# Check astream
|
||||
tic = time.time()
|
||||
response = list(model.stream("foo"))
|
||||
assert [msg.content for msg in response] == ["hello", " ", "world"]
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
assert 0.01 < toc - tic < 0.02 # Slightly smaller than check every n seconds
|
||||
|
||||
# Second time around we should have 1 token left
|
||||
tic = time.time()
|
||||
response = list(model.stream("foo"))
|
||||
assert [msg.content for msg in response] == ["hello", " ", "world"]
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
assert toc - tic < 0.005 # Slightly smaller than check every n seconds
|
||||
|
||||
# Third time around we should have 0 tokens left
|
||||
tic = time.time()
|
||||
response = list(model.stream("foo"))
|
||||
assert [msg.content for msg in response] == ["hello", " ", "world"]
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
assert 0.01 < toc - tic < 0.02 # Slightly smaller than check every n seconds
|
||||
|
||||
|
||||
async def test_rate_limit_astream() -> None:
|
||||
"""Test rate limiting astream."""
|
||||
rate_limiter = InMemoryRateLimiter(
|
||||
requests_per_second=20, check_every_n_seconds=0.1, max_bucket_size=10
|
||||
)
|
||||
model = GenericFakeChatModel(
|
||||
messages=iter(["hello world", "hello world", "hello world"]),
|
||||
rate_limiter=rate_limiter,
|
||||
)
|
||||
# Check astream
|
||||
tic = time.time()
|
||||
response = [chunk async for chunk in model.astream("foo")]
|
||||
assert [msg.content for msg in response] == ["hello", " ", "world"]
|
||||
toc = time.time()
|
||||
assert 0.1 < toc - tic < 0.2
|
||||
|
||||
# Second time around we should have 1 token left
|
||||
tic = time.time()
|
||||
response = [chunk async for chunk in model.astream("foo")]
|
||||
assert [msg.content for msg in response] == ["hello", " ", "world"]
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
assert toc - tic < 0.01 # Slightly smaller than check every n seconds
|
||||
|
||||
# Third time around we should have 0 tokens left
|
||||
tic = time.time()
|
||||
response = [chunk async for chunk in model.astream("foo")]
|
||||
assert [msg.content for msg in response] == ["hello", " ", "world"]
|
||||
toc = time.time()
|
||||
assert 0.1 < toc - tic < 0.2
|
||||
|
||||
|
||||
def test_rate_limit_skips_cache() -> None:
|
||||
"""Test that rate limiting does not rate limit cache look ups."""
|
||||
cache = InMemoryCache()
|
||||
model = GenericFakeChatModel(
|
||||
messages=iter(["hello", "world", "!"]),
|
||||
rate_limiter=InMemoryRateLimiter(
|
||||
requests_per_second=100, check_every_n_seconds=0.01, max_bucket_size=1
|
||||
),
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
tic = time.time()
|
||||
model.invoke("foo")
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
# with 0 tokens.
|
||||
assert 0.01 < toc - tic < 0.02
|
||||
|
||||
for _ in range(2):
|
||||
# Cache hits
|
||||
tic = time.time()
|
||||
model.invoke("foo")
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
# with 0 tokens.
|
||||
assert toc - tic < 0.005
|
||||
|
||||
# Test verifies that there's only a single key
|
||||
# Test also verifies that rate_limiter information is not part of the
|
||||
# cache key
|
||||
assert list(cache._cache) == [
|
||||
(
|
||||
'[{"lc": 1, "type": "constructor", "id": ["langchain", "schema", '
|
||||
'"messages", '
|
||||
'"HumanMessage"], "kwargs": {"content": "foo", "type": "human"}}]',
|
||||
"[('_type', 'generic-fake-chat-model'), ('stop', None)]",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class SerializableModel(GenericFakeChatModel):
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def test_serialization_with_rate_limiter() -> None:
|
||||
"""Test model serialization with rate limiter."""
|
||||
from langchain_core.load import dumps
|
||||
|
||||
model = SerializableModel(
|
||||
messages=iter(["hello", "world", "!"]),
|
||||
rate_limiter=InMemoryRateLimiter(
|
||||
requests_per_second=100, check_every_n_seconds=0.01, max_bucket_size=1
|
||||
),
|
||||
)
|
||||
serialized_model = dumps(model)
|
||||
assert InMemoryRateLimiter.__name__ not in serialized_model
|
||||
|
||||
|
||||
async def test_rate_limit_skips_cache_async() -> None:
|
||||
"""Test that rate limiting does not rate limit cache look ups."""
|
||||
cache = InMemoryCache()
|
||||
model = GenericFakeChatModel(
|
||||
messages=iter(["hello", "world", "!"]),
|
||||
rate_limiter=InMemoryRateLimiter(
|
||||
requests_per_second=100, check_every_n_seconds=0.01, max_bucket_size=1
|
||||
),
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
tic = time.time()
|
||||
await model.ainvoke("foo")
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
# with 0 tokens.
|
||||
assert 0.01 < toc - tic < 0.02
|
||||
|
||||
for _ in range(2):
|
||||
# Cache hits
|
||||
tic = time.time()
|
||||
await model.ainvoke("foo")
|
||||
toc = time.time()
|
||||
# Should be larger than check every n seconds since the token bucket starts
|
||||
# with 0 tokens.
|
||||
assert toc - tic < 0.005
|
@ -5,8 +5,7 @@ import time
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from langchain_core.runnables.rate_limiter import InMemoryRateLimiter
|
||||
from langchain_core.rate_limiters import InMemoryRateLimiter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -109,37 +108,3 @@ async def test_async_wait_max_bucket_size() -> None:
|
||||
# Assert that sync wait can proceed without blocking
|
||||
# since we have enough tokens
|
||||
await rate_limiter.aacquire(blocking=True)
|
||||
|
||||
|
||||
def test_add_rate_limiter() -> None:
|
||||
"""Add rate limiter."""
|
||||
|
||||
def foo(x: int) -> int:
|
||||
"""Return x."""
|
||||
return x
|
||||
|
||||
rate_limiter = InMemoryRateLimiter(
|
||||
requests_per_second=100, check_every_n_seconds=0.1, max_bucket_size=10
|
||||
)
|
||||
|
||||
foo_ = RunnableLambda(foo)
|
||||
chain = rate_limiter | foo_
|
||||
assert chain.invoke(1) == 1
|
||||
|
||||
|
||||
async def test_async_add_rate_limiter() -> None:
|
||||
"""Add rate limiter."""
|
||||
|
||||
async def foo(x: int) -> int:
|
||||
"""Return x."""
|
||||
return x
|
||||
|
||||
rate_limiter = InMemoryRateLimiter(
|
||||
requests_per_second=100, check_every_n_seconds=0.1, max_bucket_size=10
|
||||
)
|
||||
|
||||
# mypy is unable to follow the type information when
|
||||
# RunnableLambda is used with an async function
|
||||
foo_ = RunnableLambda(foo) # type: ignore
|
||||
chain = rate_limiter | foo_
|
||||
assert (await chain.ainvoke(1)) == 1
|
@ -11,7 +11,6 @@ EXPECTED_ALL = [
|
||||
"run_in_executor",
|
||||
"patch_config",
|
||||
"RouterInput",
|
||||
"InMemoryRateLimiter",
|
||||
"RouterRunnable",
|
||||
"Runnable",
|
||||
"RunnableSerializable",
|
||||
|
Loading…
Reference in New Issue
Block a user