diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py index 577b277d5a..d01247991a 100644 --- a/libs/core/langchain_core/language_models/base.py +++ b/libs/core/langchain_core/language_models/base.py @@ -16,7 +16,12 @@ from typing import ( from typing_extensions import TypeAlias from langchain_core._api import deprecated -from langchain_core.messages import AnyMessage, BaseMessage, get_buffer_string +from langchain_core.messages import ( + AnyMessage, + BaseMessage, + MessageLikeRepresentation, + get_buffer_string, +) from langchain_core.prompt_values import PromptValue from langchain_core.runnables import Runnable, RunnableSerializable from langchain_core.utils import get_pydantic_field_names @@ -49,7 +54,7 @@ def _get_token_ids_default_method(text: str) -> List[int]: return tokenizer.encode(text) -LanguageModelInput = Union[PromptValue, str, Sequence[BaseMessage]] +LanguageModelInput = Union[PromptValue, str, Sequence[MessageLikeRepresentation]] LanguageModelOutput = Union[BaseMessage, str] LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput] LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str) diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 72320e2cb2..aaf61a7810 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -34,6 +34,7 @@ from langchain_core.messages import ( BaseMessage, BaseMessageChunk, HumanMessage, + convert_to_messages, message_chunk_to_message, ) from langchain_core.outputs import ( @@ -144,7 +145,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): elif isinstance(input, str): return StringPromptValue(text=input) elif isinstance(input, Sequence): - return ChatPromptValue(messages=input) + return ChatPromptValue(messages=convert_to_messages(input)) else: raise ValueError( f"Invalid input type {type(input)}. " diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index 2b07c84027..7f987baf88 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -48,7 +48,12 @@ from langchain_core.callbacks import ( from langchain_core.globals import get_llm_cache from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput from langchain_core.load import dumpd -from langchain_core.messages import AIMessage, BaseMessage, get_buffer_string +from langchain_core.messages import ( + AIMessage, + BaseMessage, + convert_to_messages, + get_buffer_string, +) from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue from langchain_core.pydantic_v1 import Field, root_validator, validator @@ -210,7 +215,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): elif isinstance(input, str): return StringPromptValue(text=input) elif isinstance(input, Sequence): - return ChatPromptValue(messages=input) + return ChatPromptValue(messages=convert_to_messages(input)) else: raise ValueError( f"Invalid input type {type(input)}. " diff --git a/libs/core/langchain_core/messages/__init__.py b/libs/core/langchain_core/messages/__init__.py index 44444c9d53..c35dac72b9 100644 --- a/libs/core/langchain_core/messages/__init__.py +++ b/libs/core/langchain_core/messages/__init__.py @@ -1,4 +1,4 @@ -from typing import List, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from langchain_core.messages.ai import AIMessage, AIMessageChunk from langchain_core.messages.base import ( @@ -117,6 +117,110 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage: ) +MessageLikeRepresentation = Union[BaseMessage, Tuple[str, str], str, Dict[str, Any]] + + +def _create_message_from_message_type( + message_type: str, + content: str, + name: Optional[str] = None, + tool_call_id: Optional[str] = None, + **additional_kwargs: Any, +) -> BaseMessage: + """Create a message from a message type and content string. + + Args: + message_type: str the type of the message (e.g., "human", "ai", etc.) + content: str the content string. + + Returns: + a message of the appropriate type. + """ + kwargs: Dict[str, Any] = {} + if name is not None: + kwargs["name"] = name + if tool_call_id is not None: + kwargs["tool_call_id"] = tool_call_id + if additional_kwargs: + kwargs["additional_kwargs"] = additional_kwargs # type: ignore[assignment] + if message_type in ("human", "user"): + message: BaseMessage = HumanMessage(content=content, **kwargs) + elif message_type in ("ai", "assistant"): + message = AIMessage(content=content, **kwargs) + elif message_type == "system": + message = SystemMessage(content=content, **kwargs) + elif message_type == "function": + message = FunctionMessage(content=content, **kwargs) + elif message_type == "tool": + message = ToolMessage(content=content, **kwargs) + else: + raise ValueError( + f"Unexpected message type: {message_type}. Use one of 'human'," + f" 'user', 'ai', 'assistant', or 'system'." + ) + return message + + +def _convert_to_message( + message: MessageLikeRepresentation, +) -> BaseMessage: + """Instantiate a message from a variety of message formats. + + The message format can be one of the following: + + - BaseMessagePromptTemplate + - BaseMessage + - 2-tuple of (role string, template); e.g., ("human", "{user_input}") + - dict: a message dict with role and content keys + - string: shorthand for ("human", template); e.g., "{user_input}" + + Args: + message: a representation of a message in one of the supported formats + + Returns: + an instance of a message or a message template + """ + if isinstance(message, BaseMessage): + _message = message + elif isinstance(message, str): + _message = _create_message_from_message_type("human", message) + elif isinstance(message, tuple): + if len(message) != 2: + raise ValueError(f"Expected 2-tuple of (role, template), got {message}") + message_type_str, template = message + _message = _create_message_from_message_type(message_type_str, template) + elif isinstance(message, dict): + msg_kwargs = message.copy() + try: + msg_type = msg_kwargs.pop("role") + msg_content = msg_kwargs.pop("content") + except KeyError: + raise ValueError( + f"Message dict must contain 'role' and 'content' keys, got {message}" + ) + _message = _create_message_from_message_type( + msg_type, msg_content, **msg_kwargs + ) + else: + raise NotImplementedError(f"Unsupported message type: {type(message)}") + + return _message + + +def convert_to_messages( + messages: Sequence[MessageLikeRepresentation], +) -> List[BaseMessage]: + """Convert a sequence of messages to a list of messages. + + Args: + messages: Sequence of messages to convert. + + Returns: + List of messages (BaseMessages). + """ + return [_convert_to_message(m) for m in messages] + + __all__ = [ "AIMessage", "AIMessageChunk", @@ -133,6 +237,7 @@ __all__ = [ "SystemMessageChunk", "ToolMessage", "ToolMessageChunk", + "convert_to_messages", "get_buffer_string", "message_chunk_to_message", "messages_from_dict", diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index f0de1aa88b..b03e0be291 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -27,6 +27,7 @@ from langchain_core.messages import ( ChatMessage, HumanMessage, SystemMessage, + convert_to_messages, ) from langchain_core.messages.base import get_msg_title_repr from langchain_core.prompt_values import ChatPromptValue, PromptValue @@ -126,7 +127,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate): f"variable {self.variable_name} should be a list of base messages, " f"got {value}" ) - for v in value: + for v in convert_to_messages(value): if not isinstance(v, BaseMessage): raise ValueError( f"variable {self.variable_name} should be a list of base messages," diff --git a/libs/core/tests/unit_tests/fake/chat_model.py b/libs/core/tests/unit_tests/fake/chat_model.py index 98f05b6ca6..d0135d7f11 100644 --- a/libs/core/tests/unit_tests/fake/chat_model.py +++ b/libs/core/tests/unit_tests/fake/chat_model.py @@ -301,3 +301,24 @@ class GenericFakeChatModel(BaseChatModel): @property def _llm_type(self) -> str: return "generic-fake-chat-model" + + +class ParrotFakeChatModel(BaseChatModel): + """A generic fake chat model that can be used to test the chat model interface. + + * Chat model should be usable in both sync and async tests + """ + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Top Level call""" + return ChatResult(generations=[ChatGeneration(message=messages[-1])]) + + @property + def _llm_type(self) -> str: + return "parrot-fake-chat-model" diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py index 8700f0751c..6ca6265750 100644 --- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -5,8 +5,9 @@ from uuid import UUID from langchain_core.callbacks.base import AsyncCallbackHandler from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage +from langchain_core.messages.human import HumanMessage from langchain_core.outputs import ChatGenerationChunk, GenerationChunk -from tests.unit_tests.fake.chat_model import GenericFakeChatModel +from tests.unit_tests.fake.chat_model import GenericFakeChatModel, ParrotFakeChatModel def test_generic_fake_chat_model_invoke() -> None: @@ -182,3 +183,11 @@ async def test_callback_handlers() -> None: AIMessageChunk(content="goodbye"), ] assert tokens == ["hello", " ", "goodbye"] + + +def test_chat_model_inputs() -> None: + fake = ParrotFakeChatModel() + + assert fake.invoke("hello") == HumanMessage(content="hello") + assert fake.invoke([("ai", "blah")]) == AIMessage(content="blah") + assert fake.invoke([AIMessage(content="blah")]) == AIMessage(content="blah") diff --git a/libs/core/tests/unit_tests/messages/test_imports.py b/libs/core/tests/unit_tests/messages/test_imports.py index dba0c84060..628223887a 100644 --- a/libs/core/tests/unit_tests/messages/test_imports.py +++ b/libs/core/tests/unit_tests/messages/test_imports.py @@ -16,6 +16,7 @@ EXPECTED_ALL = [ "SystemMessageChunk", "ToolMessage", "ToolMessageChunk", + "convert_to_messages", "get_buffer_string", "message_chunk_to_message", "messages_from_dict", diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 2765d030d5..0f3198bf26 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -369,3 +369,9 @@ def test_messages_placeholder() -> None: prompt.format_messages() prompt = MessagesPlaceholder("history", optional=True) assert prompt.format_messages() == [] + prompt.format_messages( + history=[("system", "You are an AI assistant."), "Hello!"] + ) == [ + SystemMessage(content="You are an AI assistant."), + HumanMessage(content="Hello!"), + ] diff --git a/libs/core/tests/unit_tests/test_messages.py b/libs/core/tests/unit_tests/test_messages.py index 95d60a52f2..6f8b97951c 100644 --- a/libs/core/tests/unit_tests/test_messages.py +++ b/libs/core/tests/unit_tests/test_messages.py @@ -14,6 +14,7 @@ from langchain_core.messages import ( HumanMessageChunk, SystemMessage, ToolMessage, + convert_to_messages, get_buffer_string, message_chunk_to_message, messages_from_dict, @@ -428,3 +429,54 @@ def test_tool_calls_merge() -> None: ] }, ) + + +def test_convert_to_messages() -> None: + # dicts + assert convert_to_messages( + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + {"role": "ai", "content": "Hi!"}, + {"role": "human", "content": "Hello!", "name": "Jane"}, + { + "role": "assistant", + "content": "Hi!", + "name": "JaneBot", + "function_call": {"name": "greet", "arguments": '{"name": "Jane"}'}, + }, + {"role": "function", "name": "greet", "content": "Hi!"}, + {"role": "tool", "tool_call_id": "tool_id", "content": "Hi!"}, + ] + ) == [ + SystemMessage(content="You are a helpful assistant."), + HumanMessage(content="Hello!"), + AIMessage(content="Hi!"), + HumanMessage(content="Hello!", name="Jane"), + AIMessage( + content="Hi!", + name="JaneBot", + additional_kwargs={ + "function_call": {"name": "greet", "arguments": '{"name": "Jane"}'} + }, + ), + FunctionMessage(name="greet", content="Hi!"), + ToolMessage(tool_call_id="tool_id", content="Hi!"), + ] + + # tuples + assert convert_to_messages( + [ + ("system", "You are a helpful assistant."), + "hello!", + ("ai", "Hi!"), + ("human", "Hello!"), + ("assistant", "Hi!"), + ] + ) == [ + SystemMessage(content="You are a helpful assistant."), + HumanMessage(content="hello!"), + AIMessage(content="Hi!"), + HumanMessage(content="Hello!"), + AIMessage(content="Hi!"), + ]