Add on_chat_message_start (#4499)

### Add on_chat_message_start to callback manager and base tracer

Goal: trace messages directly to permit reloading as chat messages
(store in an integration-agnostic way)

Add an `on_chat_message_start` method. Fall back to `on_llm_start()` for
handlers that don't have it implemented.

Does so in a non-backwards-compat breaking way (for now)
This commit is contained in:
Zander Chase 2023-05-11 11:06:39 -07:00 committed by GitHub
parent bbf76dbb52
commit 4ee47926ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 311 additions and 140 deletions

View File

@ -4,7 +4,12 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional, Union
from uuid import UUID
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema import (
AgentAction,
AgentFinish,
BaseMessage,
LLMResult,
)
class LLMManagerMixin:
@ -123,6 +128,20 @@ class CallbackManagerMixin:
) -> Any:
"""Run when LLM starts running."""
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when a chat model starts running."""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
)
def on_chain_start(
self,
serialized: Dict[str, Any],
@ -184,6 +203,11 @@ class BaseCallbackHandler(
"""Whether to ignore agent callbacks."""
return False
@property
def ignore_chat_model(self) -> bool:
"""Whether to ignore chat model callbacks."""
return False
class AsyncCallbackHandler(BaseCallbackHandler):
"""Async callback handler that can be used to handle callbacks from langchain."""
@ -199,6 +223,20 @@ class AsyncCallbackHandler(BaseCallbackHandler):
) -> None:
"""Run when LLM starts running."""
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run when a chat model starts running."""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
)
async def on_llm_new_token(
self,
token: str,

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import functools
import logging
import os
import warnings
from contextlib import contextmanager
@ -22,8 +23,15 @@ from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.tracers.base import TracerSession
from langchain.callbacks.tracers.langchain import LangChainTracer, LangChainTracerV2
from langchain.callbacks.tracers.schemas import TracerSessionV2
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.schema import (
AgentAction,
AgentFinish,
BaseMessage,
LLMResult,
get_buffer_string,
)
logger = logging.getLogger(__name__)
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
@ -87,15 +95,31 @@ def _handle_event(
*args: Any,
**kwargs: Any,
) -> None:
"""Generic event handler for CallbackManager."""
message_strings: Optional[List[str]] = None
for handler in handlers:
try:
if ignore_condition_name is None or not getattr(
handler, ignore_condition_name
):
getattr(handler, event_name)(*args, **kwargs)
except NotImplementedError as e:
if event_name == "on_chat_model_start":
if message_strings is None:
message_strings = [get_buffer_string(m) for m in args[1]]
_handle_event(
[handler],
"on_llm_start",
"ignore_llm",
args[0],
message_strings,
*args[2:],
**kwargs,
)
else:
logger.warning(f"Error in {event_name} callback: {e}")
except Exception as e:
# TODO: switch this to use logging
print(f"Error in {event_name} callback: {e}")
logging.warning(f"Error in {event_name} callback: {e}")
async def _ahandle_event_for_handler(
@ -114,9 +138,22 @@ async def _ahandle_event_for_handler(
await asyncio.get_event_loop().run_in_executor(
None, functools.partial(event, *args, **kwargs)
)
except NotImplementedError as e:
if event_name == "on_chat_model_start":
message_strings = [get_buffer_string(m) for m in args[1]]
await _ahandle_event_for_handler(
handler,
"on_llm",
"ignore_llm",
args[0],
message_strings,
*args[2:],
**kwargs,
)
else:
logger.warning(f"Error in {event_name} callback: {e}")
except Exception as e:
# TODO: switch this to use logging
print(f"Error in {event_name} callback: {e}")
logger.warning(f"Error in {event_name} callback: {e}")
async def _ahandle_event(
@ -531,6 +568,33 @@ class CallbackManager(BaseCallbackManager):
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> CallbackManagerForLLMRun:
"""Run when LLM starts running."""
if run_id is None:
run_id = uuid4()
_handle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
messages,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
# Re-use the LLM Run Manager since the outputs are treated
# the same for now
return CallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
def on_chain_start(
self,
serialized: Dict[str, Any],
@ -629,6 +693,31 @@ class AsyncCallbackManager(BaseCallbackManager):
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
if run_id is None:
run_id = uuid4()
await _ahandle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
serialized,
messages,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
return AsyncCallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
async def on_chain_start(
self,
serialized: Dict[str, Any],

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import logging
import os
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from uuid import UUID, uuid4
@ -19,6 +20,7 @@ from langchain.callbacks.tracers.schemas import (
TracerSessionV2,
TracerSessionV2Create,
)
from langchain.schema import BaseMessage, messages_to_dict
from langchain.utils import raise_for_status_with_text
@ -193,6 +195,36 @@ class LangChainTracerV2(LangChainTracer):
"""Load the default tracing session and set it as the Tracer's session."""
return self.load_session("default")
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Start a trace for an LLM run."""
if self.session is None:
self.session = self.load_default_session()
run_id_ = str(run_id)
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
llm_run = LLMRun(
uuid=run_id_,
parent_uuid=parent_run_id_,
serialized=serialized,
prompts=[],
extra={**kwargs, "messages": messages},
start_time=datetime.utcnow(),
execution_order=execution_order,
child_execution_order=execution_order,
session_id=self.session.id,
)
self._start_trace(llm_run)
def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> RunCreate:
"""Convert a run to a Run."""
session = self.session or self.load_default_session()
@ -201,6 +233,11 @@ class LangChainTracerV2(LangChainTracer):
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
if isinstance(run, LLMRun):
run_type = "llm"
if run.extra is not None and "messages" in run.extra:
messages: List[List[BaseMessage]] = run.extra.pop("messages")
converted_messages = [messages_to_dict(batch) for batch in messages]
inputs = {"messages": converted_messages}
else:
inputs = {"prompts": run.prompts}
outputs = run.response.dict() if run.response else {}
child_runs = []

View File

@ -117,6 +117,7 @@ class RunBase(BaseModel):
session_id: UUID
reference_example_id: Optional[UUID]
run_type: RunTypeEnum
parent_run_id: Optional[UUID]
class RunCreate(RunBase):
@ -130,7 +131,6 @@ class Run(RunBase):
"""Run schema when loading from the DB."""
name: str
parent_run_id: Optional[UUID]
ChainRun.update_forward_refs()

View File

@ -24,7 +24,6 @@ from langchain.schema import (
HumanMessage,
LLMResult,
PromptValue,
get_buffer_string,
)
@ -69,9 +68,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
message_strings = [get_buffer_string(m) for m in messages]
run_manager = callback_manager.on_llm_start(
{"name": self.__class__.__name__}, message_strings
run_manager = callback_manager.on_chat_model_start(
{"name": self.__class__.__name__}, messages
)
new_arg_supported = inspect.signature(self._generate).parameters.get(
@ -104,9 +102,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
message_strings = [get_buffer_string(m) for m in messages]
run_manager = await callback_manager.on_llm_start(
{"name": self.__class__.__name__}, message_strings
run_manager = await callback_manager.on_chat_model_start(
{"name": self.__class__.__name__}, messages
)
new_arg_supported = inspect.signature(self._agenerate).parameters.get(

View File

@ -31,9 +31,8 @@ from langchain.callbacks.tracers.langchain import LangChainTracerV2
from langchain.chains.base import Chain
from langchain.chat_models.base import BaseChatModel
from langchain.client.models import Dataset, DatasetCreate, Example, ExampleCreate
from langchain.client.utils import parse_chat_messages
from langchain.llms.base import BaseLLM
from langchain.schema import ChatResult, LLMResult
from langchain.schema import ChatResult, LLMResult, messages_from_dict
from langchain.utils import raise_for_status_with_text, xor_args
if TYPE_CHECKING:
@ -96,7 +95,6 @@ class LangChainPlusClient(BaseSettings):
"Unable to get seeded tenant ID. Please manually provide."
) from e
results: List[dict] = response.json()
breakpoint()
if len(results) == 0:
raise ValueError("No seeded tenant found")
return results[0]["id"]
@ -296,13 +294,15 @@ class LangChainPlusClient(BaseSettings):
langchain_tracer: LangChainTracerV2,
) -> Union[LLMResult, ChatResult]:
if isinstance(llm, BaseLLM):
if "prompts" not in inputs:
raise ValueError(f"LLM Run requires 'prompts' input. Got {inputs}")
llm_prompts: List[str] = inputs["prompts"]
llm_output = await llm.agenerate(llm_prompts, callbacks=[langchain_tracer])
elif isinstance(llm, BaseChatModel):
chat_prompts: List[str] = inputs["prompts"]
messages = [
parse_chat_messages(chat_prompt) for chat_prompt in chat_prompts
]
if "messages" not in inputs:
raise ValueError(f"Chat Run requires 'messages' input. Got {inputs}")
raw_messages: List[List[dict]] = inputs["messages"]
messages = [messages_from_dict(batch) for batch in raw_messages]
llm_output = await llm.agenerate(messages, callbacks=[langchain_tracer])
else:
raise ValueError(f"Unsupported LLM type {type(llm)}")
@ -454,13 +454,17 @@ class LangChainPlusClient(BaseSettings):
) -> Union[LLMResult, ChatResult]:
"""Run the language model on the example."""
if isinstance(llm, BaseLLM):
if "prompts" not in inputs:
raise ValueError(f"LLM Run must contain 'prompts' key. Got {inputs}")
llm_prompts: List[str] = inputs["prompts"]
llm_output = llm.generate(llm_prompts, callbacks=[langchain_tracer])
elif isinstance(llm, BaseChatModel):
chat_prompts: List[str] = inputs["prompts"]
messages = [
parse_chat_messages(chat_prompt) for chat_prompt in chat_prompts
]
if "messages" not in inputs:
raise ValueError(
f"Chat Model Run must contain 'messages' key. Got {inputs}"
)
raw_messages: List[List[dict]] = inputs["messages"]
messages = [messages_from_dict(batch) for batch in raw_messages]
llm_output = llm.generate(messages, callbacks=[langchain_tracer])
else:
raise ValueError(f"Unsupported LLM type {type(llm)}")

View File

@ -1,42 +0,0 @@
"""Client Utils."""
import re
from typing import Dict, List, Optional, Sequence, Type, Union
from langchain.schema import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
_DEFAULT_MESSAGES_T = Union[Type[HumanMessage], Type[SystemMessage], Type[AIMessage]]
_RESOLUTION_MAP: Dict[str, _DEFAULT_MESSAGES_T] = {
"Human": HumanMessage,
"AI": AIMessage,
"System": SystemMessage,
}
def parse_chat_messages(
input_text: str, roles: Optional[Sequence[str]] = None
) -> List[BaseMessage]:
"""Parse chat messages from a string. This is not robust."""
roles = roles or ["Human", "AI", "System"]
roles_pattern = "|".join(roles)
pattern = (
rf"(?P<entity>{roles_pattern}): (?P<message>"
rf"(?:.*\n?)*?)(?=(?:{roles_pattern}): |\Z)"
)
matches = re.finditer(pattern, input_text, re.MULTILINE)
results: List[BaseMessage] = []
for match in matches:
entity = match.group("entity")
message = match.group("message").rstrip("\n")
if entity in _RESOLUTION_MAP:
results.append(_RESOLUTION_MAP[entity](content=message))
else:
results.append(ChatMessage(role=entity, content=message))
return results

View File

@ -1,9 +1,12 @@
"""A fake callback handler for testing purposes."""
from typing import Any
from itertools import chain
from typing import Any, Dict, List, Optional
from uuid import UUID
from pydantic import BaseModel
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain.schema import BaseMessage
class BaseFakeCallbackHandler(BaseModel):
@ -16,6 +19,7 @@ class BaseFakeCallbackHandler(BaseModel):
ignore_llm_: bool = False
ignore_chain_: bool = False
ignore_agent_: bool = False
ignore_chat_model_: bool = False
# add finer-grained counters for easier debugging of failing tests
chain_starts: int = 0
@ -27,6 +31,7 @@ class BaseFakeCallbackHandler(BaseModel):
tool_ends: int = 0
agent_actions: int = 0
agent_ends: int = 0
chat_model_starts: int = 0
class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
@ -47,6 +52,7 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
self.llm_streams += 1
def on_chain_start_common(self) -> None:
print("CHAIN START")
self.chain_starts += 1
self.starts += 1
@ -69,6 +75,7 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
self.errors += 1
def on_agent_action_common(self) -> None:
print("AGENT ACTION")
self.agent_actions += 1
self.starts += 1
@ -76,6 +83,11 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
self.agent_ends += 1
self.ends += 1
def on_chat_model_start_common(self) -> None:
print("STARTING CHAT MODEL")
self.chat_model_starts += 1
self.starts += 1
def on_text_common(self) -> None:
self.text += 1
@ -193,6 +205,20 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
return self
class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
assert all(isinstance(m, BaseMessage) for m in chain(*messages))
self.on_chat_model_start_common()
class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixin):
"""Fake async callback handler for testing."""

View File

@ -10,6 +10,7 @@ from uuid import UUID, uuid4
import pytest
from freezegun import freeze_time
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.tracers.base import (
BaseTracer,
ChainRun,
@ -96,6 +97,33 @@ def test_tracer_llm_run() -> None:
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_chat_model_run() -> None:
"""Test tracer on a Chat Model run."""
uuid = uuid4()
compare_run = LLMRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={},
prompts=[""],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
manager = CallbackManager(handlers=[tracer])
run_manager = manager.on_chat_model_start(serialized={}, messages=[[]], run_id=uuid)
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_llm_run_errors_no_start() -> None:
"""Test tracer on an LLM run without a start."""

View File

@ -1,70 +0,0 @@
"""Test LangChain+ Client Utils."""
from typing import List
from langchain.client.utils import parse_chat_messages
from langchain.schema import (
AIMessage,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
def test_parse_chat_messages() -> None:
"""Test that chat messages are parsed correctly."""
input_text = (
"Human: I am human roar\nAI: I am AI beep boop\nSystem: I am a system message"
)
expected = [
HumanMessage(content="I am human roar"),
AIMessage(content="I am AI beep boop"),
SystemMessage(content="I am a system message"),
]
assert parse_chat_messages(input_text) == expected
def test_parse_chat_messages_empty_input() -> None:
"""Test that an empty input string returns an empty list."""
input_text = ""
expected: List[BaseMessage] = []
assert parse_chat_messages(input_text) == expected
def test_parse_chat_messages_multiline_messages() -> None:
"""Test that multiline messages are parsed correctly."""
input_text = (
"Human: I am a human\nand I roar\nAI: I am an AI\nand I"
" beep boop\nSystem: I am a system\nand a message"
)
expected = [
HumanMessage(content="I am a human\nand I roar"),
AIMessage(content="I am an AI\nand I beep boop"),
SystemMessage(content="I am a system\nand a message"),
]
assert parse_chat_messages(input_text) == expected
def test_parse_chat_messages_custom_roles() -> None:
"""Test that custom roles are parsed correctly."""
input_text = "Client: I need help\nAgent: I'm here to help\nClient: Thank you"
expected = [
ChatMessage(role="Client", content="I need help"),
ChatMessage(role="Agent", content="I'm here to help"),
ChatMessage(role="Client", content="Thank you"),
]
assert parse_chat_messages(input_text, roles=["Client", "Agent"]) == expected
def test_parse_chat_messages_embedded_roles() -> None:
"""Test that messages with embedded role references are parsed correctly."""
input_text = (
"Human: Oh ai what if you said AI: foo bar?"
"\nAI: Well, that would be interesting!"
)
expected = [
HumanMessage(content="Oh ai what if you said AI: foo bar?"),
AIMessage(content="Well, that would be interesting!"),
]
assert parse_chat_messages(input_text) == expected

View File

@ -0,0 +1,32 @@
"""Fake Chat Model wrapper for testing purposes."""
from typing import List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import SimpleChatModel
from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult
class FakeChatModel(SimpleChatModel):
"""Fake Chat Model wrapper for testing purposes."""
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
return "fake response"
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> ChatResult:
output_str = "fake response"
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])

View File

@ -1,5 +1,10 @@
"""Test LLM callbacks."""
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
from langchain.schema import HumanMessage
from tests.unit_tests.callbacks.fake_callback_handler import (
FakeCallbackHandler,
FakeCallbackHandlerWithChatStart,
)
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
from tests.unit_tests.llms.fake_llm import FakeLLM
@ -12,3 +17,30 @@ def test_llm_with_callbacks() -> None:
assert handler.starts == 1
assert handler.ends == 1
assert handler.errors == 0
def test_chat_model_with_v1_callbacks() -> None:
"""Test chat model callbacks fall back to on_llm_start."""
handler = FakeCallbackHandler()
llm = FakeChatModel(callbacks=[handler], verbose=True)
output = llm([HumanMessage(content="foo")])
assert output.content == "fake response"
assert handler.starts == 1
assert handler.ends == 1
assert handler.errors == 0
assert handler.llm_starts == 1
assert handler.llm_ends == 1
def test_chat_model_with_v2_callbacks() -> None:
"""Test chat model callbacks fall back to on_llm_start."""
handler = FakeCallbackHandlerWithChatStart()
llm = FakeChatModel(callbacks=[handler], verbose=True)
output = llm([HumanMessage(content="foo")])
assert output.content == "fake response"
assert handler.starts == 1
assert handler.ends == 1
assert handler.errors == 0
assert handler.llm_starts == 0
assert handler.llm_ends == 1
assert handler.chat_model_starts == 1