mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Add Batch Size kwarg to the llm start callback (#13483)
So you can more easily use the token counts directly from the API endpoint for batch size of 1
This commit is contained in:
parent
23566cbea9
commit
163bf165ed
@ -206,6 +206,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
batch_size=1,
|
||||
)
|
||||
try:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
@ -259,6 +260,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
batch_size=1,
|
||||
)
|
||||
try:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
@ -334,6 +336,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=run_name,
|
||||
batch_size=len(messages),
|
||||
)
|
||||
results = []
|
||||
for i, m in enumerate(messages):
|
||||
@ -396,6 +399,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=run_name,
|
||||
batch_size=len(messages),
|
||||
)
|
||||
|
||||
results = await asyncio.gather(
|
||||
|
@ -382,6 +382,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
batch_size=1,
|
||||
)
|
||||
try:
|
||||
generation: Optional[GenerationChunk] = None
|
||||
@ -433,6 +434,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=config.get("run_name"),
|
||||
batch_size=1,
|
||||
)
|
||||
try:
|
||||
generation: Optional[GenerationChunk] = None
|
||||
@ -645,6 +647,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=run_name,
|
||||
batch_size=len(prompts),
|
||||
)[0]
|
||||
for callback_manager, prompt, run_name in zip(
|
||||
callback_managers, prompts, run_name_list
|
||||
@ -662,6 +665,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=run_name_list[idx],
|
||||
batch_size=len(missing_prompts),
|
||||
)[0]
|
||||
for idx in missing_prompt_idxs
|
||||
]
|
||||
@ -810,6 +814,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=run_name,
|
||||
batch_size=len(prompts),
|
||||
)
|
||||
for callback_manager, prompt, run_name in zip(
|
||||
callback_managers, prompts, run_name_list
|
||||
@ -830,6 +835,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
invocation_params=params,
|
||||
options=options,
|
||||
name=run_name_list[idx],
|
||||
batch_size=len(missing_prompts),
|
||||
)
|
||||
for idx in missing_prompt_idxs
|
||||
]
|
||||
|
@ -0,0 +1,71 @@
|
||||
"""Test base chat model."""
|
||||
import pytest
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from tests.unit_tests.fake.chat_model import FakeListChatModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def messages() -> list:
|
||||
return [
|
||||
SystemMessage(content="You are a test user."),
|
||||
HumanMessage(content="Hello, I am a test user."),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def messages_2() -> list:
|
||||
return [
|
||||
SystemMessage(content="You are a test user."),
|
||||
HumanMessage(content="Hello, I not a test user."),
|
||||
]
|
||||
|
||||
|
||||
def test_batch_size(messages: list, messages_2: list) -> None:
|
||||
# The base endpoint doesn't support native batching,
|
||||
# so we expect batch_size to always be 1
|
||||
llm = FakeListChatModel(responses=[str(i) for i in range(100)])
|
||||
with collect_runs() as cb:
|
||||
llm.batch([messages, messages_2], {"callbacks": [cb]})
|
||||
assert len(cb.traced_runs) == 2
|
||||
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
|
||||
with collect_runs() as cb:
|
||||
llm.batch([messages], {"callbacks": [cb]})
|
||||
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
|
||||
assert len(cb.traced_runs) == 1
|
||||
|
||||
with collect_runs() as cb:
|
||||
llm.invoke(messages)
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
||||
with collect_runs() as cb:
|
||||
list(llm.stream(messages))
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
||||
|
||||
async def test_async_batch_size(messages: list, messages_2: list) -> None:
|
||||
llm = FakeListChatModel(responses=[str(i) for i in range(100)])
|
||||
# The base endpoint doesn't support native batching,
|
||||
# so we expect batch_size to always be 1
|
||||
with collect_runs() as cb:
|
||||
await llm.abatch([messages, messages_2], {"callbacks": [cb]})
|
||||
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
|
||||
assert len(cb.traced_runs) == 2
|
||||
with collect_runs() as cb:
|
||||
await llm.abatch([messages], {"callbacks": [cb]})
|
||||
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
|
||||
assert len(cb.traced_runs) == 1
|
||||
|
||||
with collect_runs() as cb:
|
||||
await llm.ainvoke(messages)
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
||||
with collect_runs() as cb:
|
||||
async for _ in llm.astream(messages):
|
||||
pass
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
@ -1,3 +1,4 @@
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from tests.unit_tests.fake.llm import FakeListLLM
|
||||
|
||||
|
||||
@ -17,3 +18,60 @@ async def test_abatch() -> None:
|
||||
|
||||
output = await llm.abatch(["foo", "bar", "foo"], config={"max_concurrency": 2})
|
||||
assert output == ["foo"] * 3
|
||||
|
||||
|
||||
def test_batch_size() -> None:
|
||||
llm = FakeListLLM(responses=["foo"] * 3)
|
||||
with collect_runs() as cb:
|
||||
llm.batch(["foo", "bar", "foo"], {"callbacks": [cb]})
|
||||
assert all([(r.extra or {}).get("batch_size") == 3 for r in cb.traced_runs])
|
||||
assert len(cb.traced_runs) == 3
|
||||
llm = FakeListLLM(responses=["foo"])
|
||||
with collect_runs() as cb:
|
||||
llm.batch(["foo"], {"callbacks": [cb]})
|
||||
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
|
||||
assert len(cb.traced_runs) == 1
|
||||
|
||||
llm = FakeListLLM(responses=["foo"])
|
||||
with collect_runs() as cb:
|
||||
llm.invoke("foo")
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
||||
llm = FakeListLLM(responses=["foo"])
|
||||
with collect_runs() as cb:
|
||||
list(llm.stream("foo"))
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
||||
llm = FakeListLLM(responses=["foo"] * 1)
|
||||
with collect_runs() as cb:
|
||||
llm.predict("foo")
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
||||
|
||||
async def test_async_batch_size() -> None:
|
||||
llm = FakeListLLM(responses=["foo"] * 3)
|
||||
with collect_runs() as cb:
|
||||
await llm.abatch(["foo", "bar", "foo"], {"callbacks": [cb]})
|
||||
assert all([(r.extra or {}).get("batch_size") == 3 for r in cb.traced_runs])
|
||||
assert len(cb.traced_runs) == 3
|
||||
llm = FakeListLLM(responses=["foo"])
|
||||
with collect_runs() as cb:
|
||||
await llm.abatch(["foo"], {"callbacks": [cb]})
|
||||
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
|
||||
assert len(cb.traced_runs) == 1
|
||||
|
||||
llm = FakeListLLM(responses=["foo"])
|
||||
with collect_runs() as cb:
|
||||
await llm.ainvoke("foo")
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
||||
llm = FakeListLLM(responses=["foo"])
|
||||
with collect_runs() as cb:
|
||||
async for _ in llm.astream("foo"):
|
||||
pass
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user