core[patch]: testing add chat model for unit-tests (#16209)

This PR adds a fake chat model for testing purposes.

Used in this PR: https://github.com/langchain-ai/langchain/pull/16172
pull/16216/head
Eugene Yurtsev 6 months ago committed by GitHub
parent 27ad65cc68
commit ecd4f0a7ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,15 +1,21 @@
"""Fake ChatModel for testing purposes."""
"""Fake Chat Model wrapper for testing purposes."""
import asyncio
import re
import time
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union, cast
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
from langchain_core.messages import AIMessageChunk, BaseMessage
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import run_in_executor
class FakeMessagesListChatModel(BaseChatModel):
@ -114,3 +120,184 @@ class FakeListChatModel(SimpleChatModel):
@property
def _identifying_params(self) -> Dict[str, Any]:
return {"responses": self.responses}
class FakeChatModel(SimpleChatModel):
"""Fake Chat Model wrapper for testing purposes."""
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
return "fake response"
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
output_str = "fake response"
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@property
def _llm_type(self) -> str:
return "fake-chat-model"
@property
def _identifying_params(self) -> Dict[str, Any]:
return {"key": "fake"}
class GenericFakeChatModel(BaseChatModel):
"""A generic fake chat model that can be used to test the chat model interface.
* Chat model should be usable in both sync and async tests
* Invokes on_llm_new_token to allow for testing of callback related code for new
tokens.
* Includes logic to break messages into message chunk to facilitate testing of
streaming.
"""
messages: Iterator[AIMessage]
"""Get an iterator over messages.
This can be expanded to accept other types like Callables / dicts / strings
to make the interface more generic if needed.
Note: if you want to pass a list, you can use `iter` to convert it to an iterator.
Please note that streaming is not implemented yet. We should try to implement it
in the future by delegating to invoke and then breaking the resulting output
into message chunks.
"""
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
message = next(self.messages)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream the output of the model."""
chat_result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
if not isinstance(chat_result, ChatResult):
raise ValueError(
f"Expected generate to return a ChatResult, "
f"but got {type(chat_result)} instead."
)
message = chat_result.generations[0].message
if not isinstance(message, AIMessage):
raise ValueError(
f"Expected invoke to return an AIMessage, "
f"but got {type(message)} instead."
)
content = message.content
if content:
# Use a regular expression to split on whitespace with a capture group
# so that we can preserve the whitespace in the output.
assert isinstance(content, str)
content_chunks = cast(List[str], re.split(r"(\s)", content))
for token in content_chunks:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
yield chunk
if run_manager:
run_manager.on_llm_new_token(token, chunk=chunk)
if message.additional_kwargs:
for key, value in message.additional_kwargs.items():
# We should further break down the additional kwargs into chunks
# Special case for function call
if key == "function_call":
for fkey, fvalue in value.items():
if isinstance(fvalue, str):
# Break function call by `,`
fvalue_chunks = cast(List[str], re.split(r"(,)", fvalue))
for fvalue_chunk in fvalue_chunks:
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content="",
additional_kwargs={
"function_call": {fkey: fvalue_chunk}
},
)
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
"",
chunk=chunk, # No token for function call
)
else:
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content="",
additional_kwargs={"function_call": {fkey: fvalue}},
)
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
"",
chunk=chunk, # No token for function call
)
else:
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content="", additional_kwargs={key: value}
)
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
"",
chunk=chunk, # No token for function call
)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
"""Stream the output of the model."""
result = await run_in_executor(
None,
self._stream,
messages,
stop=stop,
run_manager=run_manager.get_sync() if run_manager else None,
**kwargs,
)
for chunk in result:
yield chunk
@property
def _llm_type(self) -> str:
return "generic-fake-chat-model"

@ -0,0 +1,184 @@
"""Tests for verifying that testing utility code works as expected."""
from itertools import cycle
from typing import Any, Dict, List, Optional, Union
from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.fake.chat_model import GenericFakeChatModel
def test_generic_fake_chat_model_invoke() -> None:
# Will alternate between responding with hello and goodbye
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = model.invoke("meow")
assert response == AIMessage(content="hello")
response = model.invoke("kitty")
assert response == AIMessage(content="goodbye")
response = model.invoke("meow")
assert response == AIMessage(content="hello")
async def test_generic_fake_chat_model_ainvoke() -> None:
# Will alternate between responding with hello and goodbye
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello")
response = await model.ainvoke("kitty")
assert response == AIMessage(content="goodbye")
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello")
async def test_generic_fake_chat_model_stream() -> None:
"""Test streaming."""
infinite_cycle = cycle(
[
AIMessage(content="hello goodbye"),
]
)
model = GenericFakeChatModel(messages=infinite_cycle)
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
]
chunks = [chunk for chunk in model.stream("meow")]
assert chunks == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
]
# Test streaming of additional kwargs.
# Relying on insertion order of the additional kwargs dict
message = AIMessage(content="", additional_kwargs={"foo": 42, "bar": 24})
model = GenericFakeChatModel(messages=cycle([message]))
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(content="", additional_kwargs={"foo": 42}),
AIMessageChunk(content="", additional_kwargs={"bar": 24}),
]
message = AIMessage(
content="",
additional_kwargs={
"function_call": {
"name": "move_file",
"arguments": '{\n "source_path": "foo",\n "'
'destination_path": "bar"\n}',
}
},
)
model = GenericFakeChatModel(messages=cycle([message]))
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(
content="", additional_kwargs={"function_call": {"name": "move_file"}}
),
AIMessageChunk(
content="",
additional_kwargs={
"function_call": {"arguments": '{\n "source_path": "foo"'}
},
),
AIMessageChunk(
content="", additional_kwargs={"function_call": {"arguments": ","}}
),
AIMessageChunk(
content="",
additional_kwargs={
"function_call": {"arguments": '\n "destination_path": "bar"\n}'}
},
),
]
accumulate_chunks = None
for chunk in chunks:
if accumulate_chunks is None:
accumulate_chunks = chunk
else:
accumulate_chunks += chunk
assert accumulate_chunks == AIMessageChunk(
content="",
additional_kwargs={
"function_call": {
"name": "move_file",
"arguments": '{\n "source_path": "foo",\n "'
'destination_path": "bar"\n}',
}
},
)
async def test_generic_fake_chat_model_astream_log() -> None:
"""Test streaming."""
infinite_cycle = cycle([AIMessage(content="hello goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
log_patches = [
log_patch async for log_patch in model.astream_log("meow", diff=False)
]
final = log_patches[-1]
assert final.state["streamed_output"] == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
]
async def test_callback_handlers() -> None:
"""Verify that model is implemented correctly with handlers working."""
class MyCustomAsyncHandler(AsyncCallbackHandler):
def __init__(self, store: List[str]) -> None:
self.store = store
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
# Do nothing
# Required to implement since this is an abstract method
pass
async def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
self.store.append(token)
infinite_cycle = cycle(
[
AIMessage(content="hello goodbye"),
]
)
model = GenericFakeChatModel(messages=infinite_cycle)
tokens: List[str] = []
# New model
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
assert results == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
]
assert tokens == ["hello", " ", "goodbye"]
Loading…
Cancel
Save