core[minor]: Add BaseModel.rate_limiter, RateLimiter abstraction and in-memory implementation (#24669)

This PR proposes to create a rate limiter in the chat model directly,
and would replace: https://github.com/langchain-ai/langchain/pull/21992

It resolves most of the constraints that the Runnable rate limiter
introduced:

1. It's not annoying to apply the rate limiter to existing code; i.e., 
possible to roll out the change at the location where the model is
instantiated,
rather than at every location where the model is used! (Which is
necessary
   if the model is used in different ways in a given application.)
2. batch rate limiting is enforced properly
3. the rate limiter works correctly with streaming
4. the rate limiter is aware of the cache
5. The rate limiter can take into account information about the inputs
into the
model (we can add optional inputs to it down-the road together with
outputs!)

The only downside is that information will not be properly reflected in
tracing
as we don't have any metadata evens about a rate limiter. So the total
time
spent on a model invocation will be: 

* time spent waiting for the rate limiter
* time spend on the actual model request

## Example

```python
from langchain_core.rate_limiters import InMemoryRateLimiter
from langchain_groq import ChatGroq

groq = ChatGroq(rate_limiter=InMemoryRateLimiter(check_every_n_seconds=1))
groq.invoke('hello')
```
This commit is contained in:
Eugene Yurtsev 2024-07-25 23:03:34 -04:00 committed by GitHub
parent c623ae6661
commit 20690db482
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 300 additions and 123 deletions

View File

@ -60,6 +60,7 @@ from langchain_core.pydantic_v1 import (
Field,
root_validator,
)
from langchain_core.rate_limiters import BaseRateLimiter
from langchain_core.runnables import RunnableMap, RunnablePassthrough
from langchain_core.runnables.config import ensure_config, run_in_executor
from langchain_core.tracers._streaming import _StreamingCallbackHandler
@ -210,6 +211,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""[DEPRECATED] Callback manager to add to the run trace."""
rate_limiter: Optional[BaseRateLimiter] = Field(default=None, exclude=True)
"""An optional rate limiter to use for limiting the number of requests."""
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used.
@ -341,6 +345,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
batch_size=1,
)
generation: Optional[ChatGenerationChunk] = None
if self.rate_limiter:
self.rate_limiter.acquire(blocking=True)
try:
for chunk in self._stream(messages, stop=stop, **kwargs):
if chunk.message.id is None:
@ -412,6 +420,9 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
batch_size=1,
)
if self.rate_limiter:
self.rate_limiter.acquire(blocking=True)
generation: Optional[ChatGenerationChunk] = None
try:
async for chunk in self._astream(
@ -742,6 +753,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
# Apply the rate limiter after checking the cache, since
# we usually don't want to rate limit cache lookups, but
# we do want to rate limit API requests.
if self.rate_limiter:
self.rate_limiter.acquire(blocking=True)
# If stream is not explicitly set, check if implicitly requested by
# astream_events() or astream_log(). Bail out if _stream not implemented
if type(self)._stream != BaseChatModel._stream and kwargs.pop(
@ -822,6 +840,13 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
# Apply the rate limiter after checking the cache, since
# we usually don't want to rate limit cache lookups, but
# we do want to rate limit API requests.
if self.rate_limiter:
self.rate_limiter.acquire(blocking=True)
# If stream is not explicitly set, check if implicitly requested by
# astream_events() or astream_log(). Bail out if _astream not implemented
if (

View File

@ -1,11 +1,4 @@
"""Interface and implementation for time based rate limiters.
This module defines an interface for rate limiting requests based on time.
The interface cannot account for the size of the request or any other factors.
The module also provides an in-memory implementation of the rate limiter.
"""
"""Interface for a rate limiter and an in-memory rate limiter."""
from __future__ import annotations
@ -14,22 +7,14 @@ import asyncio
import threading
import time
from typing import (
Any,
Optional,
cast,
)
from langchain_core._api import beta
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.base import (
Input,
Output,
Runnable,
)
@beta(message="Introduced in 0.2.24. API subject to change.")
class BaseRateLimiter(Runnable[Input, Output], abc.ABC):
class BaseRateLimiter(abc.ABC):
"""Base class for rate limiters.
Usage of the base limiter is through the acquire and aacquire methods depending
@ -41,18 +26,10 @@ class BaseRateLimiter(Runnable[Input, Output], abc.ABC):
Current limitations:
- The rate limiter is not designed to work across different processes. It is
an in-memory rate limiter, but it is thread safe.
- The rate limiter only supports time-based rate limiting. It does not take
into account the size of the request or any other factors.
- The current implementation does not handle streaming inputs well and will
consume all inputs even if the rate limit has not been reached. Better support
for streaming inputs will be added in the future.
- When the rate limiter is combined with another runnable via a RunnableSequence,
usage of .batch() or .abatch() will only respect the average rate limit.
There will be bursty behavior as .batch() and .abatch() wait for each step
to complete before starting the next step. One way to mitigate this is to
use batch_as_completed() or abatch_as_completed().
- Rate limiting information is not surfaced in tracing or callbacks. This means
that the total time it takes to invoke a chat model will encompass both
the time spent waiting for tokens and the time spent making the request.
.. versionadded:: 0.2.24
"""
@ -95,55 +72,10 @@ class BaseRateLimiter(Runnable[Input, Output], abc.ABC):
True if the tokens were successfully acquired, False otherwise.
"""
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
"""Invoke the rate limiter.
This is a blocking call that waits until the given number of tokens are
available.
Args:
input: The input to the rate limiter.
config: The configuration for the rate limiter.
**kwargs: Additional keyword arguments.
Returns:
The output of the rate limiter.
"""
def _invoke(input: Input) -> Output:
"""Invoke the rate limiter. Internal function."""
self.acquire(blocking=True)
return cast(Output, input)
return self._call_with_config(_invoke, input, config, **kwargs)
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
"""Invoke the rate limiter. Async version.
This is a blocking call that waits until the given number of tokens are
available.
Args:
input: The input to the rate limiter.
config: The configuration for the rate limiter.
**kwargs: Additional keyword arguments.
"""
async def _ainvoke(input: Input) -> Output:
"""Invoke the rate limiter. Internal function."""
await self.aacquire(blocking=True)
return cast(Output, input)
return await self._acall_with_config(_ainvoke, input, config, **kwargs)
@beta(message="Introduced in 0.2.24. API subject to change.")
class InMemoryRateLimiter(BaseRateLimiter):
"""An in memory rate limiter.
"""An in memory rate limiter based on a token bucket algorithm.
This is an in memory rate limiter, so it cannot rate limit across
different processes.
@ -168,19 +100,13 @@ class InMemoryRateLimiter(BaseRateLimiter):
an in-memory rate limiter, but it is thread safe.
- The rate limiter only supports time-based rate limiting. It does not take
into account the size of the request or any other factors.
- The current implementation does not handle streaming inputs well and will
consume all inputs even if the rate limit has not been reached. Better support
for streaming inputs will be added in the future.
- When the rate limiter is combined with another runnable via a RunnableSequence,
usage of .batch() or .abatch() will only respect the average rate limit.
There will be bursty behavior as .batch() and .abatch() wait for each step
to complete before starting the next step. One way to mitigate this is to
use batch_as_completed() or abatch_as_completed().
Example:
.. code-block:: python
from langchain_core import InMemoryRateLimiter
from langchain_core.runnables import RunnableLambda, InMemoryRateLimiter
rate_limiter = InMemoryRateLimiter(
@ -239,7 +165,7 @@ class InMemoryRateLimiter(BaseRateLimiter):
self.check_every_n_seconds = check_every_n_seconds
def _consume(self) -> bool:
"""Consume the given amount of tokens if possible.
"""Try to consume a token.
Returns:
True means that the tokens were consumed, and the caller can proceed to
@ -317,3 +243,9 @@ class InMemoryRateLimiter(BaseRateLimiter):
while not self._consume():
await asyncio.sleep(self.check_every_n_seconds)
return True
__all__ = [
"BaseRateLimiter",
"InMemoryRateLimiter",
]

View File

@ -43,7 +43,6 @@ from langchain_core.runnables.passthrough import (
RunnablePassthrough,
RunnablePick,
)
from langchain_core.runnables.rate_limiter import InMemoryRateLimiter
from langchain_core.runnables.router import RouterInput, RouterRunnable
from langchain_core.runnables.utils import (
AddableDict,
@ -65,7 +64,6 @@ __all__ = [
"ensure_config",
"run_in_executor",
"patch_config",
"InMemoryRateLimiter",
"RouterInput",
"RouterRunnable",
"Runnable",

View File

@ -0,0 +1,258 @@
import time
from langchain_core.caches import InMemoryCache
from langchain_core.language_models import GenericFakeChatModel
from langchain_core.rate_limiters import InMemoryRateLimiter
def test_rate_limit_invoke() -> None:
"""Add rate limiter."""
model = GenericFakeChatModel(
messages=iter(["hello", "world", "!"]),
rate_limiter=InMemoryRateLimiter(
requests_per_second=200, check_every_n_seconds=0.01, max_bucket_size=10
),
)
tic = time.time()
model.invoke("foo")
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
# with 0 tokens.
assert 0.01 < toc - tic < 0.02
tic = time.time()
model.invoke("foo")
toc = time.time()
# The second time we call the model, we should have 1 extra token
# to proceed immediately.
assert toc - tic < 0.005
# The third time we call the model, we need to wait again for a token
tic = time.time()
model.invoke("foo")
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
# with 0 tokens.
assert 0.01 < toc - tic < 0.02
async def test_rate_limit_ainvoke() -> None:
"""Add rate limiter."""
model = GenericFakeChatModel(
messages=iter(["hello", "world", "!"]),
rate_limiter=InMemoryRateLimiter(
requests_per_second=20, check_every_n_seconds=0.1, max_bucket_size=10
),
)
tic = time.time()
await model.ainvoke("foo")
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
# with 0 tokens.
assert 0.1 < toc - tic < 0.2
tic = time.time()
await model.ainvoke("foo")
toc = time.time()
# The second time we call the model, we should have 1 extra token
# to proceed immediately.
assert toc - tic < 0.01
# The third time we call the model, we need to wait again for a token
tic = time.time()
await model.ainvoke("foo")
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
# with 0 tokens.
assert 0.1 < toc - tic < 0.2
def test_rate_limit_batch() -> None:
"""Test that batch and stream calls work with rate limiters."""
model = GenericFakeChatModel(
messages=iter(["hello", "world", "!"]),
rate_limiter=InMemoryRateLimiter(
requests_per_second=200, check_every_n_seconds=0.01, max_bucket_size=10
),
)
# Need 2 tokens to proceed
time_to_fill = 2 / 200.0
tic = time.time()
model.batch(["foo", "foo"])
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
# with 0 tokens.
assert time_to_fill < toc - tic < time_to_fill + 0.01
async def test_rate_limit_abatch() -> None:
"""Test that batch and stream calls work with rate limiters."""
model = GenericFakeChatModel(
messages=iter(["hello", "world", "!"]),
rate_limiter=InMemoryRateLimiter(
requests_per_second=200, check_every_n_seconds=0.01, max_bucket_size=10
),
)
# Need 2 tokens to proceed
time_to_fill = 2 / 200.0
tic = time.time()
await model.abatch(["foo", "foo"])
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
# with 0 tokens.
assert time_to_fill < toc - tic < time_to_fill + 0.01
def test_rate_limit_stream() -> None:
"""Test rate limit by stream."""
model = GenericFakeChatModel(
messages=iter(["hello world", "hello world", "hello world"]),
rate_limiter=InMemoryRateLimiter(
requests_per_second=200, check_every_n_seconds=0.01, max_bucket_size=10
),
)
# Check astream
tic = time.time()
response = list(model.stream("foo"))
assert [msg.content for msg in response] == ["hello", " ", "world"]
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
assert 0.01 < toc - tic < 0.02 # Slightly smaller than check every n seconds
# Second time around we should have 1 token left
tic = time.time()
response = list(model.stream("foo"))
assert [msg.content for msg in response] == ["hello", " ", "world"]
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
assert toc - tic < 0.005 # Slightly smaller than check every n seconds
# Third time around we should have 0 tokens left
tic = time.time()
response = list(model.stream("foo"))
assert [msg.content for msg in response] == ["hello", " ", "world"]
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
assert 0.01 < toc - tic < 0.02 # Slightly smaller than check every n seconds
async def test_rate_limit_astream() -> None:
"""Test rate limiting astream."""
rate_limiter = InMemoryRateLimiter(
requests_per_second=20, check_every_n_seconds=0.1, max_bucket_size=10
)
model = GenericFakeChatModel(
messages=iter(["hello world", "hello world", "hello world"]),
rate_limiter=rate_limiter,
)
# Check astream
tic = time.time()
response = [chunk async for chunk in model.astream("foo")]
assert [msg.content for msg in response] == ["hello", " ", "world"]
toc = time.time()
assert 0.1 < toc - tic < 0.2
# Second time around we should have 1 token left
tic = time.time()
response = [chunk async for chunk in model.astream("foo")]
assert [msg.content for msg in response] == ["hello", " ", "world"]
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
assert toc - tic < 0.01 # Slightly smaller than check every n seconds
# Third time around we should have 0 tokens left
tic = time.time()
response = [chunk async for chunk in model.astream("foo")]
assert [msg.content for msg in response] == ["hello", " ", "world"]
toc = time.time()
assert 0.1 < toc - tic < 0.2
def test_rate_limit_skips_cache() -> None:
"""Test that rate limiting does not rate limit cache look ups."""
cache = InMemoryCache()
model = GenericFakeChatModel(
messages=iter(["hello", "world", "!"]),
rate_limiter=InMemoryRateLimiter(
requests_per_second=100, check_every_n_seconds=0.01, max_bucket_size=1
),
cache=cache,
)
tic = time.time()
model.invoke("foo")
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
# with 0 tokens.
assert 0.01 < toc - tic < 0.02
for _ in range(2):
# Cache hits
tic = time.time()
model.invoke("foo")
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
# with 0 tokens.
assert toc - tic < 0.005
# Test verifies that there's only a single key
# Test also verifies that rate_limiter information is not part of the
# cache key
assert list(cache._cache) == [
(
'[{"lc": 1, "type": "constructor", "id": ["langchain", "schema", '
'"messages", '
'"HumanMessage"], "kwargs": {"content": "foo", "type": "human"}}]',
"[('_type', 'generic-fake-chat-model'), ('stop', None)]",
)
]
class SerializableModel(GenericFakeChatModel):
@classmethod
def is_lc_serializable(cls) -> bool:
return True
def test_serialization_with_rate_limiter() -> None:
"""Test model serialization with rate limiter."""
from langchain_core.load import dumps
model = SerializableModel(
messages=iter(["hello", "world", "!"]),
rate_limiter=InMemoryRateLimiter(
requests_per_second=100, check_every_n_seconds=0.01, max_bucket_size=1
),
)
serialized_model = dumps(model)
assert InMemoryRateLimiter.__name__ not in serialized_model
async def test_rate_limit_skips_cache_async() -> None:
"""Test that rate limiting does not rate limit cache look ups."""
cache = InMemoryCache()
model = GenericFakeChatModel(
messages=iter(["hello", "world", "!"]),
rate_limiter=InMemoryRateLimiter(
requests_per_second=100, check_every_n_seconds=0.01, max_bucket_size=1
),
cache=cache,
)
tic = time.time()
await model.ainvoke("foo")
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
# with 0 tokens.
assert 0.01 < toc - tic < 0.02
for _ in range(2):
# Cache hits
tic = time.time()
await model.ainvoke("foo")
toc = time.time()
# Should be larger than check every n seconds since the token bucket starts
# with 0 tokens.
assert toc - tic < 0.005

View File

@ -5,8 +5,7 @@ import time
import pytest
from freezegun import freeze_time
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables.rate_limiter import InMemoryRateLimiter
from langchain_core.rate_limiters import InMemoryRateLimiter
@pytest.fixture
@ -109,37 +108,3 @@ async def test_async_wait_max_bucket_size() -> None:
# Assert that sync wait can proceed without blocking
# since we have enough tokens
await rate_limiter.aacquire(blocking=True)
def test_add_rate_limiter() -> None:
"""Add rate limiter."""
def foo(x: int) -> int:
"""Return x."""
return x
rate_limiter = InMemoryRateLimiter(
requests_per_second=100, check_every_n_seconds=0.1, max_bucket_size=10
)
foo_ = RunnableLambda(foo)
chain = rate_limiter | foo_
assert chain.invoke(1) == 1
async def test_async_add_rate_limiter() -> None:
"""Add rate limiter."""
async def foo(x: int) -> int:
"""Return x."""
return x
rate_limiter = InMemoryRateLimiter(
requests_per_second=100, check_every_n_seconds=0.1, max_bucket_size=10
)
# mypy is unable to follow the type information when
# RunnableLambda is used with an async function
foo_ = RunnableLambda(foo) # type: ignore
chain = rate_limiter | foo_
assert (await chain.ainvoke(1)) == 1

View File

@ -11,7 +11,6 @@ EXPECTED_ALL = [
"run_in_executor",
"patch_config",
"RouterInput",
"InMemoryRateLimiter",
"RouterRunnable",
"Runnable",
"RunnableSerializable",