You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/core/tests/unit_tests/fake/test_fake_chat_model.py

194 lines
6.5 KiB
Python

"""Tests for verifying that testing utility code works as expected."""
from itertools import cycle
from typing import Any, Dict, List, Optional, Union
from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.messages.human import HumanMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.fake.chat_model import GenericFakeChatModel, ParrotFakeChatModel
def test_generic_fake_chat_model_invoke() -> None:
# Will alternate between responding with hello and goodbye
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = model.invoke("meow")
assert response == AIMessage(content="hello")
response = model.invoke("kitty")
assert response == AIMessage(content="goodbye")
response = model.invoke("meow")
assert response == AIMessage(content="hello")
async def test_generic_fake_chat_model_ainvoke() -> None:
# Will alternate between responding with hello and goodbye
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello")
response = await model.ainvoke("kitty")
assert response == AIMessage(content="goodbye")
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello")
async def test_generic_fake_chat_model_stream() -> None:
"""Test streaming."""
infinite_cycle = cycle(
[
AIMessage(content="hello goodbye"),
]
)
model = GenericFakeChatModel(messages=infinite_cycle)
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
]
chunks = [chunk for chunk in model.stream("meow")]
assert chunks == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
]
# Test streaming of additional kwargs.
# Relying on insertion order of the additional kwargs dict
message = AIMessage(content="", additional_kwargs={"foo": 42, "bar": 24})
model = GenericFakeChatModel(messages=cycle([message]))
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(content="", additional_kwargs={"foo": 42}),
AIMessageChunk(content="", additional_kwargs={"bar": 24}),
]
message = AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": "move_file",
"arguments": '{\n "source_path": "foo",\n "'
'destination_path": "bar"\n}',
}
},
)
model = GenericFakeChatModel(messages=cycle([message]))
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(
content="", additional_kwargs={"function_call": {"name": "move_file"}}
),
AIMessageChunk(
content="",
additional_kwargs={
"function_call": {"arguments": '{\n "source_path": "foo"'}
},
),
AIMessageChunk(
content="", additional_kwargs={"function_call": {"arguments": ","}}
),
AIMessageChunk(
content="",
additional_kwargs={
"function_call": {"arguments": '\n "destination_path": "bar"\n}'}
},
),
]
accumulate_chunks = None
for chunk in chunks:
if accumulate_chunks is None:
accumulate_chunks = chunk
else:
accumulate_chunks += chunk
assert accumulate_chunks == AIMessageChunk(
content="",
additional_kwargs={
"function_call": {
"name": "move_file",
"arguments": '{\n "source_path": "foo",\n "'
'destination_path": "bar"\n}',
}
},
)
async def test_generic_fake_chat_model_astream_log() -> None:
"""Test streaming."""
infinite_cycle = cycle([AIMessage(content="hello goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
log_patches = [
log_patch async for log_patch in model.astream_log("meow", diff=False)
]
final = log_patches[-1]
assert final.state["streamed_output"] == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
]
async def test_callback_handlers() -> None:
"""Verify that model is implemented correctly with handlers working."""
class MyCustomAsyncHandler(AsyncCallbackHandler):
def __init__(self, store: List[str]) -> None:
self.store = store
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
# Do nothing
# Required to implement since this is an abstract method
pass
async def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
self.store.append(token)
infinite_cycle = cycle(
[
AIMessage(content="hello goodbye"),
]
)
model = GenericFakeChatModel(messages=infinite_cycle)
tokens: List[str] = []
# New model
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
assert results == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
]
assert tokens == ["hello", " ", "goodbye"]
def test_chat_model_inputs() -> None:
fake = ParrotFakeChatModel()
assert fake.invoke("hello") == HumanMessage(content="hello")
assert fake.invoke([("ai", "blah")]) == AIMessage(content="blah")
assert fake.invoke([AIMessage(content="blah")]) == AIMessage(content="blah")