You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/core/tests/unit_tests/language_models/llms/test_base.py

78 lines
2.8 KiB
Python

from langchain_core.tracers.context import collect_runs
from tests.unit_tests.fake.llm import FakeListLLM
def test_batch() -> None:
llm = FakeListLLM(responses=["foo"] * 3)
output = llm.batch(["foo", "bar", "foo"])
assert output == ["foo"] * 3
output = llm.batch(["foo", "bar", "foo"], config={"max_concurrency": 2})
assert output == ["foo"] * 3
async def test_abatch() -> None:
llm = FakeListLLM(responses=["foo"] * 3)
output = await llm.abatch(["foo", "bar", "foo"])
assert output == ["foo"] * 3
output = await llm.abatch(["foo", "bar", "foo"], config={"max_concurrency": 2})
assert output == ["foo"] * 3
def test_batch_size() -> None:
llm = FakeListLLM(responses=["foo"] * 3)
with collect_runs() as cb:
llm.batch(["foo", "bar", "foo"], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 3 for r in cb.traced_runs])
assert len(cb.traced_runs) == 3
llm = FakeListLLM(responses=["foo"])
with collect_runs() as cb:
llm.batch(["foo"], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
assert len(cb.traced_runs) == 1
llm = FakeListLLM(responses=["foo"])
with collect_runs() as cb:
llm.invoke("foo")
assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
llm = FakeListLLM(responses=["foo"])
with collect_runs() as cb:
list(llm.stream("foo"))
assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
llm = FakeListLLM(responses=["foo"] * 1)
with collect_runs() as cb:
llm.predict("foo")
assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
async def test_async_batch_size() -> None:
llm = FakeListLLM(responses=["foo"] * 3)
with collect_runs() as cb:
await llm.abatch(["foo", "bar", "foo"], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 3 for r in cb.traced_runs])
assert len(cb.traced_runs) == 3
llm = FakeListLLM(responses=["foo"])
with collect_runs() as cb:
await llm.abatch(["foo"], {"callbacks": [cb]})
assert all([(r.extra or {}).get("batch_size") == 1 for r in cb.traced_runs])
assert len(cb.traced_runs) == 1
llm = FakeListLLM(responses=["foo"])
with collect_runs() as cb:
await llm.ainvoke("foo")
assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
llm = FakeListLLM(responses=["foo"])
with collect_runs() as cb:
async for _ in llm.astream("foo"):
pass
assert len(cb.traced_runs) == 1
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1