diff --git a/libs/core/tests/unit_tests/fake/chat_model.py b/libs/core/tests/unit_tests/fake/chat_model.py index 717ab02533..98f05b6ca6 100644 --- a/libs/core/tests/unit_tests/fake/chat_model.py +++ b/libs/core/tests/unit_tests/fake/chat_model.py @@ -1,15 +1,21 @@ -"""Fake ChatModel for testing purposes.""" +"""Fake Chat Model wrapper for testing purposes.""" import asyncio +import re import time -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union, cast from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel -from langchain_core.messages import AIMessageChunk, BaseMessage +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, +) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.runnables import run_in_executor class FakeMessagesListChatModel(BaseChatModel): @@ -114,3 +120,184 @@ class FakeListChatModel(SimpleChatModel): @property def _identifying_params(self) -> Dict[str, Any]: return {"responses": self.responses} + + +class FakeChatModel(SimpleChatModel): + """Fake Chat Model wrapper for testing purposes.""" + + def _call( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + return "fake response" + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + output_str = "fake response" + message = AIMessage(content=output_str) + generation = ChatGeneration(message=message) + return ChatResult(generations=[generation]) + + @property + def _llm_type(self) -> str: + return "fake-chat-model" + + @property + def _identifying_params(self) -> Dict[str, Any]: + return {"key": "fake"} + + +class GenericFakeChatModel(BaseChatModel): + """A generic fake chat model that can be used to test the chat model interface. + + * Chat model should be usable in both sync and async tests + * Invokes on_llm_new_token to allow for testing of callback related code for new + tokens. + * Includes logic to break messages into message chunk to facilitate testing of + streaming. + """ + + messages: Iterator[AIMessage] + """Get an iterator over messages. + + This can be expanded to accept other types like Callables / dicts / strings + to make the interface more generic if needed. + + Note: if you want to pass a list, you can use `iter` to convert it to an iterator. + + Please note that streaming is not implemented yet. We should try to implement it + in the future by delegating to invoke and then breaking the resulting output + into message chunks. + """ + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Top Level call""" + message = next(self.messages) + generation = ChatGeneration(message=message) + return ChatResult(generations=[generation]) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + """Stream the output of the model.""" + chat_result = self._generate( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + if not isinstance(chat_result, ChatResult): + raise ValueError( + f"Expected generate to return a ChatResult, " + f"but got {type(chat_result)} instead." + ) + + message = chat_result.generations[0].message + + if not isinstance(message, AIMessage): + raise ValueError( + f"Expected invoke to return an AIMessage, " + f"but got {type(message)} instead." + ) + + content = message.content + + if content: + # Use a regular expression to split on whitespace with a capture group + # so that we can preserve the whitespace in the output. + assert isinstance(content, str) + content_chunks = cast(List[str], re.split(r"(\s)", content)) + + for token in content_chunks: + chunk = ChatGenerationChunk(message=AIMessageChunk(content=token)) + yield chunk + if run_manager: + run_manager.on_llm_new_token(token, chunk=chunk) + + if message.additional_kwargs: + for key, value in message.additional_kwargs.items(): + # We should further break down the additional kwargs into chunks + # Special case for function call + if key == "function_call": + for fkey, fvalue in value.items(): + if isinstance(fvalue, str): + # Break function call by `,` + fvalue_chunks = cast(List[str], re.split(r"(,)", fvalue)) + for fvalue_chunk in fvalue_chunks: + chunk = ChatGenerationChunk( + message=AIMessageChunk( + content="", + additional_kwargs={ + "function_call": {fkey: fvalue_chunk} + }, + ) + ) + yield chunk + if run_manager: + run_manager.on_llm_new_token( + "", + chunk=chunk, # No token for function call + ) + else: + chunk = ChatGenerationChunk( + message=AIMessageChunk( + content="", + additional_kwargs={"function_call": {fkey: fvalue}}, + ) + ) + yield chunk + if run_manager: + run_manager.on_llm_new_token( + "", + chunk=chunk, # No token for function call + ) + else: + chunk = ChatGenerationChunk( + message=AIMessageChunk( + content="", additional_kwargs={key: value} + ) + ) + yield chunk + if run_manager: + run_manager.on_llm_new_token( + "", + chunk=chunk, # No token for function call + ) + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + """Stream the output of the model.""" + result = await run_in_executor( + None, + self._stream, + messages, + stop=stop, + run_manager=run_manager.get_sync() if run_manager else None, + **kwargs, + ) + for chunk in result: + yield chunk + + @property + def _llm_type(self) -> str: + return "generic-fake-chat-model" diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py new file mode 100644 index 0000000000..8700f0751c --- /dev/null +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -0,0 +1,184 @@ +"""Tests for verifying that testing utility code works as expected.""" +from itertools import cycle +from typing import Any, Dict, List, Optional, Union +from uuid import UUID + +from langchain_core.callbacks.base import AsyncCallbackHandler +from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage +from langchain_core.outputs import ChatGenerationChunk, GenerationChunk +from tests.unit_tests.fake.chat_model import GenericFakeChatModel + + +def test_generic_fake_chat_model_invoke() -> None: + # Will alternate between responding with hello and goodbye + infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")]) + model = GenericFakeChatModel(messages=infinite_cycle) + response = model.invoke("meow") + assert response == AIMessage(content="hello") + response = model.invoke("kitty") + assert response == AIMessage(content="goodbye") + response = model.invoke("meow") + assert response == AIMessage(content="hello") + + +async def test_generic_fake_chat_model_ainvoke() -> None: + # Will alternate between responding with hello and goodbye + infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")]) + model = GenericFakeChatModel(messages=infinite_cycle) + response = await model.ainvoke("meow") + assert response == AIMessage(content="hello") + response = await model.ainvoke("kitty") + assert response == AIMessage(content="goodbye") + response = await model.ainvoke("meow") + assert response == AIMessage(content="hello") + + +async def test_generic_fake_chat_model_stream() -> None: + """Test streaming.""" + infinite_cycle = cycle( + [ + AIMessage(content="hello goodbye"), + ] + ) + model = GenericFakeChatModel(messages=infinite_cycle) + chunks = [chunk async for chunk in model.astream("meow")] + assert chunks == [ + AIMessageChunk(content="hello"), + AIMessageChunk(content=" "), + AIMessageChunk(content="goodbye"), + ] + + chunks = [chunk for chunk in model.stream("meow")] + assert chunks == [ + AIMessageChunk(content="hello"), + AIMessageChunk(content=" "), + AIMessageChunk(content="goodbye"), + ] + + # Test streaming of additional kwargs. + # Relying on insertion order of the additional kwargs dict + message = AIMessage(content="", additional_kwargs={"foo": 42, "bar": 24}) + model = GenericFakeChatModel(messages=cycle([message])) + chunks = [chunk async for chunk in model.astream("meow")] + assert chunks == [ + AIMessageChunk(content="", additional_kwargs={"foo": 42}), + AIMessageChunk(content="", additional_kwargs={"bar": 24}), + ] + + message = AIMessage( + content="", + additional_kwargs={ + "function_call": { + "name": "move_file", + "arguments": '{\n "source_path": "foo",\n "' + 'destination_path": "bar"\n}', + } + }, + ) + model = GenericFakeChatModel(messages=cycle([message])) + chunks = [chunk async for chunk in model.astream("meow")] + + assert chunks == [ + AIMessageChunk( + content="", additional_kwargs={"function_call": {"name": "move_file"}} + ), + AIMessageChunk( + content="", + additional_kwargs={ + "function_call": {"arguments": '{\n "source_path": "foo"'} + }, + ), + AIMessageChunk( + content="", additional_kwargs={"function_call": {"arguments": ","}} + ), + AIMessageChunk( + content="", + additional_kwargs={ + "function_call": {"arguments": '\n "destination_path": "bar"\n}'} + }, + ), + ] + + accumulate_chunks = None + for chunk in chunks: + if accumulate_chunks is None: + accumulate_chunks = chunk + else: + accumulate_chunks += chunk + + assert accumulate_chunks == AIMessageChunk( + content="", + additional_kwargs={ + "function_call": { + "name": "move_file", + "arguments": '{\n "source_path": "foo",\n "' + 'destination_path": "bar"\n}', + } + }, + ) + + +async def test_generic_fake_chat_model_astream_log() -> None: + """Test streaming.""" + infinite_cycle = cycle([AIMessage(content="hello goodbye")]) + model = GenericFakeChatModel(messages=infinite_cycle) + log_patches = [ + log_patch async for log_patch in model.astream_log("meow", diff=False) + ] + final = log_patches[-1] + assert final.state["streamed_output"] == [ + AIMessageChunk(content="hello"), + AIMessageChunk(content=" "), + AIMessageChunk(content="goodbye"), + ] + + +async def test_callback_handlers() -> None: + """Verify that model is implemented correctly with handlers working.""" + + class MyCustomAsyncHandler(AsyncCallbackHandler): + def __init__(self, store: List[str]) -> None: + self.store = store + + async def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Any: + # Do nothing + # Required to implement since this is an abstract method + pass + + async def on_llm_new_token( + self, + token: str, + *, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + self.store.append(token) + + infinite_cycle = cycle( + [ + AIMessage(content="hello goodbye"), + ] + ) + model = GenericFakeChatModel(messages=infinite_cycle) + tokens: List[str] = [] + # New model + results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]})) + assert results == [ + AIMessageChunk(content="hello"), + AIMessageChunk(content=" "), + AIMessageChunk(content="goodbye"), + ] + assert tokens == ["hello", " ", "goodbye"]