Add on_chat_message_start (#4499)

### Add on_chat_message_start to callback manager and base tracer

Goal: trace messages directly to permit reloading as chat messages
(store in an integration-agnostic way)

Add an `on_chat_message_start` method. Fall back to `on_llm_start()` for
handlers that don't have it implemented.

Does so in a non-backwards-compat breaking way (for now)
This commit is contained in:
Zander Chase 2023-05-11 11:06:39 -07:00 committed by GitHub
parent bbf76dbb52
commit 4ee47926ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 311 additions and 140 deletions

View File

@ -4,7 +4,12 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from uuid import UUID from uuid import UUID
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import (
AgentAction,
AgentFinish,
BaseMessage,
LLMResult,
)
class LLMManagerMixin: class LLMManagerMixin:
@ -123,6 +128,20 @@ class CallbackManagerMixin:
) -> Any: ) -> Any:
"""Run when LLM starts running.""" """Run when LLM starts running."""
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when a chat model starts running."""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
)
def on_chain_start( def on_chain_start(
self, self,
serialized: Dict[str, Any], serialized: Dict[str, Any],
@ -184,6 +203,11 @@ class BaseCallbackHandler(
"""Whether to ignore agent callbacks.""" """Whether to ignore agent callbacks."""
return False return False
@property
def ignore_chat_model(self) -> bool:
"""Whether to ignore chat model callbacks."""
return False
class AsyncCallbackHandler(BaseCallbackHandler): class AsyncCallbackHandler(BaseCallbackHandler):
"""Async callback handler that can be used to handle callbacks from langchain.""" """Async callback handler that can be used to handle callbacks from langchain."""
@ -199,6 +223,20 @@ class AsyncCallbackHandler(BaseCallbackHandler):
) -> None: ) -> None:
"""Run when LLM starts running.""" """Run when LLM starts running."""
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when a chat model starts running."""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
)
async def on_llm_new_token( async def on_llm_new_token(
self, self,
token: str, token: str,

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
import functools import functools
import logging
import os import os
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
@ -22,8 +23,15 @@ from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.tracers.base import TracerSession from langchain.callbacks.tracers.base import TracerSession
from langchain.callbacks.tracers.langchain import LangChainTracer, LangChainTracerV2 from langchain.callbacks.tracers.langchain import LangChainTracer, LangChainTracerV2
from langchain.callbacks.tracers.schemas import TracerSessionV2 from langchain.callbacks.tracers.schemas import TracerSessionV2
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import (
AgentAction,
AgentFinish,
BaseMessage,
LLMResult,
get_buffer_string,
)
logger = logging.getLogger(__name__)
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]] Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar( openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
@ -87,15 +95,31 @@ def _handle_event(
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Generic event handler for CallbackManager."""
message_strings: Optional[List[str]] = None
for handler in handlers: for handler in handlers:
try: try:
if ignore_condition_name is None or not getattr( if ignore_condition_name is None or not getattr(
handler, ignore_condition_name handler, ignore_condition_name
): ):
getattr(handler, event_name)(*args, **kwargs) getattr(handler, event_name)(*args, **kwargs)
except NotImplementedError as e:
if event_name == "on_chat_model_start":
if message_strings is None:
message_strings = [get_buffer_string(m) for m in args[1]]
_handle_event(
[handler],
"on_llm_start",
"ignore_llm",
args[0],
message_strings,
*args[2:],
**kwargs,
)
else:
logger.warning(f"Error in {event_name} callback: {e}")
except Exception as e: except Exception as e:
# TODO: switch this to use logging logging.warning(f"Error in {event_name} callback: {e}")
print(f"Error in {event_name} callback: {e}")
async def _ahandle_event_for_handler( async def _ahandle_event_for_handler(
@ -114,9 +138,22 @@ async def _ahandle_event_for_handler(
await asyncio.get_event_loop().run_in_executor( await asyncio.get_event_loop().run_in_executor(
None, functools.partial(event, *args, **kwargs) None, functools.partial(event, *args, **kwargs)
) )
except NotImplementedError as e:
if event_name == "on_chat_model_start":
message_strings = [get_buffer_string(m) for m in args[1]]
await _ahandle_event_for_handler(
handler,
"on_llm",
"ignore_llm",
args[0],
message_strings,
*args[2:],
**kwargs,
)
else:
logger.warning(f"Error in {event_name} callback: {e}")
except Exception as e: except Exception as e:
# TODO: switch this to use logging logger.warning(f"Error in {event_name} callback: {e}")
print(f"Error in {event_name} callback: {e}")
async def _ahandle_event( async def _ahandle_event(
@ -531,6 +568,33 @@ class CallbackManager(BaseCallbackManager):
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
) )
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> CallbackManagerForLLMRun:
"""Run when LLM starts running."""
if run_id is None:
run_id = uuid4()
_handle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
messages,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
# Re-use the LLM Run Manager since the outputs are treated
# the same for now
return CallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
def on_chain_start( def on_chain_start(
self, self,
serialized: Dict[str, Any], serialized: Dict[str, Any],
@ -629,6 +693,31 @@ class AsyncCallbackManager(BaseCallbackManager):
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
) )
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if run_id is None:
run_id = uuid4()
await _ahandle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
messages,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
return AsyncCallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
async def on_chain_start( async def on_chain_start(
self, self,
serialized: Dict[str, Any], serialized: Dict[str, Any],

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import logging import logging
import os import os
from datetime import datetime
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from uuid import UUID, uuid4 from uuid import UUID, uuid4
@ -19,6 +20,7 @@ from langchain.callbacks.tracers.schemas import (
TracerSessionV2, TracerSessionV2,
TracerSessionV2Create, TracerSessionV2Create,
) )
from langchain.schema import BaseMessage, messages_to_dict
from langchain.utils import raise_for_status_with_text from langchain.utils import raise_for_status_with_text
@ -193,6 +195,36 @@ class LangChainTracerV2(LangChainTracer):
"""Load the default tracing session and set it as the Tracer's session.""" """Load the default tracing session and set it as the Tracer's session."""
return self.load_session("default") return self.load_session("default")
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Start a trace for an LLM run."""
if self.session is None:
self.session = self.load_default_session()
run_id_ = str(run_id)
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
llm_run = LLMRun(
uuid=run_id_,
parent_uuid=parent_run_id_,
serialized=serialized,
prompts=[],
extra={**kwargs, "messages": messages},
start_time=datetime.utcnow(),
execution_order=execution_order,
child_execution_order=execution_order,
session_id=self.session.id,
)
self._start_trace(llm_run)
def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> RunCreate: def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> RunCreate:
"""Convert a run to a Run.""" """Convert a run to a Run."""
session = self.session or self.load_default_session() session = self.session or self.load_default_session()
@ -201,7 +233,12 @@ class LangChainTracerV2(LangChainTracer):
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = [] child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
if isinstance(run, LLMRun): if isinstance(run, LLMRun):
run_type = "llm" run_type = "llm"
inputs = {"prompts": run.prompts} if run.extra is not None and "messages" in run.extra:
messages: List[List[BaseMessage]] = run.extra.pop("messages")
converted_messages = [messages_to_dict(batch) for batch in messages]
inputs = {"messages": converted_messages}
else:
inputs = {"prompts": run.prompts}
outputs = run.response.dict() if run.response else {} outputs = run.response.dict() if run.response else {}
child_runs = [] child_runs = []
elif isinstance(run, ChainRun): elif isinstance(run, ChainRun):

View File

@ -117,6 +117,7 @@ class RunBase(BaseModel):
session_id: UUID session_id: UUID
reference_example_id: Optional[UUID] reference_example_id: Optional[UUID]
run_type: RunTypeEnum run_type: RunTypeEnum
parent_run_id: Optional[UUID]
class RunCreate(RunBase): class RunCreate(RunBase):
@ -130,7 +131,6 @@ class Run(RunBase):
"""Run schema when loading from the DB.""" """Run schema when loading from the DB."""
name: str name: str
parent_run_id: Optional[UUID]
ChainRun.update_forward_refs() ChainRun.update_forward_refs()

View File

@ -24,7 +24,6 @@ from langchain.schema import (
HumanMessage, HumanMessage,
LLMResult, LLMResult,
PromptValue, PromptValue,
get_buffer_string,
) )
@ -69,9 +68,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
callback_manager = CallbackManager.configure( callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose callbacks, self.callbacks, self.verbose
) )
message_strings = [get_buffer_string(m) for m in messages] run_manager = callback_manager.on_chat_model_start(
run_manager = callback_manager.on_llm_start( {"name": self.__class__.__name__}, messages
{"name": self.__class__.__name__}, message_strings
) )
new_arg_supported = inspect.signature(self._generate).parameters.get( new_arg_supported = inspect.signature(self._generate).parameters.get(
@ -104,9 +102,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
callback_manager = AsyncCallbackManager.configure( callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose callbacks, self.callbacks, self.verbose
) )
message_strings = [get_buffer_string(m) for m in messages] run_manager = await callback_manager.on_chat_model_start(
run_manager = await callback_manager.on_llm_start( {"name": self.__class__.__name__}, messages
{"name": self.__class__.__name__}, message_strings
) )
new_arg_supported = inspect.signature(self._agenerate).parameters.get( new_arg_supported = inspect.signature(self._agenerate).parameters.get(

View File

@ -31,9 +31,8 @@ from langchain.callbacks.tracers.langchain import LangChainTracerV2
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
from langchain.client.models import Dataset, DatasetCreate, Example, ExampleCreate from langchain.client.models import Dataset, DatasetCreate, Example, ExampleCreate
from langchain.client.utils import parse_chat_messages
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from langchain.schema import ChatResult, LLMResult from langchain.schema import ChatResult, LLMResult, messages_from_dict
from langchain.utils import raise_for_status_with_text, xor_args from langchain.utils import raise_for_status_with_text, xor_args
if TYPE_CHECKING: if TYPE_CHECKING:
@ -96,7 +95,6 @@ class LangChainPlusClient(BaseSettings):
"Unable to get seeded tenant ID. Please manually provide." "Unable to get seeded tenant ID. Please manually provide."
) from e ) from e
results: List[dict] = response.json() results: List[dict] = response.json()
breakpoint()
if len(results) == 0: if len(results) == 0:
raise ValueError("No seeded tenant found") raise ValueError("No seeded tenant found")
return results[0]["id"] return results[0]["id"]
@ -296,13 +294,15 @@ class LangChainPlusClient(BaseSettings):
langchain_tracer: LangChainTracerV2, langchain_tracer: LangChainTracerV2,
) -> Union[LLMResult, ChatResult]: ) -> Union[LLMResult, ChatResult]:
if isinstance(llm, BaseLLM): if isinstance(llm, BaseLLM):
if "prompts" not in inputs:
raise ValueError(f"LLM Run requires 'prompts' input. Got {inputs}")
llm_prompts: List[str] = inputs["prompts"] llm_prompts: List[str] = inputs["prompts"]
llm_output = await llm.agenerate(llm_prompts, callbacks=[langchain_tracer]) llm_output = await llm.agenerate(llm_prompts, callbacks=[langchain_tracer])
elif isinstance(llm, BaseChatModel): elif isinstance(llm, BaseChatModel):
chat_prompts: List[str] = inputs["prompts"] if "messages" not in inputs:
messages = [ raise ValueError(f"Chat Run requires 'messages' input. Got {inputs}")
parse_chat_messages(chat_prompt) for chat_prompt in chat_prompts raw_messages: List[List[dict]] = inputs["messages"]
] messages = [messages_from_dict(batch) for batch in raw_messages]
llm_output = await llm.agenerate(messages, callbacks=[langchain_tracer]) llm_output = await llm.agenerate(messages, callbacks=[langchain_tracer])
else: else:
raise ValueError(f"Unsupported LLM type {type(llm)}") raise ValueError(f"Unsupported LLM type {type(llm)}")
@ -454,13 +454,17 @@ class LangChainPlusClient(BaseSettings):
) -> Union[LLMResult, ChatResult]: ) -> Union[LLMResult, ChatResult]:
"""Run the language model on the example.""" """Run the language model on the example."""
if isinstance(llm, BaseLLM): if isinstance(llm, BaseLLM):
if "prompts" not in inputs:
raise ValueError(f"LLM Run must contain 'prompts' key. Got {inputs}")
llm_prompts: List[str] = inputs["prompts"] llm_prompts: List[str] = inputs["prompts"]
llm_output = llm.generate(llm_prompts, callbacks=[langchain_tracer]) llm_output = llm.generate(llm_prompts, callbacks=[langchain_tracer])
elif isinstance(llm, BaseChatModel): elif isinstance(llm, BaseChatModel):
chat_prompts: List[str] = inputs["prompts"] if "messages" not in inputs:
messages = [ raise ValueError(
parse_chat_messages(chat_prompt) for chat_prompt in chat_prompts f"Chat Model Run must contain 'messages' key. Got {inputs}"
] )
raw_messages: List[List[dict]] = inputs["messages"]
messages = [messages_from_dict(batch) for batch in raw_messages]
llm_output = llm.generate(messages, callbacks=[langchain_tracer]) llm_output = llm.generate(messages, callbacks=[langchain_tracer])
else: else:
raise ValueError(f"Unsupported LLM type {type(llm)}") raise ValueError(f"Unsupported LLM type {type(llm)}")

View File

@ -1,42 +0,0 @@
"""Client Utils."""
import re
from typing import Dict, List, Optional, Sequence, Type, Union
from langchain.schema import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
_DEFAULT_MESSAGES_T = Union[Type[HumanMessage], Type[SystemMessage], Type[AIMessage]]
_RESOLUTION_MAP: Dict[str, _DEFAULT_MESSAGES_T] = {
"Human": HumanMessage,
"AI": AIMessage,
"System": SystemMessage,
}
def parse_chat_messages(
input_text: str, roles: Optional[Sequence[str]] = None
) -> List[BaseMessage]:
"""Parse chat messages from a string. This is not robust."""
roles = roles or ["Human", "AI", "System"]
roles_pattern = "|".join(roles)
pattern = (
rf"(?P<entity>{roles_pattern}): (?P<message>"
rf"(?:.*\n?)*?)(?=(?:{roles_pattern}): |\Z)"
)
matches = re.finditer(pattern, input_text, re.MULTILINE)
results: List[BaseMessage] = []
for match in matches:
entity = match.group("entity")
message = match.group("message").rstrip("\n")
if entity in _RESOLUTION_MAP:
results.append(_RESOLUTION_MAP[entity](content=message))
else:
results.append(ChatMessage(role=entity, content=message))
return results

View File

@ -1,9 +1,12 @@
"""A fake callback handler for testing purposes.""" """A fake callback handler for testing purposes."""
from typing import Any from itertools import chain
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel from pydantic import BaseModel
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain.schema import BaseMessage
class BaseFakeCallbackHandler(BaseModel): class BaseFakeCallbackHandler(BaseModel):
@ -16,6 +19,7 @@ class BaseFakeCallbackHandler(BaseModel):
ignore_llm_: bool = False ignore_llm_: bool = False
ignore_chain_: bool = False ignore_chain_: bool = False
ignore_agent_: bool = False ignore_agent_: bool = False
ignore_chat_model_: bool = False
# add finer-grained counters for easier debugging of failing tests # add finer-grained counters for easier debugging of failing tests
chain_starts: int = 0 chain_starts: int = 0
@ -27,6 +31,7 @@ class BaseFakeCallbackHandler(BaseModel):
tool_ends: int = 0 tool_ends: int = 0
agent_actions: int = 0 agent_actions: int = 0
agent_ends: int = 0 agent_ends: int = 0
chat_model_starts: int = 0
class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler): class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
@ -47,6 +52,7 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
self.llm_streams += 1 self.llm_streams += 1
def on_chain_start_common(self) -> None: def on_chain_start_common(self) -> None:
print("CHAIN START")
self.chain_starts += 1 self.chain_starts += 1
self.starts += 1 self.starts += 1
@ -69,6 +75,7 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
self.errors += 1 self.errors += 1
def on_agent_action_common(self) -> None: def on_agent_action_common(self) -> None:
print("AGENT ACTION")
self.agent_actions += 1 self.agent_actions += 1
self.starts += 1 self.starts += 1
@ -76,6 +83,11 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
self.agent_ends += 1 self.agent_ends += 1
self.ends += 1 self.ends += 1
def on_chat_model_start_common(self) -> None:
print("STARTING CHAT MODEL")
self.chat_model_starts += 1
self.starts += 1
def on_text_common(self) -> None: def on_text_common(self) -> None:
self.text += 1 self.text += 1
@ -193,6 +205,20 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
return self return self
class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
assert all(isinstance(m, BaseMessage) for m in chain(*messages))
self.on_chat_model_start_common()
class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixin): class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixin):
"""Fake async callback handler for testing.""" """Fake async callback handler for testing."""

View File

@ -10,6 +10,7 @@ from uuid import UUID, uuid4
import pytest import pytest
from freezegun import freeze_time from freezegun import freeze_time
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.tracers.base import ( from langchain.callbacks.tracers.base import (
BaseTracer, BaseTracer,
ChainRun, ChainRun,
@ -96,6 +97,33 @@ def test_tracer_llm_run() -> None:
assert tracer.runs == [compare_run] assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_chat_model_run() -> None:
"""Test tracer on a Chat Model run."""
uuid = uuid4()
compare_run = LLMRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={},
prompts=[""],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
manager = CallbackManager(handlers=[tracer])
run_manager = manager.on_chat_model_start(serialized={}, messages=[[]], run_id=uuid)
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
def test_tracer_llm_run_errors_no_start() -> None: def test_tracer_llm_run_errors_no_start() -> None:
"""Test tracer on an LLM run without a start.""" """Test tracer on an LLM run without a start."""

View File

@ -1,70 +0,0 @@
"""Test LangChain+ Client Utils."""
from typing import List
from langchain.client.utils import parse_chat_messages
from langchain.schema import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
def test_parse_chat_messages() -> None:
"""Test that chat messages are parsed correctly."""
input_text = (
"Human: I am human roar\nAI: I am AI beep boop\nSystem: I am a system message"
)
expected = [
HumanMessage(content="I am human roar"),
AIMessage(content="I am AI beep boop"),
SystemMessage(content="I am a system message"),
]
assert parse_chat_messages(input_text) == expected
def test_parse_chat_messages_empty_input() -> None:
"""Test that an empty input string returns an empty list."""
input_text = ""
expected: List[BaseMessage] = []
assert parse_chat_messages(input_text) == expected
def test_parse_chat_messages_multiline_messages() -> None:
"""Test that multiline messages are parsed correctly."""
input_text = (
"Human: I am a human\nand I roar\nAI: I am an AI\nand I"
" beep boop\nSystem: I am a system\nand a message"
)
expected = [
HumanMessage(content="I am a human\nand I roar"),
AIMessage(content="I am an AI\nand I beep boop"),
SystemMessage(content="I am a system\nand a message"),
]
assert parse_chat_messages(input_text) == expected
def test_parse_chat_messages_custom_roles() -> None:
"""Test that custom roles are parsed correctly."""
input_text = "Client: I need help\nAgent: I'm here to help\nClient: Thank you"
expected = [
ChatMessage(role="Client", content="I need help"),
ChatMessage(role="Agent", content="I'm here to help"),
ChatMessage(role="Client", content="Thank you"),
]
assert parse_chat_messages(input_text, roles=["Client", "Agent"]) == expected
def test_parse_chat_messages_embedded_roles() -> None:
"""Test that messages with embedded role references are parsed correctly."""
input_text = (
"Human: Oh ai what if you said AI: foo bar?"
"\nAI: Well, that would be interesting!"
)
expected = [
HumanMessage(content="Oh ai what if you said AI: foo bar?"),
AIMessage(content="Well, that would be interesting!"),
]
assert parse_chat_messages(input_text) == expected

View File

@ -0,0 +1,32 @@
"""Fake Chat Model wrapper for testing purposes."""
from typing import List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import SimpleChatModel
from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult
class FakeChatModel(SimpleChatModel):
"""Fake Chat Model wrapper for testing purposes."""
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
return "fake response"
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> ChatResult:
output_str = "fake response"
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])

View File

@ -1,5 +1,10 @@
"""Test LLM callbacks.""" """Test LLM callbacks."""
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler from langchain.schema import HumanMessage
from tests.unit_tests.callbacks.fake_callback_handler import (
FakeCallbackHandler,
FakeCallbackHandlerWithChatStart,
)
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM from tests.unit_tests.llms.fake_llm import FakeLLM
@ -12,3 +17,30 @@ def test_llm_with_callbacks() -> None:
assert handler.starts == 1 assert handler.starts == 1
assert handler.ends == 1 assert handler.ends == 1
assert handler.errors == 0 assert handler.errors == 0
def test_chat_model_with_v1_callbacks() -> None:
"""Test chat model callbacks fall back to on_llm_start."""
handler = FakeCallbackHandler()
llm = FakeChatModel(callbacks=[handler], verbose=True)
output = llm([HumanMessage(content="foo")])
assert output.content == "fake response"
assert handler.starts == 1
assert handler.ends == 1
assert handler.errors == 0
assert handler.llm_starts == 1
assert handler.llm_ends == 1
def test_chat_model_with_v2_callbacks() -> None:
"""Test chat model callbacks fall back to on_llm_start."""
handler = FakeCallbackHandlerWithChatStart()
llm = FakeChatModel(callbacks=[handler], verbose=True)
output = llm([HumanMessage(content="foo")])
assert output.content == "fake response"
assert handler.starts == 1
assert handler.ends == 1
assert handler.errors == 0
assert handler.llm_starts == 0
assert handler.llm_ends == 1
assert handler.chat_model_starts == 1