mirror of https://github.com/hwchase17/langchain
Add on_chat_message_start (#4499)
### Add on_chat_message_start to callback manager and base tracer Goal: trace messages directly to permit reloading as chat messages (store in an integration-agnostic way) Add an `on_chat_message_start` method. Fall back to `on_llm_start()` for handlers that don't have it implemented. Does so in a non-backwards-compat breaking way (for now)pull/4529/head
parent
bbf76dbb52
commit
4ee47926ca
@ -1,42 +0,0 @@
|
|||||||
"""Client Utils."""
|
|
||||||
import re
|
|
||||||
from typing import Dict, List, Optional, Sequence, Type, Union
|
|
||||||
|
|
||||||
from langchain.schema import (
|
|
||||||
AIMessage,
|
|
||||||
BaseMessage,
|
|
||||||
ChatMessage,
|
|
||||||
HumanMessage,
|
|
||||||
SystemMessage,
|
|
||||||
)
|
|
||||||
|
|
||||||
_DEFAULT_MESSAGES_T = Union[Type[HumanMessage], Type[SystemMessage], Type[AIMessage]]
|
|
||||||
_RESOLUTION_MAP: Dict[str, _DEFAULT_MESSAGES_T] = {
|
|
||||||
"Human": HumanMessage,
|
|
||||||
"AI": AIMessage,
|
|
||||||
"System": SystemMessage,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def parse_chat_messages(
|
|
||||||
input_text: str, roles: Optional[Sequence[str]] = None
|
|
||||||
) -> List[BaseMessage]:
|
|
||||||
"""Parse chat messages from a string. This is not robust."""
|
|
||||||
roles = roles or ["Human", "AI", "System"]
|
|
||||||
roles_pattern = "|".join(roles)
|
|
||||||
pattern = (
|
|
||||||
rf"(?P<entity>{roles_pattern}): (?P<message>"
|
|
||||||
rf"(?:.*\n?)*?)(?=(?:{roles_pattern}): |\Z)"
|
|
||||||
)
|
|
||||||
matches = re.finditer(pattern, input_text, re.MULTILINE)
|
|
||||||
|
|
||||||
results: List[BaseMessage] = []
|
|
||||||
for match in matches:
|
|
||||||
entity = match.group("entity")
|
|
||||||
message = match.group("message").rstrip("\n")
|
|
||||||
if entity in _RESOLUTION_MAP:
|
|
||||||
results.append(_RESOLUTION_MAP[entity](content=message))
|
|
||||||
else:
|
|
||||||
results.append(ChatMessage(role=entity, content=message))
|
|
||||||
|
|
||||||
return results
|
|
@ -1,70 +0,0 @@
|
|||||||
"""Test LangChain+ Client Utils."""
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from langchain.client.utils import parse_chat_messages
|
|
||||||
from langchain.schema import (
|
|
||||||
AIMessage,
|
|
||||||
BaseMessage,
|
|
||||||
ChatMessage,
|
|
||||||
HumanMessage,
|
|
||||||
SystemMessage,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_chat_messages() -> None:
|
|
||||||
"""Test that chat messages are parsed correctly."""
|
|
||||||
input_text = (
|
|
||||||
"Human: I am human roar\nAI: I am AI beep boop\nSystem: I am a system message"
|
|
||||||
)
|
|
||||||
expected = [
|
|
||||||
HumanMessage(content="I am human roar"),
|
|
||||||
AIMessage(content="I am AI beep boop"),
|
|
||||||
SystemMessage(content="I am a system message"),
|
|
||||||
]
|
|
||||||
assert parse_chat_messages(input_text) == expected
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_chat_messages_empty_input() -> None:
|
|
||||||
"""Test that an empty input string returns an empty list."""
|
|
||||||
input_text = ""
|
|
||||||
expected: List[BaseMessage] = []
|
|
||||||
assert parse_chat_messages(input_text) == expected
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_chat_messages_multiline_messages() -> None:
|
|
||||||
"""Test that multiline messages are parsed correctly."""
|
|
||||||
input_text = (
|
|
||||||
"Human: I am a human\nand I roar\nAI: I am an AI\nand I"
|
|
||||||
" beep boop\nSystem: I am a system\nand a message"
|
|
||||||
)
|
|
||||||
expected = [
|
|
||||||
HumanMessage(content="I am a human\nand I roar"),
|
|
||||||
AIMessage(content="I am an AI\nand I beep boop"),
|
|
||||||
SystemMessage(content="I am a system\nand a message"),
|
|
||||||
]
|
|
||||||
assert parse_chat_messages(input_text) == expected
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_chat_messages_custom_roles() -> None:
|
|
||||||
"""Test that custom roles are parsed correctly."""
|
|
||||||
input_text = "Client: I need help\nAgent: I'm here to help\nClient: Thank you"
|
|
||||||
expected = [
|
|
||||||
ChatMessage(role="Client", content="I need help"),
|
|
||||||
ChatMessage(role="Agent", content="I'm here to help"),
|
|
||||||
ChatMessage(role="Client", content="Thank you"),
|
|
||||||
]
|
|
||||||
assert parse_chat_messages(input_text, roles=["Client", "Agent"]) == expected
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_chat_messages_embedded_roles() -> None:
|
|
||||||
"""Test that messages with embedded role references are parsed correctly."""
|
|
||||||
input_text = (
|
|
||||||
"Human: Oh ai what if you said AI: foo bar?"
|
|
||||||
"\nAI: Well, that would be interesting!"
|
|
||||||
)
|
|
||||||
expected = [
|
|
||||||
HumanMessage(content="Oh ai what if you said AI: foo bar?"),
|
|
||||||
AIMessage(content="Well, that would be interesting!"),
|
|
||||||
]
|
|
||||||
assert parse_chat_messages(input_text) == expected
|
|
@ -0,0 +1,32 @@
|
|||||||
|
"""Fake Chat Model wrapper for testing purposes."""
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain.chat_models.base import SimpleChatModel
|
||||||
|
from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult
|
||||||
|
|
||||||
|
|
||||||
|
class FakeChatModel(SimpleChatModel):
|
||||||
|
"""Fake Chat Model wrapper for testing purposes."""
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
) -> str:
|
||||||
|
return "fake response"
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
) -> ChatResult:
|
||||||
|
output_str = "fake response"
|
||||||
|
message = AIMessage(content=output_str)
|
||||||
|
generation = ChatGeneration(message=message)
|
||||||
|
return ChatResult(generations=[generation])
|
Loading…
Reference in New Issue