diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index b43c4b07e1..c52d3bfb9d 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -38,7 +38,7 @@ from langchain_core.prompts.base import BasePromptTemplate from langchain_core.prompts.image import ImagePromptTemplate from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.string import StringPromptTemplate, get_template_variables -from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.pydantic_v1 import Field, PositiveInt, root_validator from langchain_core.utils import get_colored_text from langchain_core.utils.interactive_env import is_interactive_env @@ -160,6 +160,24 @@ class MessagesPlaceholder(BaseMessagePromptTemplate): # AIMessage(content="5 + 2 is 7"), # HumanMessage(content="now multiply that by 4"), # ]) + + Limiting the number of messages: + + .. code-block:: python + + from langchain_core.prompts import MessagesPlaceholder + + prompt = MessagesPlaceholder("history", n_messages=1) + + prompt.format_messages( + history=[ + ("system", "You are an AI assistant."), + ("human", "Hello!"), + ] + ) + # -> [ + # HumanMessage(content="Hello!"), + # ] """ variable_name: str @@ -170,6 +188,10 @@ class MessagesPlaceholder(BaseMessagePromptTemplate): list. If False then a named argument with name `variable_name` must be passed in, even if the value is an empty list.""" + n_messages: Optional[PositiveInt] = None + """Maximum number of messages to include. If None, then will include all. + Defaults to None.""" + @classmethod def get_lc_namespace(cls) -> List[str]: """Get the namespace of the langchain object.""" @@ -197,7 +219,10 @@ class MessagesPlaceholder(BaseMessagePromptTemplate): f"variable {self.variable_name} should be a list of base messages, " f"got {value}" ) - return convert_to_messages(value) + value = convert_to_messages(value) + if self.n_messages: + value = value[-self.n_messages :] + return value @property def input_variables(self) -> List[str]: diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index e5fddbb485..5811480f04 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -655,6 +655,21 @@ def test_messages_placeholder() -> None: ] +def test_messages_placeholder_with_max() -> None: + history = [ + AIMessage(content="1"), + AIMessage(content="2"), + AIMessage(content="3"), + ] + prompt = MessagesPlaceholder("history") + assert prompt.format_messages(history=history) == history + prompt = MessagesPlaceholder("history", n_messages=2) + assert prompt.format_messages(history=history) == [ + AIMessage(content="2"), + AIMessage(content="3"), + ] + + def test_chat_prompt_message_placeholder_partial() -> None: prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("history")]) prompt = prompt.partial(history=[("system", "foo")])