Correct number of elements in config list in `batch()` and `abatch()` of `BaseLLM` (#12713)

- **Description:** Correct number of elements in config list in
`batch()` and `abatch()` of `BaseLLM` in case `max_concurrency` is not
None.
- **Issue:** #12643
- **Twitter handle:** @akionux

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/12817/head
Akio Nishimura 11 months ago committed by GitHub
parent 88b506b321
commit c04647bb4e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -297,9 +297,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
config = [{**c, "max_concurrency": None} for c in config] # type: ignore[misc]
return [
output
for batch in batches
for i, batch in enumerate(batches)
for output in self.batch(
batch, config=config, return_exceptions=return_exceptions, **kwargs
batch,
config=config[i * max_concurrency : (i + 1) * max_concurrency],
return_exceptions=return_exceptions,
**kwargs,
)
]
@ -340,9 +343,12 @@ class BaseLLM(BaseLanguageModel[str], ABC):
config = [{**c, "max_concurrency": None} for c in config] # type: ignore[misc]
return [
output
for batch in batches
for i, batch in enumerate(batches)
for output in await self.abatch(
batch, config=config, return_exceptions=return_exceptions, **kwargs
batch,
config=config[i * max_concurrency : (i + 1) * max_concurrency],
return_exceptions=return_exceptions,
**kwargs,
)
]

@ -6,6 +6,8 @@ try:
except ImportError:
from sqlalchemy.ext.declarative import declarative_base
import pytest
from langchain.cache import InMemoryCache, SQLAlchemyCache
from langchain.globals import get_llm_cache, set_llm_cache
from langchain.schema import Generation, LLMResult
@ -73,3 +75,22 @@ def test_custom_caching() -> None:
llm_output=None,
)
assert output == expected_output
def test_batch() -> None:
llm = FakeLLM()
output = llm.batch(["foo", "bar", "foo"])
assert output == ["foo"] * 3
output = llm.batch(["foo", "bar", "foo"], config={"max_concurrency": 2})
assert output == ["foo"] * 3
@pytest.mark.asyncio
async def test_abatch() -> None:
llm = FakeLLM()
output = await llm.abatch(["foo", "bar", "foo"])
assert output == ["foo"] * 3
output = await llm.abatch(["foo", "bar", "foo"], config={"max_concurrency": 2})
assert output == ["foo"] * 3

Loading…
Cancel
Save