From 4ee47926cafba0eb00851972783c1d66236f6f00 Mon Sep 17 00:00:00 2001 From: Zander Chase <130414180+vowelparrot@users.noreply.github.com> Date: Thu, 11 May 2023 11:06:39 -0700 Subject: [PATCH] Add on_chat_message_start (#4499) ### Add on_chat_message_start to callback manager and base tracer Goal: trace messages directly to permit reloading as chat messages (store in an integration-agnostic way) Add an `on_chat_message_start` method. Fall back to `on_llm_start()` for handlers that don't have it implemented. Does so in a non-backwards-compat breaking way (for now) --- langchain/callbacks/base.py | 40 +++++++- langchain/callbacks/manager.py | 99 ++++++++++++++++++- langchain/callbacks/tracers/langchain.py | 39 +++++++- langchain/callbacks/tracers/schemas.py | 2 +- langchain/chat_models/base.py | 11 +-- langchain/client/langchain.py | 26 ++--- langchain/client/utils.py | 42 -------- .../callbacks/fake_callback_handler.py | 28 +++++- .../callbacks/tracers/test_tracer.py | 28 ++++++ tests/unit_tests/client/test_utils.py | 70 ------------- tests/unit_tests/llms/fake_chat_model.py | 32 ++++++ tests/unit_tests/llms/test_callbacks.py | 34 ++++++- 12 files changed, 311 insertions(+), 140 deletions(-) delete mode 100644 langchain/client/utils.py delete mode 100644 tests/unit_tests/client/test_utils.py create mode 100644 tests/unit_tests/llms/fake_chat_model.py diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index d9f63b73..cd0f0d77 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -4,7 +4,12 @@ from __future__ import annotations from typing import Any, Dict, List, Optional, Union from uuid import UUID -from langchain.schema import AgentAction, AgentFinish, LLMResult +from langchain.schema import ( + AgentAction, + AgentFinish, + BaseMessage, + LLMResult, +) class LLMManagerMixin: @@ -123,6 +128,20 @@ class CallbackManagerMixin: ) -> Any: """Run when LLM starts running.""" + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when a chat model starts running.""" + raise NotImplementedError( + f"{self.__class__.__name__} does not implement `on_chat_model_start`" + ) + def on_chain_start( self, serialized: Dict[str, Any], @@ -184,6 +203,11 @@ class BaseCallbackHandler( """Whether to ignore agent callbacks.""" return False + @property + def ignore_chat_model(self) -> bool: + """Whether to ignore chat model callbacks.""" + return False + class AsyncCallbackHandler(BaseCallbackHandler): """Async callback handler that can be used to handle callbacks from langchain.""" @@ -199,6 +223,20 @@ class AsyncCallbackHandler(BaseCallbackHandler): ) -> None: """Run when LLM starts running.""" + async def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + """Run when a chat model starts running.""" + raise NotImplementedError( + f"{self.__class__.__name__} does not implement `on_chat_model_start`" + ) + async def on_llm_new_token( self, token: str, diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index 745f6dd2..4622f2a3 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import functools +import logging import os import warnings from contextlib import contextmanager @@ -22,8 +23,15 @@ from langchain.callbacks.stdout import StdOutCallbackHandler from langchain.callbacks.tracers.base import TracerSession from langchain.callbacks.tracers.langchain import LangChainTracer, LangChainTracerV2 from langchain.callbacks.tracers.schemas import TracerSessionV2 -from langchain.schema import AgentAction, AgentFinish, LLMResult +from langchain.schema import ( + AgentAction, + AgentFinish, + BaseMessage, + LLMResult, + get_buffer_string, +) +logger = logging.getLogger(__name__) Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar( @@ -87,15 +95,31 @@ def _handle_event( *args: Any, **kwargs: Any, ) -> None: + """Generic event handler for CallbackManager.""" + message_strings: Optional[List[str]] = None for handler in handlers: try: if ignore_condition_name is None or not getattr( handler, ignore_condition_name ): getattr(handler, event_name)(*args, **kwargs) + except NotImplementedError as e: + if event_name == "on_chat_model_start": + if message_strings is None: + message_strings = [get_buffer_string(m) for m in args[1]] + _handle_event( + [handler], + "on_llm_start", + "ignore_llm", + args[0], + message_strings, + *args[2:], + **kwargs, + ) + else: + logger.warning(f"Error in {event_name} callback: {e}") except Exception as e: - # TODO: switch this to use logging - print(f"Error in {event_name} callback: {e}") + logging.warning(f"Error in {event_name} callback: {e}") async def _ahandle_event_for_handler( @@ -114,9 +138,22 @@ async def _ahandle_event_for_handler( await asyncio.get_event_loop().run_in_executor( None, functools.partial(event, *args, **kwargs) ) + except NotImplementedError as e: + if event_name == "on_chat_model_start": + message_strings = [get_buffer_string(m) for m in args[1]] + await _ahandle_event_for_handler( + handler, + "on_llm", + "ignore_llm", + args[0], + message_strings, + *args[2:], + **kwargs, + ) + else: + logger.warning(f"Error in {event_name} callback: {e}") except Exception as e: - # TODO: switch this to use logging - print(f"Error in {event_name} callback: {e}") + logger.warning(f"Error in {event_name} callback: {e}") async def _ahandle_event( @@ -531,6 +568,33 @@ class CallbackManager(BaseCallbackManager): run_id, self.handlers, self.inheritable_handlers, self.parent_run_id ) + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> CallbackManagerForLLMRun: + """Run when LLM starts running.""" + if run_id is None: + run_id = uuid4() + _handle_event( + self.handlers, + "on_chat_model_start", + "ignore_chat_model", + serialized, + messages, + run_id=run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + # Re-use the LLM Run Manager since the outputs are treated + # the same for now + return CallbackManagerForLLMRun( + run_id, self.handlers, self.inheritable_handlers, self.parent_run_id + ) + def on_chain_start( self, serialized: Dict[str, Any], @@ -629,6 +693,31 @@ class AsyncCallbackManager(BaseCallbackManager): run_id, self.handlers, self.inheritable_handlers, self.parent_run_id ) + async def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + if run_id is None: + run_id = uuid4() + + await _ahandle_event( + self.handlers, + "on_chat_model_start", + "ignore_chat_model", + serialized, + messages, + run_id=run_id, + parent_run_id=self.parent_run_id, + **kwargs, + ) + + return AsyncCallbackManagerForLLMRun( + run_id, self.handlers, self.inheritable_handlers, self.parent_run_id + ) + async def on_chain_start( self, serialized: Dict[str, Any], diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index 1d581d35..65bddfcc 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging import os +from datetime import datetime from typing import Any, Dict, List, Optional, Union from uuid import UUID, uuid4 @@ -19,6 +20,7 @@ from langchain.callbacks.tracers.schemas import ( TracerSessionV2, TracerSessionV2Create, ) +from langchain.schema import BaseMessage, messages_to_dict from langchain.utils import raise_for_status_with_text @@ -193,6 +195,36 @@ class LangChainTracerV2(LangChainTracer): """Load the default tracing session and set it as the Tracer's session.""" return self.load_session("default") + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> None: + """Start a trace for an LLM run.""" + if self.session is None: + self.session = self.load_default_session() + + run_id_ = str(run_id) + parent_run_id_ = str(parent_run_id) if parent_run_id else None + + execution_order = self._get_execution_order(parent_run_id_) + llm_run = LLMRun( + uuid=run_id_, + parent_uuid=parent_run_id_, + serialized=serialized, + prompts=[], + extra={**kwargs, "messages": messages}, + start_time=datetime.utcnow(), + execution_order=execution_order, + child_execution_order=execution_order, + session_id=self.session.id, + ) + self._start_trace(llm_run) + def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> RunCreate: """Convert a run to a Run.""" session = self.session or self.load_default_session() @@ -201,7 +233,12 @@ class LangChainTracerV2(LangChainTracer): child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = [] if isinstance(run, LLMRun): run_type = "llm" - inputs = {"prompts": run.prompts} + if run.extra is not None and "messages" in run.extra: + messages: List[List[BaseMessage]] = run.extra.pop("messages") + converted_messages = [messages_to_dict(batch) for batch in messages] + inputs = {"messages": converted_messages} + else: + inputs = {"prompts": run.prompts} outputs = run.response.dict() if run.response else {} child_runs = [] elif isinstance(run, ChainRun): diff --git a/langchain/callbacks/tracers/schemas.py b/langchain/callbacks/tracers/schemas.py index f38094ae..39a0ab0e 100644 --- a/langchain/callbacks/tracers/schemas.py +++ b/langchain/callbacks/tracers/schemas.py @@ -117,6 +117,7 @@ class RunBase(BaseModel): session_id: UUID reference_example_id: Optional[UUID] run_type: RunTypeEnum + parent_run_id: Optional[UUID] class RunCreate(RunBase): @@ -130,7 +131,6 @@ class Run(RunBase): """Run schema when loading from the DB.""" name: str - parent_run_id: Optional[UUID] ChainRun.update_forward_refs() diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index b41aa237..3a8bfb1a 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -24,7 +24,6 @@ from langchain.schema import ( HumanMessage, LLMResult, PromptValue, - get_buffer_string, ) @@ -69,9 +68,8 @@ class BaseChatModel(BaseLanguageModel, ABC): callback_manager = CallbackManager.configure( callbacks, self.callbacks, self.verbose ) - message_strings = [get_buffer_string(m) for m in messages] - run_manager = callback_manager.on_llm_start( - {"name": self.__class__.__name__}, message_strings + run_manager = callback_manager.on_chat_model_start( + {"name": self.__class__.__name__}, messages ) new_arg_supported = inspect.signature(self._generate).parameters.get( @@ -104,9 +102,8 @@ class BaseChatModel(BaseLanguageModel, ABC): callback_manager = AsyncCallbackManager.configure( callbacks, self.callbacks, self.verbose ) - message_strings = [get_buffer_string(m) for m in messages] - run_manager = await callback_manager.on_llm_start( - {"name": self.__class__.__name__}, message_strings + run_manager = await callback_manager.on_chat_model_start( + {"name": self.__class__.__name__}, messages ) new_arg_supported = inspect.signature(self._agenerate).parameters.get( diff --git a/langchain/client/langchain.py b/langchain/client/langchain.py index 5e01e190..e197f055 100644 --- a/langchain/client/langchain.py +++ b/langchain/client/langchain.py @@ -31,9 +31,8 @@ from langchain.callbacks.tracers.langchain import LangChainTracerV2 from langchain.chains.base import Chain from langchain.chat_models.base import BaseChatModel from langchain.client.models import Dataset, DatasetCreate, Example, ExampleCreate -from langchain.client.utils import parse_chat_messages from langchain.llms.base import BaseLLM -from langchain.schema import ChatResult, LLMResult +from langchain.schema import ChatResult, LLMResult, messages_from_dict from langchain.utils import raise_for_status_with_text, xor_args if TYPE_CHECKING: @@ -96,7 +95,6 @@ class LangChainPlusClient(BaseSettings): "Unable to get seeded tenant ID. Please manually provide." ) from e results: List[dict] = response.json() - breakpoint() if len(results) == 0: raise ValueError("No seeded tenant found") return results[0]["id"] @@ -296,13 +294,15 @@ class LangChainPlusClient(BaseSettings): langchain_tracer: LangChainTracerV2, ) -> Union[LLMResult, ChatResult]: if isinstance(llm, BaseLLM): + if "prompts" not in inputs: + raise ValueError(f"LLM Run requires 'prompts' input. Got {inputs}") llm_prompts: List[str] = inputs["prompts"] llm_output = await llm.agenerate(llm_prompts, callbacks=[langchain_tracer]) elif isinstance(llm, BaseChatModel): - chat_prompts: List[str] = inputs["prompts"] - messages = [ - parse_chat_messages(chat_prompt) for chat_prompt in chat_prompts - ] + if "messages" not in inputs: + raise ValueError(f"Chat Run requires 'messages' input. Got {inputs}") + raw_messages: List[List[dict]] = inputs["messages"] + messages = [messages_from_dict(batch) for batch in raw_messages] llm_output = await llm.agenerate(messages, callbacks=[langchain_tracer]) else: raise ValueError(f"Unsupported LLM type {type(llm)}") @@ -454,13 +454,17 @@ class LangChainPlusClient(BaseSettings): ) -> Union[LLMResult, ChatResult]: """Run the language model on the example.""" if isinstance(llm, BaseLLM): + if "prompts" not in inputs: + raise ValueError(f"LLM Run must contain 'prompts' key. Got {inputs}") llm_prompts: List[str] = inputs["prompts"] llm_output = llm.generate(llm_prompts, callbacks=[langchain_tracer]) elif isinstance(llm, BaseChatModel): - chat_prompts: List[str] = inputs["prompts"] - messages = [ - parse_chat_messages(chat_prompt) for chat_prompt in chat_prompts - ] + if "messages" not in inputs: + raise ValueError( + f"Chat Model Run must contain 'messages' key. Got {inputs}" + ) + raw_messages: List[List[dict]] = inputs["messages"] + messages = [messages_from_dict(batch) for batch in raw_messages] llm_output = llm.generate(messages, callbacks=[langchain_tracer]) else: raise ValueError(f"Unsupported LLM type {type(llm)}") diff --git a/langchain/client/utils.py b/langchain/client/utils.py deleted file mode 100644 index f7ce264c..00000000 --- a/langchain/client/utils.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Client Utils.""" -import re -from typing import Dict, List, Optional, Sequence, Type, Union - -from langchain.schema import ( - AIMessage, - BaseMessage, - ChatMessage, - HumanMessage, - SystemMessage, -) - -_DEFAULT_MESSAGES_T = Union[Type[HumanMessage], Type[SystemMessage], Type[AIMessage]] -_RESOLUTION_MAP: Dict[str, _DEFAULT_MESSAGES_T] = { - "Human": HumanMessage, - "AI": AIMessage, - "System": SystemMessage, -} - - -def parse_chat_messages( - input_text: str, roles: Optional[Sequence[str]] = None -) -> List[BaseMessage]: - """Parse chat messages from a string. This is not robust.""" - roles = roles or ["Human", "AI", "System"] - roles_pattern = "|".join(roles) - pattern = ( - rf"(?P{roles_pattern}): (?P" - rf"(?:.*\n?)*?)(?=(?:{roles_pattern}): |\Z)" - ) - matches = re.finditer(pattern, input_text, re.MULTILINE) - - results: List[BaseMessage] = [] - for match in matches: - entity = match.group("entity") - message = match.group("message").rstrip("\n") - if entity in _RESOLUTION_MAP: - results.append(_RESOLUTION_MAP[entity](content=message)) - else: - results.append(ChatMessage(role=entity, content=message)) - - return results diff --git a/tests/unit_tests/callbacks/fake_callback_handler.py b/tests/unit_tests/callbacks/fake_callback_handler.py index ef2b8171..e167a70f 100644 --- a/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/tests/unit_tests/callbacks/fake_callback_handler.py @@ -1,9 +1,12 @@ """A fake callback handler for testing purposes.""" -from typing import Any +from itertools import chain +from typing import Any, Dict, List, Optional +from uuid import UUID from pydantic import BaseModel from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler +from langchain.schema import BaseMessage class BaseFakeCallbackHandler(BaseModel): @@ -16,6 +19,7 @@ class BaseFakeCallbackHandler(BaseModel): ignore_llm_: bool = False ignore_chain_: bool = False ignore_agent_: bool = False + ignore_chat_model_: bool = False # add finer-grained counters for easier debugging of failing tests chain_starts: int = 0 @@ -27,6 +31,7 @@ class BaseFakeCallbackHandler(BaseModel): tool_ends: int = 0 agent_actions: int = 0 agent_ends: int = 0 + chat_model_starts: int = 0 class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler): @@ -47,6 +52,7 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler): self.llm_streams += 1 def on_chain_start_common(self) -> None: + print("CHAIN START") self.chain_starts += 1 self.starts += 1 @@ -69,6 +75,7 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler): self.errors += 1 def on_agent_action_common(self) -> None: + print("AGENT ACTION") self.agent_actions += 1 self.starts += 1 @@ -76,6 +83,11 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler): self.agent_ends += 1 self.ends += 1 + def on_chat_model_start_common(self) -> None: + print("STARTING CHAT MODEL") + self.chat_model_starts += 1 + self.starts += 1 + def on_text_common(self) -> None: self.text += 1 @@ -193,6 +205,20 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): return self +class FakeCallbackHandlerWithChatStart(FakeCallbackHandler): + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + **kwargs: Any, + ) -> Any: + assert all(isinstance(m, BaseMessage) for m in chain(*messages)) + self.on_chat_model_start_common() + + class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixin): """Fake async callback handler for testing.""" diff --git a/tests/unit_tests/callbacks/tracers/test_tracer.py b/tests/unit_tests/callbacks/tracers/test_tracer.py index 488d1f70..5c0c4b11 100644 --- a/tests/unit_tests/callbacks/tracers/test_tracer.py +++ b/tests/unit_tests/callbacks/tracers/test_tracer.py @@ -10,6 +10,7 @@ from uuid import UUID, uuid4 import pytest from freezegun import freeze_time +from langchain.callbacks.manager import CallbackManager from langchain.callbacks.tracers.base import ( BaseTracer, ChainRun, @@ -96,6 +97,33 @@ def test_tracer_llm_run() -> None: assert tracer.runs == [compare_run] +@freeze_time("2023-01-01") +def test_tracer_chat_model_run() -> None: + """Test tracer on a Chat Model run.""" + uuid = uuid4() + compare_run = LLMRun( + uuid=str(uuid), + parent_uuid=None, + start_time=datetime.utcnow(), + end_time=datetime.utcnow(), + extra={}, + execution_order=1, + child_execution_order=1, + serialized={}, + prompts=[""], + response=LLMResult(generations=[[]]), + session_id=TEST_SESSION_ID, + error=None, + ) + tracer = FakeTracer() + + tracer.new_session() + manager = CallbackManager(handlers=[tracer]) + run_manager = manager.on_chat_model_start(serialized={}, messages=[[]], run_id=uuid) + run_manager.on_llm_end(response=LLMResult(generations=[[]])) + assert tracer.runs == [compare_run] + + @freeze_time("2023-01-01") def test_tracer_llm_run_errors_no_start() -> None: """Test tracer on an LLM run without a start.""" diff --git a/tests/unit_tests/client/test_utils.py b/tests/unit_tests/client/test_utils.py deleted file mode 100644 index 7ee405c5..00000000 --- a/tests/unit_tests/client/test_utils.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Test LangChain+ Client Utils.""" - -from typing import List - -from langchain.client.utils import parse_chat_messages -from langchain.schema import ( - AIMessage, - BaseMessage, - ChatMessage, - HumanMessage, - SystemMessage, -) - - -def test_parse_chat_messages() -> None: - """Test that chat messages are parsed correctly.""" - input_text = ( - "Human: I am human roar\nAI: I am AI beep boop\nSystem: I am a system message" - ) - expected = [ - HumanMessage(content="I am human roar"), - AIMessage(content="I am AI beep boop"), - SystemMessage(content="I am a system message"), - ] - assert parse_chat_messages(input_text) == expected - - -def test_parse_chat_messages_empty_input() -> None: - """Test that an empty input string returns an empty list.""" - input_text = "" - expected: List[BaseMessage] = [] - assert parse_chat_messages(input_text) == expected - - -def test_parse_chat_messages_multiline_messages() -> None: - """Test that multiline messages are parsed correctly.""" - input_text = ( - "Human: I am a human\nand I roar\nAI: I am an AI\nand I" - " beep boop\nSystem: I am a system\nand a message" - ) - expected = [ - HumanMessage(content="I am a human\nand I roar"), - AIMessage(content="I am an AI\nand I beep boop"), - SystemMessage(content="I am a system\nand a message"), - ] - assert parse_chat_messages(input_text) == expected - - -def test_parse_chat_messages_custom_roles() -> None: - """Test that custom roles are parsed correctly.""" - input_text = "Client: I need help\nAgent: I'm here to help\nClient: Thank you" - expected = [ - ChatMessage(role="Client", content="I need help"), - ChatMessage(role="Agent", content="I'm here to help"), - ChatMessage(role="Client", content="Thank you"), - ] - assert parse_chat_messages(input_text, roles=["Client", "Agent"]) == expected - - -def test_parse_chat_messages_embedded_roles() -> None: - """Test that messages with embedded role references are parsed correctly.""" - input_text = ( - "Human: Oh ai what if you said AI: foo bar?" - "\nAI: Well, that would be interesting!" - ) - expected = [ - HumanMessage(content="Oh ai what if you said AI: foo bar?"), - AIMessage(content="Well, that would be interesting!"), - ] - assert parse_chat_messages(input_text) == expected diff --git a/tests/unit_tests/llms/fake_chat_model.py b/tests/unit_tests/llms/fake_chat_model.py new file mode 100644 index 00000000..1f0a8e28 --- /dev/null +++ b/tests/unit_tests/llms/fake_chat_model.py @@ -0,0 +1,32 @@ +"""Fake Chat Model wrapper for testing purposes.""" +from typing import List, Optional + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.chat_models.base import SimpleChatModel +from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult + + +class FakeChatModel(SimpleChatModel): + """Fake Chat Model wrapper for testing purposes.""" + + def _call( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: + return "fake response" + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + ) -> ChatResult: + output_str = "fake response" + message = AIMessage(content=output_str) + generation = ChatGeneration(message=message) + return ChatResult(generations=[generation]) diff --git a/tests/unit_tests/llms/test_callbacks.py b/tests/unit_tests/llms/test_callbacks.py index ce0cf77f..8d0f8487 100644 --- a/tests/unit_tests/llms/test_callbacks.py +++ b/tests/unit_tests/llms/test_callbacks.py @@ -1,5 +1,10 @@ """Test LLM callbacks.""" -from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler +from langchain.schema import HumanMessage +from tests.unit_tests.callbacks.fake_callback_handler import ( + FakeCallbackHandler, + FakeCallbackHandlerWithChatStart, +) +from tests.unit_tests.llms.fake_chat_model import FakeChatModel from tests.unit_tests.llms.fake_llm import FakeLLM @@ -12,3 +17,30 @@ def test_llm_with_callbacks() -> None: assert handler.starts == 1 assert handler.ends == 1 assert handler.errors == 0 + + +def test_chat_model_with_v1_callbacks() -> None: + """Test chat model callbacks fall back to on_llm_start.""" + handler = FakeCallbackHandler() + llm = FakeChatModel(callbacks=[handler], verbose=True) + output = llm([HumanMessage(content="foo")]) + assert output.content == "fake response" + assert handler.starts == 1 + assert handler.ends == 1 + assert handler.errors == 0 + assert handler.llm_starts == 1 + assert handler.llm_ends == 1 + + +def test_chat_model_with_v2_callbacks() -> None: + """Test chat model callbacks fall back to on_llm_start.""" + handler = FakeCallbackHandlerWithChatStart() + llm = FakeChatModel(callbacks=[handler], verbose=True) + output = llm([HumanMessage(content="foo")]) + assert output.content == "fake response" + assert handler.starts == 1 + assert handler.ends == 1 + assert handler.errors == 0 + assert handler.llm_starts == 0 + assert handler.llm_ends == 1 + assert handler.chat_model_starts == 1