Accept message-like things in Chat models, LLMs and MessagesPlaceholder (#16418)

pull/16657/head
Nuno Campos 5 months ago committed by GitHub
parent 570b4f8e66
commit 52ccae3fb1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -16,7 +16,12 @@ from typing import (
from typing_extensions import TypeAlias
from langchain_core._api import deprecated
from langchain_core.messages import AnyMessage, BaseMessage, get_buffer_string
from langchain_core.messages import (
AnyMessage,
BaseMessage,
MessageLikeRepresentation,
get_buffer_string,
)
from langchain_core.prompt_values import PromptValue
from langchain_core.runnables import Runnable, RunnableSerializable
from langchain_core.utils import get_pydantic_field_names
@ -49,7 +54,7 @@ def _get_token_ids_default_method(text: str) -> List[int]:
return tokenizer.encode(text)
LanguageModelInput = Union[PromptValue, str, Sequence[BaseMessage]]
LanguageModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation]]
LanguageModelOutput = Union[BaseMessage, str]
LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput]
LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str)

@ -34,6 +34,7 @@ from langchain_core.messages import (
BaseMessage,
BaseMessageChunk,
HumanMessage,
convert_to_messages,
message_chunk_to_message,
)
from langchain_core.outputs import (
@ -144,7 +145,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
elif isinstance(input, str):
return StringPromptValue(text=input)
elif isinstance(input, Sequence):
return ChatPromptValue(messages=input)
return ChatPromptValue(messages=convert_to_messages(input))
else:
raise ValueError(
f"Invalid input type {type(input)}. "

@ -48,7 +48,12 @@ from langchain_core.callbacks import (
from langchain_core.globals import get_llm_cache
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
from langchain_core.load import dumpd
from langchain_core.messages import AIMessage, BaseMessage, get_buffer_string
from langchain_core.messages import (
AIMessage,
BaseMessage,
convert_to_messages,
get_buffer_string,
)
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import Field, root_validator, validator
@ -210,7 +215,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
elif isinstance(input, str):
return StringPromptValue(text=input)
elif isinstance(input, Sequence):
return ChatPromptValue(messages=input)
return ChatPromptValue(messages=convert_to_messages(input))
else:
raise ValueError(
f"Invalid input type {type(input)}. "

@ -1,4 +1,4 @@
from typing import List, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from langchain_core.messages.ai import AIMessage, AIMessageChunk
from langchain_core.messages.base import (
@ -117,6 +117,110 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
)
MessageLikeRepresentation = Union[BaseMessage, Tuple[str, str], str, Dict[str, Any]]
def _create_message_from_message_type(
message_type: str,
content: str,
name: Optional[str] = None,
tool_call_id: Optional[str] = None,
**additional_kwargs: Any,
) -> BaseMessage:
"""Create a message from a message type and content string.
Args:
message_type: str the type of the message (e.g., "human", "ai", etc.)
content: str the content string.
Returns:
a message of the appropriate type.
"""
kwargs: Dict[str, Any] = {}
if name is not None:
kwargs["name"] = name
if tool_call_id is not None:
kwargs["tool_call_id"] = tool_call_id
if additional_kwargs:
kwargs["additional_kwargs"] = additional_kwargs # type: ignore[assignment]
if message_type in ("human", "user"):
message: BaseMessage = HumanMessage(content=content, **kwargs)
elif message_type in ("ai", "assistant"):
message = AIMessage(content=content, **kwargs)
elif message_type == "system":
message = SystemMessage(content=content, **kwargs)
elif message_type == "function":
message = FunctionMessage(content=content, **kwargs)
elif message_type == "tool":
message = ToolMessage(content=content, **kwargs)
else:
raise ValueError(
f"Unexpected message type: {message_type}. Use one of 'human',"
f" 'user', 'ai', 'assistant', or 'system'."
)
return message
def _convert_to_message(
message: MessageLikeRepresentation,
) -> BaseMessage:
"""Instantiate a message from a variety of message formats.
The message format can be one of the following:
- BaseMessagePromptTemplate
- BaseMessage
- 2-tuple of (role string, template); e.g., ("human", "{user_input}")
- dict: a message dict with role and content keys
- string: shorthand for ("human", template); e.g., "{user_input}"
Args:
message: a representation of a message in one of the supported formats
Returns:
an instance of a message or a message template
"""
if isinstance(message, BaseMessage):
_message = message
elif isinstance(message, str):
_message = _create_message_from_message_type("human", message)
elif isinstance(message, tuple):
if len(message) != 2:
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
message_type_str, template = message
_message = _create_message_from_message_type(message_type_str, template)
elif isinstance(message, dict):
msg_kwargs = message.copy()
try:
msg_type = msg_kwargs.pop("role")
msg_content = msg_kwargs.pop("content")
except KeyError:
raise ValueError(
f"Message dict must contain 'role' and 'content' keys, got {message}"
)
_message = _create_message_from_message_type(
msg_type, msg_content, **msg_kwargs
)
else:
raise NotImplementedError(f"Unsupported message type: {type(message)}")
return _message
def convert_to_messages(
messages: Sequence[MessageLikeRepresentation],
) -> List[BaseMessage]:
"""Convert a sequence of messages to a list of messages.
Args:
messages: Sequence of messages to convert.
Returns:
List of messages (BaseMessages).
"""
return [_convert_to_message(m) for m in messages]
__all__ = [
"AIMessage",
"AIMessageChunk",
@ -133,6 +237,7 @@ __all__ = [
"SystemMessageChunk",
"ToolMessage",
"ToolMessageChunk",
"convert_to_messages",
"get_buffer_string",
"message_chunk_to_message",
"messages_from_dict",

@ -27,6 +27,7 @@ from langchain_core.messages import (
ChatMessage,
HumanMessage,
SystemMessage,
convert_to_messages,
)
from langchain_core.messages.base import get_msg_title_repr
from langchain_core.prompt_values import ChatPromptValue, PromptValue
@ -126,7 +127,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
f"variable {self.variable_name} should be a list of base messages, "
f"got {value}"
)
for v in value:
for v in convert_to_messages(value):
if not isinstance(v, BaseMessage):
raise ValueError(
f"variable {self.variable_name} should be a list of base messages,"

@ -301,3 +301,24 @@ class GenericFakeChatModel(BaseChatModel):
@property
def _llm_type(self) -> str:
return "generic-fake-chat-model"
class ParrotFakeChatModel(BaseChatModel):
"""A generic fake chat model that can be used to test the chat model interface.
* Chat model should be usable in both sync and async tests
"""
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
return ChatResult(generations=[ChatGeneration(message=messages[-1])])
@property
def _llm_type(self) -> str:
return "parrot-fake-chat-model"

@ -5,8 +5,9 @@ from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.messages.human import HumanMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.fake.chat_model import GenericFakeChatModel
from tests.unit_tests.fake.chat_model import GenericFakeChatModel, ParrotFakeChatModel
def test_generic_fake_chat_model_invoke() -> None:
@ -182,3 +183,11 @@ async def test_callback_handlers() -> None:
AIMessageChunk(content="goodbye"),
]
assert tokens == ["hello", " ", "goodbye"]
def test_chat_model_inputs() -> None:
fake = ParrotFakeChatModel()
assert fake.invoke("hello") == HumanMessage(content="hello")
assert fake.invoke([("ai", "blah")]) == AIMessage(content="blah")
assert fake.invoke([AIMessage(content="blah")]) == AIMessage(content="blah")

@ -16,6 +16,7 @@ EXPECTED_ALL = [
"SystemMessageChunk",
"ToolMessage",
"ToolMessageChunk",
"convert_to_messages",
"get_buffer_string",
"message_chunk_to_message",
"messages_from_dict",

@ -369,3 +369,9 @@ def test_messages_placeholder() -> None:
prompt.format_messages()
prompt = MessagesPlaceholder("history", optional=True)
assert prompt.format_messages() == []
prompt.format_messages(
history=[("system", "You are an AI assistant."), "Hello!"]
) == [
SystemMessage(content="You are an AI assistant."),
HumanMessage(content="Hello!"),
]

@ -14,6 +14,7 @@ from langchain_core.messages import (
HumanMessageChunk,
SystemMessage,
ToolMessage,
convert_to_messages,
get_buffer_string,
message_chunk_to_message,
messages_from_dict,
@ -428,3 +429,54 @@ def test_tool_calls_merge() -> None:
]
},
)
def test_convert_to_messages() -> None:
# dicts
assert convert_to_messages(
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
{"role": "ai", "content": "Hi!"},
{"role": "human", "content": "Hello!", "name": "Jane"},
{
"role": "assistant",
"content": "Hi!",
"name": "JaneBot",
"function_call": {"name": "greet", "arguments": '{"name": "Jane"}'},
},
{"role": "function", "name": "greet", "content": "Hi!"},
{"role": "tool", "tool_call_id": "tool_id", "content": "Hi!"},
]
) == [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Hello!"),
AIMessage(content="Hi!"),
HumanMessage(content="Hello!", name="Jane"),
AIMessage(
content="Hi!",
name="JaneBot",
additional_kwargs={
"function_call": {"name": "greet", "arguments": '{"name": "Jane"}'}
},
),
FunctionMessage(name="greet", content="Hi!"),
ToolMessage(tool_call_id="tool_id", content="Hi!"),
]
# tuples
assert convert_to_messages(
[
("system", "You are a helpful assistant."),
"hello!",
("ai", "Hi!"),
("human", "Hello!"),
("assistant", "Hi!"),
]
) == [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="hello!"),
AIMessage(content="Hi!"),
HumanMessage(content="Hello!"),
AIMessage(content="Hi!"),
]

Loading…
Cancel
Save