Add Batch Size kwarg to the llm start callback (#13483)

So you can more easily use the token counts directly from the API
endpoint for batch size of 1
This commit is contained in:
William FH 2023-11-22 14:47:57 -08:00 committed by GitHub
parent 23566cbea9
commit 163bf165ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 147 additions and 8 deletions

View File

@ -206,6 +206,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
invocation_params=params,
options=options,
name=config.get("run_name"),
batch_size=1,
)
try:
generation: Optional[ChatGenerationChunk] = None
@ -259,6 +260,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
invocation_params=params,
options=options,
name=config.get("run_name"),
batch_size=1,
)
try:
generation: Optional[ChatGenerationChunk] = None
@ -334,6 +336,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
invocation_params=params,
options=options,
name=run_name,
batch_size=len(messages),
)
results = []
for i, m in enumerate(messages):
@ -396,6 +399,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
invocation_params=params,
options=options,
name=run_name,
batch_size=len(messages),
)
results = await asyncio.gather(

View File

@ -382,6 +382,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
invocation_params=params,
options=options,
name=config.get("run_name"),
batch_size=1,
)
try:
generation: Optional[GenerationChunk] = None
@ -433,6 +434,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
invocation_params=params,
options=options,
name=config.get("run_name"),
batch_size=1,
)
try:
generation: Optional[GenerationChunk] = None
@ -645,6 +647,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
invocation_params=params,
options=options,
name=run_name,
batch_size=len(prompts),
)[0]
for callback_manager, prompt, run_name in zip(
callback_managers, prompts, run_name_list
@ -662,6 +665,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
invocation_params=params,
options=options,
name=run_name_list[idx],
batch_size=len(missing_prompts),
)[0]
for idx in missing_prompt_idxs
]
@ -810,6 +814,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
invocation_params=params,
options=options,
name=run_name,
batch_size=len(prompts),
)
for callback_manager, prompt, run_name in zip(
callback_managers, prompts, run_name_list
@ -830,6 +835,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
invocation_params=params,
options=options,
name=run_name_list[idx],
batch_size=len(missing_prompts),
)
for idx in missing_prompt_idxs
]

View File

@ -0,0 +1,71 @@
"""Test base chat model."""
import pytest
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.fake.chat_model import FakeListChatModel
@pytest.fixture
def messages() -> list:
return [
SystemMessage(content="You are a test user."),
HumanMessage(content="Hello, I am a test user."),
]
@pytest.fixture
def messages_2() -> list:
return [
SystemMessage(content="You are a test user."),
HumanMessage(content="Hello, I not a test user."),
]
def test_batch_size(messages: list, messages_2: list) -> None:
# The base endpoint doesn't support native batching,
# so we expect batch_size to always be 1
llm = FakeListChatModel(responses=[str(i) for i in range(100)])
with collect_runs() as cb:
llm.batch([messages, messages_2], {"callbacks": [cb]})
assert len(cb.traced_runs) == 2
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
with collect_runs() as cb:
llm.batch([messages], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
assert len(cb.traced_runs) == 1
with collect_runs() as cb:
llm.invoke(messages)
assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
with collect_runs() as cb:
list(llm.stream(messages))
assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
async def test_async_batch_size(messages: list, messages_2: list) -> None:
llm = FakeListChatModel(responses=[str(i) for i in range(100)])
# The base endpoint doesn't support native batching,
# so we expect batch_size to always be 1
with collect_runs() as cb:
await llm.abatch([messages, messages_2], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
assert len(cb.traced_runs) == 2
with collect_runs() as cb:
await llm.abatch([messages], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
assert len(cb.traced_runs) == 1
with collect_runs() as cb:
await llm.ainvoke(messages)
assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
with collect_runs() as cb:
async for _ in llm.astream(messages):
pass
assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1

View File

@ -1,3 +1,4 @@
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.fake.llm import FakeListLLM
@ -17,3 +18,60 @@ async def test_abatch() -> None:
output = await llm.abatch(["foo", "bar", "foo"], config={"max_concurrency": 2})
assert output == ["foo"] * 3
def test_batch_size() -> None:
llm = FakeListLLM(responses=["foo"] * 3)
with collect_runs() as cb:
llm.batch(["foo", "bar", "foo"], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 3 for r in cb.traced_runs])
assert len(cb.traced_runs) == 3
llm = FakeListLLM(responses=["foo"])
with collect_runs() as cb:
llm.batch(["foo"], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
assert len(cb.traced_runs) == 1
llm = FakeListLLM(responses=["foo"])
with collect_runs() as cb:
llm.invoke("foo")
assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
llm = FakeListLLM(responses=["foo"])
with collect_runs() as cb:
list(llm.stream("foo"))
assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
llm = FakeListLLM(responses=["foo"] * 1)
with collect_runs() as cb:
llm.predict("foo")
assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
async def test_async_batch_size() -> None:
llm = FakeListLLM(responses=["foo"] * 3)
with collect_runs() as cb:
await llm.abatch(["foo", "bar", "foo"], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 3 for r in cb.traced_runs])
assert len(cb.traced_runs) == 3
llm = FakeListLLM(responses=["foo"])
with collect_runs() as cb:
await llm.abatch(["foo"], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
assert len(cb.traced_runs) == 1
llm = FakeListLLM(responses=["foo"])
with collect_runs() as cb:
await llm.ainvoke("foo")
assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
llm = FakeListLLM(responses=["foo"])
with collect_runs() as cb:
async for _ in llm.astream("foo"):
pass
assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1

File diff suppressed because one or more lines are too long