mirror of https://github.com/hwchase17/langchain
Add Batch Size kwarg to the llm start callback (#13483)
So you can more easily use the token counts directly from the API endpoint for batch size of 1pull/13722/head^2
parent
23566cbea9
commit
163bf165ed
@ -0,0 +1,71 @@
|
||||
"""Test base chat model."""
|
||||
import pytest
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from tests.unit_tests.fake.chat_model import FakeListChatModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def messages() -> list:
|
||||
return [
|
||||
SystemMessage(content="You are a test user."),
|
||||
HumanMessage(content="Hello, I am a test user."),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def messages_2() -> list:
|
||||
return [
|
||||
SystemMessage(content="You are a test user."),
|
||||
HumanMessage(content="Hello, I not a test user."),
|
||||
]
|
||||
|
||||
|
||||
def test_batch_size(messages: list, messages_2: list) -> None:
|
||||
# The base endpoint doesn't support native batching,
|
||||
# so we expect batch_size to always be 1
|
||||
llm = FakeListChatModel(responses=[str(i) for i in range(100)])
|
||||
with collect_runs() as cb:
|
||||
llm.batch([messages, messages_2], {"callbacks": [cb]})
|
||||
assert len(cb.traced_runs) == 2
|
||||
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
|
||||
with collect_runs() as cb:
|
||||
llm.batch([messages], {"callbacks": [cb]})
|
||||
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
|
||||
assert len(cb.traced_runs) == 1
|
||||
|
||||
with collect_runs() as cb:
|
||||
llm.invoke(messages)
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
||||
with collect_runs() as cb:
|
||||
list(llm.stream(messages))
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
||||
|
||||
async def test_async_batch_size(messages: list, messages_2: list) -> None:
|
||||
llm = FakeListChatModel(responses=[str(i) for i in range(100)])
|
||||
# The base endpoint doesn't support native batching,
|
||||
# so we expect batch_size to always be 1
|
||||
with collect_runs() as cb:
|
||||
await llm.abatch([messages, messages_2], {"callbacks": [cb]})
|
||||
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
|
||||
assert len(cb.traced_runs) == 2
|
||||
with collect_runs() as cb:
|
||||
await llm.abatch([messages], {"callbacks": [cb]})
|
||||
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
|
||||
assert len(cb.traced_runs) == 1
|
||||
|
||||
with collect_runs() as cb:
|
||||
await llm.ainvoke(messages)
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
||||
with collect_runs() as cb:
|
||||
async for _ in llm.astream(messages):
|
||||
pass
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue