From a4bcb45f659d29e4c4ce715f70813413db9bad03 Mon Sep 17 00:00:00 2001 From: David DeCaprio Date: Wed, 19 Jun 2024 18:39:51 -0500 Subject: [PATCH] core:Add optional max_messages to MessagePlaceholder (#16098) - **Description:** Add optional max_messages to MessagePlaceholder - **Issue:** [16096](https://github.com/langchain-ai/langchain/issues/16096) - **Dependencies:** None - **Twitter handle:** @davedecaprio Sometimes it's better to limit the history in the prompt itself rather than the memory. This is needed if you want different prompts in the chain to have different history lengths. --------- Co-authored-by: Harrison Chase --- libs/core/langchain_core/prompts/chat.py | 29 +++++++++++++++++-- .../tests/unit_tests/prompts/test_chat.py | 15 ++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index b43c4b07e1..c52d3bfb9d 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -38,7 +38,7 @@ from langchain_core.prompts.base import BasePromptTemplate from langchain_core.prompts.image import ImagePromptTemplate from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.string import StringPromptTemplate, get_template_variables -from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.pydantic_v1 import Field, PositiveInt, root_validator from langchain_core.utils import get_colored_text from langchain_core.utils.interactive_env import is_interactive_env @@ -160,6 +160,24 @@ class MessagesPlaceholder(BaseMessagePromptTemplate): # AIMessage(content="5 + 2 is 7"), # HumanMessage(content="now multiply that by 4"), # ]) + + Limiting the number of messages: + + .. code-block:: python + + from langchain_core.prompts import MessagesPlaceholder + + prompt = MessagesPlaceholder("history", n_messages=1) + + prompt.format_messages( + history=[ + ("system", "You are an AI assistant."), + ("human", "Hello!"), + ] + ) + # -> [ + # HumanMessage(content="Hello!"), + # ] """ variable_name: str @@ -170,6 +188,10 @@ class MessagesPlaceholder(BaseMessagePromptTemplate): list. If False then a named argument with name `variable_name` must be passed in, even if the value is an empty list.""" + n_messages: Optional[PositiveInt] = None + """Maximum number of messages to include. If None, then will include all. + Defaults to None.""" + @classmethod def get_lc_namespace(cls) -> List[str]: """Get the namespace of the langchain object.""" @@ -197,7 +219,10 @@ class MessagesPlaceholder(BaseMessagePromptTemplate): f"variable {self.variable_name} should be a list of base messages, " f"got {value}" ) - return convert_to_messages(value) + value = convert_to_messages(value) + if self.n_messages: + value = value[-self.n_messages :] + return value @property def input_variables(self) -> List[str]: diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index e5fddbb485..5811480f04 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -655,6 +655,21 @@ def test_messages_placeholder() -> None: ] +def test_messages_placeholder_with_max() -> None: + history = [ + AIMessage(content="1"), + AIMessage(content="2"), + AIMessage(content="3"), + ] + prompt = MessagesPlaceholder("history") + assert prompt.format_messages(history=history) == history + prompt = MessagesPlaceholder("history", n_messages=2) + assert prompt.format_messages(history=history) == [ + AIMessage(content="2"), + AIMessage(content="3"), + ] + + def test_chat_prompt_message_placeholder_partial() -> None: prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("history")]) prompt = prompt.partial(history=[("system", "foo")])