forked from Archives/langchain
Add on_chat_message_start (#4499)
### Add on_chat_message_start to callback manager and base tracer Goal: trace messages directly to permit reloading as chat messages (store in an integration-agnostic way) Add an `on_chat_message_start` method. Fall back to `on_llm_start()` for handlers that don't have it implemented. Does so in a non-backwards-compat breaking way (for now)
This commit is contained in:
parent
bbf76dbb52
commit
4ee47926ca
@ -4,7 +4,12 @@ from __future__ import annotations
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
from langchain.schema import (
|
||||||
|
AgentAction,
|
||||||
|
AgentFinish,
|
||||||
|
BaseMessage,
|
||||||
|
LLMResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LLMManagerMixin:
|
class LLMManagerMixin:
|
||||||
@ -123,6 +128,20 @@ class CallbackManagerMixin:
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Run when LLM starts running."""
|
"""Run when LLM starts running."""
|
||||||
|
|
||||||
|
def on_chat_model_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
messages: List[List[BaseMessage]],
|
||||||
|
*,
|
||||||
|
run_id: UUID,
|
||||||
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""Run when a chat model starts running."""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
|
||||||
|
)
|
||||||
|
|
||||||
def on_chain_start(
|
def on_chain_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: Dict[str, Any],
|
||||||
@ -184,6 +203,11 @@ class BaseCallbackHandler(
|
|||||||
"""Whether to ignore agent callbacks."""
|
"""Whether to ignore agent callbacks."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ignore_chat_model(self) -> bool:
|
||||||
|
"""Whether to ignore chat model callbacks."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class AsyncCallbackHandler(BaseCallbackHandler):
|
class AsyncCallbackHandler(BaseCallbackHandler):
|
||||||
"""Async callback handler that can be used to handle callbacks from langchain."""
|
"""Async callback handler that can be used to handle callbacks from langchain."""
|
||||||
@ -199,6 +223,20 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM starts running."""
|
"""Run when LLM starts running."""
|
||||||
|
|
||||||
|
async def on_chat_model_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
messages: List[List[BaseMessage]],
|
||||||
|
*,
|
||||||
|
run_id: UUID,
|
||||||
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""Run when a chat model starts running."""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} does not implement `on_chat_model_start`"
|
||||||
|
)
|
||||||
|
|
||||||
async def on_llm_new_token(
|
async def on_llm_new_token(
|
||||||
self,
|
self,
|
||||||
token: str,
|
token: str,
|
||||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
@ -22,8 +23,15 @@ from langchain.callbacks.stdout import StdOutCallbackHandler
|
|||||||
from langchain.callbacks.tracers.base import TracerSession
|
from langchain.callbacks.tracers.base import TracerSession
|
||||||
from langchain.callbacks.tracers.langchain import LangChainTracer, LangChainTracerV2
|
from langchain.callbacks.tracers.langchain import LangChainTracer, LangChainTracerV2
|
||||||
from langchain.callbacks.tracers.schemas import TracerSessionV2
|
from langchain.callbacks.tracers.schemas import TracerSessionV2
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
from langchain.schema import (
|
||||||
|
AgentAction,
|
||||||
|
AgentFinish,
|
||||||
|
BaseMessage,
|
||||||
|
LLMResult,
|
||||||
|
get_buffer_string,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
|
||||||
|
|
||||||
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
|
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
|
||||||
@ -87,15 +95,31 @@ def _handle_event(
|
|||||||
*args: Any,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Generic event handler for CallbackManager."""
|
||||||
|
message_strings: Optional[List[str]] = None
|
||||||
for handler in handlers:
|
for handler in handlers:
|
||||||
try:
|
try:
|
||||||
if ignore_condition_name is None or not getattr(
|
if ignore_condition_name is None or not getattr(
|
||||||
handler, ignore_condition_name
|
handler, ignore_condition_name
|
||||||
):
|
):
|
||||||
getattr(handler, event_name)(*args, **kwargs)
|
getattr(handler, event_name)(*args, **kwargs)
|
||||||
|
except NotImplementedError as e:
|
||||||
|
if event_name == "on_chat_model_start":
|
||||||
|
if message_strings is None:
|
||||||
|
message_strings = [get_buffer_string(m) for m in args[1]]
|
||||||
|
_handle_event(
|
||||||
|
[handler],
|
||||||
|
"on_llm_start",
|
||||||
|
"ignore_llm",
|
||||||
|
args[0],
|
||||||
|
message_strings,
|
||||||
|
*args[2:],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Error in {event_name} callback: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# TODO: switch this to use logging
|
logging.warning(f"Error in {event_name} callback: {e}")
|
||||||
print(f"Error in {event_name} callback: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def _ahandle_event_for_handler(
|
async def _ahandle_event_for_handler(
|
||||||
@ -114,9 +138,22 @@ async def _ahandle_event_for_handler(
|
|||||||
await asyncio.get_event_loop().run_in_executor(
|
await asyncio.get_event_loop().run_in_executor(
|
||||||
None, functools.partial(event, *args, **kwargs)
|
None, functools.partial(event, *args, **kwargs)
|
||||||
)
|
)
|
||||||
|
except NotImplementedError as e:
|
||||||
|
if event_name == "on_chat_model_start":
|
||||||
|
message_strings = [get_buffer_string(m) for m in args[1]]
|
||||||
|
await _ahandle_event_for_handler(
|
||||||
|
handler,
|
||||||
|
"on_llm",
|
||||||
|
"ignore_llm",
|
||||||
|
args[0],
|
||||||
|
message_strings,
|
||||||
|
*args[2:],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Error in {event_name} callback: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# TODO: switch this to use logging
|
logger.warning(f"Error in {event_name} callback: {e}")
|
||||||
print(f"Error in {event_name} callback: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def _ahandle_event(
|
async def _ahandle_event(
|
||||||
@ -531,6 +568,33 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def on_chat_model_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
messages: List[List[BaseMessage]],
|
||||||
|
run_id: Optional[UUID] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> CallbackManagerForLLMRun:
|
||||||
|
"""Run when LLM starts running."""
|
||||||
|
if run_id is None:
|
||||||
|
run_id = uuid4()
|
||||||
|
_handle_event(
|
||||||
|
self.handlers,
|
||||||
|
"on_chat_model_start",
|
||||||
|
"ignore_chat_model",
|
||||||
|
serialized,
|
||||||
|
messages,
|
||||||
|
run_id=run_id,
|
||||||
|
parent_run_id=self.parent_run_id,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Re-use the LLM Run Manager since the outputs are treated
|
||||||
|
# the same for now
|
||||||
|
return CallbackManagerForLLMRun(
|
||||||
|
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||||
|
)
|
||||||
|
|
||||||
def on_chain_start(
|
def on_chain_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: Dict[str, Any],
|
||||||
@ -629,6 +693,31 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def on_chat_model_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
messages: List[List[BaseMessage]],
|
||||||
|
run_id: Optional[UUID] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
if run_id is None:
|
||||||
|
run_id = uuid4()
|
||||||
|
|
||||||
|
await _ahandle_event(
|
||||||
|
self.handlers,
|
||||||
|
"on_chat_model_start",
|
||||||
|
"ignore_chat_model",
|
||||||
|
serialized,
|
||||||
|
messages,
|
||||||
|
run_id=run_id,
|
||||||
|
parent_run_id=self.parent_run_id,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return AsyncCallbackManagerForLLMRun(
|
||||||
|
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||||
|
)
|
||||||
|
|
||||||
async def on_chain_start(
|
async def on_chain_start(
|
||||||
self,
|
self,
|
||||||
serialized: Dict[str, Any],
|
serialized: Dict[str, Any],
|
||||||
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
@ -19,6 +20,7 @@ from langchain.callbacks.tracers.schemas import (
|
|||||||
TracerSessionV2,
|
TracerSessionV2,
|
||||||
TracerSessionV2Create,
|
TracerSessionV2Create,
|
||||||
)
|
)
|
||||||
|
from langchain.schema import BaseMessage, messages_to_dict
|
||||||
from langchain.utils import raise_for_status_with_text
|
from langchain.utils import raise_for_status_with_text
|
||||||
|
|
||||||
|
|
||||||
@ -193,6 +195,36 @@ class LangChainTracerV2(LangChainTracer):
|
|||||||
"""Load the default tracing session and set it as the Tracer's session."""
|
"""Load the default tracing session and set it as the Tracer's session."""
|
||||||
return self.load_session("default")
|
return self.load_session("default")
|
||||||
|
|
||||||
|
def on_chat_model_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
messages: List[List[BaseMessage]],
|
||||||
|
*,
|
||||||
|
run_id: UUID,
|
||||||
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Start a trace for an LLM run."""
|
||||||
|
if self.session is None:
|
||||||
|
self.session = self.load_default_session()
|
||||||
|
|
||||||
|
run_id_ = str(run_id)
|
||||||
|
parent_run_id_ = str(parent_run_id) if parent_run_id else None
|
||||||
|
|
||||||
|
execution_order = self._get_execution_order(parent_run_id_)
|
||||||
|
llm_run = LLMRun(
|
||||||
|
uuid=run_id_,
|
||||||
|
parent_uuid=parent_run_id_,
|
||||||
|
serialized=serialized,
|
||||||
|
prompts=[],
|
||||||
|
extra={**kwargs, "messages": messages},
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
execution_order=execution_order,
|
||||||
|
child_execution_order=execution_order,
|
||||||
|
session_id=self.session.id,
|
||||||
|
)
|
||||||
|
self._start_trace(llm_run)
|
||||||
|
|
||||||
def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> RunCreate:
|
def _convert_run(self, run: Union[LLMRun, ChainRun, ToolRun]) -> RunCreate:
|
||||||
"""Convert a run to a Run."""
|
"""Convert a run to a Run."""
|
||||||
session = self.session or self.load_default_session()
|
session = self.session or self.load_default_session()
|
||||||
@ -201,7 +233,12 @@ class LangChainTracerV2(LangChainTracer):
|
|||||||
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
child_runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
|
||||||
if isinstance(run, LLMRun):
|
if isinstance(run, LLMRun):
|
||||||
run_type = "llm"
|
run_type = "llm"
|
||||||
inputs = {"prompts": run.prompts}
|
if run.extra is not None and "messages" in run.extra:
|
||||||
|
messages: List[List[BaseMessage]] = run.extra.pop("messages")
|
||||||
|
converted_messages = [messages_to_dict(batch) for batch in messages]
|
||||||
|
inputs = {"messages": converted_messages}
|
||||||
|
else:
|
||||||
|
inputs = {"prompts": run.prompts}
|
||||||
outputs = run.response.dict() if run.response else {}
|
outputs = run.response.dict() if run.response else {}
|
||||||
child_runs = []
|
child_runs = []
|
||||||
elif isinstance(run, ChainRun):
|
elif isinstance(run, ChainRun):
|
||||||
|
@ -117,6 +117,7 @@ class RunBase(BaseModel):
|
|||||||
session_id: UUID
|
session_id: UUID
|
||||||
reference_example_id: Optional[UUID]
|
reference_example_id: Optional[UUID]
|
||||||
run_type: RunTypeEnum
|
run_type: RunTypeEnum
|
||||||
|
parent_run_id: Optional[UUID]
|
||||||
|
|
||||||
|
|
||||||
class RunCreate(RunBase):
|
class RunCreate(RunBase):
|
||||||
@ -130,7 +131,6 @@ class Run(RunBase):
|
|||||||
"""Run schema when loading from the DB."""
|
"""Run schema when loading from the DB."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
parent_run_id: Optional[UUID]
|
|
||||||
|
|
||||||
|
|
||||||
ChainRun.update_forward_refs()
|
ChainRun.update_forward_refs()
|
||||||
|
@ -24,7 +24,6 @@ from langchain.schema import (
|
|||||||
HumanMessage,
|
HumanMessage,
|
||||||
LLMResult,
|
LLMResult,
|
||||||
PromptValue,
|
PromptValue,
|
||||||
get_buffer_string,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -69,9 +68,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
callback_manager = CallbackManager.configure(
|
callback_manager = CallbackManager.configure(
|
||||||
callbacks, self.callbacks, self.verbose
|
callbacks, self.callbacks, self.verbose
|
||||||
)
|
)
|
||||||
message_strings = [get_buffer_string(m) for m in messages]
|
run_manager = callback_manager.on_chat_model_start(
|
||||||
run_manager = callback_manager.on_llm_start(
|
{"name": self.__class__.__name__}, messages
|
||||||
{"name": self.__class__.__name__}, message_strings
|
|
||||||
)
|
)
|
||||||
|
|
||||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||||
@ -104,9 +102,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
callbacks, self.callbacks, self.verbose
|
callbacks, self.callbacks, self.verbose
|
||||||
)
|
)
|
||||||
message_strings = [get_buffer_string(m) for m in messages]
|
run_manager = await callback_manager.on_chat_model_start(
|
||||||
run_manager = await callback_manager.on_llm_start(
|
{"name": self.__class__.__name__}, messages
|
||||||
{"name": self.__class__.__name__}, message_strings
|
|
||||||
)
|
)
|
||||||
|
|
||||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||||
|
@ -31,9 +31,8 @@ from langchain.callbacks.tracers.langchain import LangChainTracerV2
|
|||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
from langchain.client.models import Dataset, DatasetCreate, Example, ExampleCreate
|
from langchain.client.models import Dataset, DatasetCreate, Example, ExampleCreate
|
||||||
from langchain.client.utils import parse_chat_messages
|
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.schema import ChatResult, LLMResult
|
from langchain.schema import ChatResult, LLMResult, messages_from_dict
|
||||||
from langchain.utils import raise_for_status_with_text, xor_args
|
from langchain.utils import raise_for_status_with_text, xor_args
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -96,7 +95,6 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
"Unable to get seeded tenant ID. Please manually provide."
|
"Unable to get seeded tenant ID. Please manually provide."
|
||||||
) from e
|
) from e
|
||||||
results: List[dict] = response.json()
|
results: List[dict] = response.json()
|
||||||
breakpoint()
|
|
||||||
if len(results) == 0:
|
if len(results) == 0:
|
||||||
raise ValueError("No seeded tenant found")
|
raise ValueError("No seeded tenant found")
|
||||||
return results[0]["id"]
|
return results[0]["id"]
|
||||||
@ -296,13 +294,15 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
langchain_tracer: LangChainTracerV2,
|
langchain_tracer: LangChainTracerV2,
|
||||||
) -> Union[LLMResult, ChatResult]:
|
) -> Union[LLMResult, ChatResult]:
|
||||||
if isinstance(llm, BaseLLM):
|
if isinstance(llm, BaseLLM):
|
||||||
|
if "prompts" not in inputs:
|
||||||
|
raise ValueError(f"LLM Run requires 'prompts' input. Got {inputs}")
|
||||||
llm_prompts: List[str] = inputs["prompts"]
|
llm_prompts: List[str] = inputs["prompts"]
|
||||||
llm_output = await llm.agenerate(llm_prompts, callbacks=[langchain_tracer])
|
llm_output = await llm.agenerate(llm_prompts, callbacks=[langchain_tracer])
|
||||||
elif isinstance(llm, BaseChatModel):
|
elif isinstance(llm, BaseChatModel):
|
||||||
chat_prompts: List[str] = inputs["prompts"]
|
if "messages" not in inputs:
|
||||||
messages = [
|
raise ValueError(f"Chat Run requires 'messages' input. Got {inputs}")
|
||||||
parse_chat_messages(chat_prompt) for chat_prompt in chat_prompts
|
raw_messages: List[List[dict]] = inputs["messages"]
|
||||||
]
|
messages = [messages_from_dict(batch) for batch in raw_messages]
|
||||||
llm_output = await llm.agenerate(messages, callbacks=[langchain_tracer])
|
llm_output = await llm.agenerate(messages, callbacks=[langchain_tracer])
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported LLM type {type(llm)}")
|
raise ValueError(f"Unsupported LLM type {type(llm)}")
|
||||||
@ -454,13 +454,17 @@ class LangChainPlusClient(BaseSettings):
|
|||||||
) -> Union[LLMResult, ChatResult]:
|
) -> Union[LLMResult, ChatResult]:
|
||||||
"""Run the language model on the example."""
|
"""Run the language model on the example."""
|
||||||
if isinstance(llm, BaseLLM):
|
if isinstance(llm, BaseLLM):
|
||||||
|
if "prompts" not in inputs:
|
||||||
|
raise ValueError(f"LLM Run must contain 'prompts' key. Got {inputs}")
|
||||||
llm_prompts: List[str] = inputs["prompts"]
|
llm_prompts: List[str] = inputs["prompts"]
|
||||||
llm_output = llm.generate(llm_prompts, callbacks=[langchain_tracer])
|
llm_output = llm.generate(llm_prompts, callbacks=[langchain_tracer])
|
||||||
elif isinstance(llm, BaseChatModel):
|
elif isinstance(llm, BaseChatModel):
|
||||||
chat_prompts: List[str] = inputs["prompts"]
|
if "messages" not in inputs:
|
||||||
messages = [
|
raise ValueError(
|
||||||
parse_chat_messages(chat_prompt) for chat_prompt in chat_prompts
|
f"Chat Model Run must contain 'messages' key. Got {inputs}"
|
||||||
]
|
)
|
||||||
|
raw_messages: List[List[dict]] = inputs["messages"]
|
||||||
|
messages = [messages_from_dict(batch) for batch in raw_messages]
|
||||||
llm_output = llm.generate(messages, callbacks=[langchain_tracer])
|
llm_output = llm.generate(messages, callbacks=[langchain_tracer])
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported LLM type {type(llm)}")
|
raise ValueError(f"Unsupported LLM type {type(llm)}")
|
||||||
|
@ -1,42 +0,0 @@
|
|||||||
"""Client Utils."""
|
|
||||||
import re
|
|
||||||
from typing import Dict, List, Optional, Sequence, Type, Union
|
|
||||||
|
|
||||||
from langchain.schema import (
|
|
||||||
AIMessage,
|
|
||||||
BaseMessage,
|
|
||||||
ChatMessage,
|
|
||||||
HumanMessage,
|
|
||||||
SystemMessage,
|
|
||||||
)
|
|
||||||
|
|
||||||
_DEFAULT_MESSAGES_T = Union[Type[HumanMessage], Type[SystemMessage], Type[AIMessage]]
|
|
||||||
_RESOLUTION_MAP: Dict[str, _DEFAULT_MESSAGES_T] = {
|
|
||||||
"Human": HumanMessage,
|
|
||||||
"AI": AIMessage,
|
|
||||||
"System": SystemMessage,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def parse_chat_messages(
|
|
||||||
input_text: str, roles: Optional[Sequence[str]] = None
|
|
||||||
) -> List[BaseMessage]:
|
|
||||||
"""Parse chat messages from a string. This is not robust."""
|
|
||||||
roles = roles or ["Human", "AI", "System"]
|
|
||||||
roles_pattern = "|".join(roles)
|
|
||||||
pattern = (
|
|
||||||
rf"(?P<entity>{roles_pattern}): (?P<message>"
|
|
||||||
rf"(?:.*\n?)*?)(?=(?:{roles_pattern}): |\Z)"
|
|
||||||
)
|
|
||||||
matches = re.finditer(pattern, input_text, re.MULTILINE)
|
|
||||||
|
|
||||||
results: List[BaseMessage] = []
|
|
||||||
for match in matches:
|
|
||||||
entity = match.group("entity")
|
|
||||||
message = match.group("message").rstrip("\n")
|
|
||||||
if entity in _RESOLUTION_MAP:
|
|
||||||
results.append(_RESOLUTION_MAP[entity](content=message))
|
|
||||||
else:
|
|
||||||
results.append(ChatMessage(role=entity, content=message))
|
|
||||||
|
|
||||||
return results
|
|
@ -1,9 +1,12 @@
|
|||||||
"""A fake callback handler for testing purposes."""
|
"""A fake callback handler for testing purposes."""
|
||||||
from typing import Any
|
from itertools import chain
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||||
|
from langchain.schema import BaseMessage
|
||||||
|
|
||||||
|
|
||||||
class BaseFakeCallbackHandler(BaseModel):
|
class BaseFakeCallbackHandler(BaseModel):
|
||||||
@ -16,6 +19,7 @@ class BaseFakeCallbackHandler(BaseModel):
|
|||||||
ignore_llm_: bool = False
|
ignore_llm_: bool = False
|
||||||
ignore_chain_: bool = False
|
ignore_chain_: bool = False
|
||||||
ignore_agent_: bool = False
|
ignore_agent_: bool = False
|
||||||
|
ignore_chat_model_: bool = False
|
||||||
|
|
||||||
# add finer-grained counters for easier debugging of failing tests
|
# add finer-grained counters for easier debugging of failing tests
|
||||||
chain_starts: int = 0
|
chain_starts: int = 0
|
||||||
@ -27,6 +31,7 @@ class BaseFakeCallbackHandler(BaseModel):
|
|||||||
tool_ends: int = 0
|
tool_ends: int = 0
|
||||||
agent_actions: int = 0
|
agent_actions: int = 0
|
||||||
agent_ends: int = 0
|
agent_ends: int = 0
|
||||||
|
chat_model_starts: int = 0
|
||||||
|
|
||||||
|
|
||||||
class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
||||||
@ -47,6 +52,7 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
|||||||
self.llm_streams += 1
|
self.llm_streams += 1
|
||||||
|
|
||||||
def on_chain_start_common(self) -> None:
|
def on_chain_start_common(self) -> None:
|
||||||
|
print("CHAIN START")
|
||||||
self.chain_starts += 1
|
self.chain_starts += 1
|
||||||
self.starts += 1
|
self.starts += 1
|
||||||
|
|
||||||
@ -69,6 +75,7 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
|||||||
self.errors += 1
|
self.errors += 1
|
||||||
|
|
||||||
def on_agent_action_common(self) -> None:
|
def on_agent_action_common(self) -> None:
|
||||||
|
print("AGENT ACTION")
|
||||||
self.agent_actions += 1
|
self.agent_actions += 1
|
||||||
self.starts += 1
|
self.starts += 1
|
||||||
|
|
||||||
@ -76,6 +83,11 @@ class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
|||||||
self.agent_ends += 1
|
self.agent_ends += 1
|
||||||
self.ends += 1
|
self.ends += 1
|
||||||
|
|
||||||
|
def on_chat_model_start_common(self) -> None:
|
||||||
|
print("STARTING CHAT MODEL")
|
||||||
|
self.chat_model_starts += 1
|
||||||
|
self.starts += 1
|
||||||
|
|
||||||
def on_text_common(self) -> None:
|
def on_text_common(self) -> None:
|
||||||
self.text += 1
|
self.text += 1
|
||||||
|
|
||||||
@ -193,6 +205,20 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
|
||||||
|
def on_chat_model_start(
|
||||||
|
self,
|
||||||
|
serialized: Dict[str, Any],
|
||||||
|
messages: List[List[BaseMessage]],
|
||||||
|
*,
|
||||||
|
run_id: UUID,
|
||||||
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
assert all(isinstance(m, BaseMessage) for m in chain(*messages))
|
||||||
|
self.on_chat_model_start_common()
|
||||||
|
|
||||||
|
|
||||||
class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixin):
|
class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||||
"""Fake async callback handler for testing."""
|
"""Fake async callback handler for testing."""
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ from uuid import UUID, uuid4
|
|||||||
import pytest
|
import pytest
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import CallbackManager
|
||||||
from langchain.callbacks.tracers.base import (
|
from langchain.callbacks.tracers.base import (
|
||||||
BaseTracer,
|
BaseTracer,
|
||||||
ChainRun,
|
ChainRun,
|
||||||
@ -96,6 +97,33 @@ def test_tracer_llm_run() -> None:
|
|||||||
assert tracer.runs == [compare_run]
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2023-01-01")
|
||||||
|
def test_tracer_chat_model_run() -> None:
|
||||||
|
"""Test tracer on a Chat Model run."""
|
||||||
|
uuid = uuid4()
|
||||||
|
compare_run = LLMRun(
|
||||||
|
uuid=str(uuid),
|
||||||
|
parent_uuid=None,
|
||||||
|
start_time=datetime.utcnow(),
|
||||||
|
end_time=datetime.utcnow(),
|
||||||
|
extra={},
|
||||||
|
execution_order=1,
|
||||||
|
child_execution_order=1,
|
||||||
|
serialized={},
|
||||||
|
prompts=[""],
|
||||||
|
response=LLMResult(generations=[[]]),
|
||||||
|
session_id=TEST_SESSION_ID,
|
||||||
|
error=None,
|
||||||
|
)
|
||||||
|
tracer = FakeTracer()
|
||||||
|
|
||||||
|
tracer.new_session()
|
||||||
|
manager = CallbackManager(handlers=[tracer])
|
||||||
|
run_manager = manager.on_chat_model_start(serialized={}, messages=[[]], run_id=uuid)
|
||||||
|
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||||
|
assert tracer.runs == [compare_run]
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
@freeze_time("2023-01-01")
|
||||||
def test_tracer_llm_run_errors_no_start() -> None:
|
def test_tracer_llm_run_errors_no_start() -> None:
|
||||||
"""Test tracer on an LLM run without a start."""
|
"""Test tracer on an LLM run without a start."""
|
||||||
|
@ -1,70 +0,0 @@
|
|||||||
"""Test LangChain+ Client Utils."""
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from langchain.client.utils import parse_chat_messages
|
|
||||||
from langchain.schema import (
|
|
||||||
AIMessage,
|
|
||||||
BaseMessage,
|
|
||||||
ChatMessage,
|
|
||||||
HumanMessage,
|
|
||||||
SystemMessage,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_chat_messages() -> None:
|
|
||||||
"""Test that chat messages are parsed correctly."""
|
|
||||||
input_text = (
|
|
||||||
"Human: I am human roar\nAI: I am AI beep boop\nSystem: I am a system message"
|
|
||||||
)
|
|
||||||
expected = [
|
|
||||||
HumanMessage(content="I am human roar"),
|
|
||||||
AIMessage(content="I am AI beep boop"),
|
|
||||||
SystemMessage(content="I am a system message"),
|
|
||||||
]
|
|
||||||
assert parse_chat_messages(input_text) == expected
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_chat_messages_empty_input() -> None:
|
|
||||||
"""Test that an empty input string returns an empty list."""
|
|
||||||
input_text = ""
|
|
||||||
expected: List[BaseMessage] = []
|
|
||||||
assert parse_chat_messages(input_text) == expected
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_chat_messages_multiline_messages() -> None:
|
|
||||||
"""Test that multiline messages are parsed correctly."""
|
|
||||||
input_text = (
|
|
||||||
"Human: I am a human\nand I roar\nAI: I am an AI\nand I"
|
|
||||||
" beep boop\nSystem: I am a system\nand a message"
|
|
||||||
)
|
|
||||||
expected = [
|
|
||||||
HumanMessage(content="I am a human\nand I roar"),
|
|
||||||
AIMessage(content="I am an AI\nand I beep boop"),
|
|
||||||
SystemMessage(content="I am a system\nand a message"),
|
|
||||||
]
|
|
||||||
assert parse_chat_messages(input_text) == expected
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_chat_messages_custom_roles() -> None:
|
|
||||||
"""Test that custom roles are parsed correctly."""
|
|
||||||
input_text = "Client: I need help\nAgent: I'm here to help\nClient: Thank you"
|
|
||||||
expected = [
|
|
||||||
ChatMessage(role="Client", content="I need help"),
|
|
||||||
ChatMessage(role="Agent", content="I'm here to help"),
|
|
||||||
ChatMessage(role="Client", content="Thank you"),
|
|
||||||
]
|
|
||||||
assert parse_chat_messages(input_text, roles=["Client", "Agent"]) == expected
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_chat_messages_embedded_roles() -> None:
|
|
||||||
"""Test that messages with embedded role references are parsed correctly."""
|
|
||||||
input_text = (
|
|
||||||
"Human: Oh ai what if you said AI: foo bar?"
|
|
||||||
"\nAI: Well, that would be interesting!"
|
|
||||||
)
|
|
||||||
expected = [
|
|
||||||
HumanMessage(content="Oh ai what if you said AI: foo bar?"),
|
|
||||||
AIMessage(content="Well, that would be interesting!"),
|
|
||||||
]
|
|
||||||
assert parse_chat_messages(input_text) == expected
|
|
32
tests/unit_tests/llms/fake_chat_model.py
Normal file
32
tests/unit_tests/llms/fake_chat_model.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
"""Fake Chat Model wrapper for testing purposes."""
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain.chat_models.base import SimpleChatModel
|
||||||
|
from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult
|
||||||
|
|
||||||
|
|
||||||
|
class FakeChatModel(SimpleChatModel):
|
||||||
|
"""Fake Chat Model wrapper for testing purposes."""
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
) -> str:
|
||||||
|
return "fake response"
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
) -> ChatResult:
|
||||||
|
output_str = "fake response"
|
||||||
|
message = AIMessage(content=output_str)
|
||||||
|
generation = ChatGeneration(message=message)
|
||||||
|
return ChatResult(generations=[generation])
|
@ -1,5 +1,10 @@
|
|||||||
"""Test LLM callbacks."""
|
"""Test LLM callbacks."""
|
||||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
from langchain.schema import HumanMessage
|
||||||
|
from tests.unit_tests.callbacks.fake_callback_handler import (
|
||||||
|
FakeCallbackHandler,
|
||||||
|
FakeCallbackHandlerWithChatStart,
|
||||||
|
)
|
||||||
|
from tests.unit_tests.llms.fake_chat_model import FakeChatModel
|
||||||
from tests.unit_tests.llms.fake_llm import FakeLLM
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||||
|
|
||||||
|
|
||||||
@ -12,3 +17,30 @@ def test_llm_with_callbacks() -> None:
|
|||||||
assert handler.starts == 1
|
assert handler.starts == 1
|
||||||
assert handler.ends == 1
|
assert handler.ends == 1
|
||||||
assert handler.errors == 0
|
assert handler.errors == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_model_with_v1_callbacks() -> None:
|
||||||
|
"""Test chat model callbacks fall back to on_llm_start."""
|
||||||
|
handler = FakeCallbackHandler()
|
||||||
|
llm = FakeChatModel(callbacks=[handler], verbose=True)
|
||||||
|
output = llm([HumanMessage(content="foo")])
|
||||||
|
assert output.content == "fake response"
|
||||||
|
assert handler.starts == 1
|
||||||
|
assert handler.ends == 1
|
||||||
|
assert handler.errors == 0
|
||||||
|
assert handler.llm_starts == 1
|
||||||
|
assert handler.llm_ends == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_model_with_v2_callbacks() -> None:
|
||||||
|
"""Test chat model callbacks fall back to on_llm_start."""
|
||||||
|
handler = FakeCallbackHandlerWithChatStart()
|
||||||
|
llm = FakeChatModel(callbacks=[handler], verbose=True)
|
||||||
|
output = llm([HumanMessage(content="foo")])
|
||||||
|
assert output.content == "fake response"
|
||||||
|
assert handler.starts == 1
|
||||||
|
assert handler.ends == 1
|
||||||
|
assert handler.errors == 0
|
||||||
|
assert handler.llm_starts == 0
|
||||||
|
assert handler.llm_ends == 1
|
||||||
|
assert handler.chat_model_starts == 1
|
||||||
|
Loading…
Reference in New Issue
Block a user