core:Add optional max_messages to MessagePlaceholder (#16098)

- **Description:** Add optional max_messages to MessagePlaceholder
- **Issue:**
[16096](https://github.com/langchain-ai/langchain/issues/16096)
- **Dependencies:** None
- **Twitter handle:** @davedecaprio

Sometimes it's better to limit the history in the prompt itself rather
than the memory. This is needed if you want different prompts in the
chain to have different history lengths.

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
David DeCaprio 2024-06-19 18:39:51 -05:00 committed by GitHub
parent 7193634ae6
commit a4bcb45f65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 42 additions and 2 deletions

View File

@ -38,7 +38,7 @@ from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.image import ImagePromptTemplate from langchain_core.prompts.image import ImagePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.pydantic_v1 import Field, PositiveInt, root_validator
from langchain_core.utils import get_colored_text from langchain_core.utils import get_colored_text
from langchain_core.utils.interactive_env import is_interactive_env from langchain_core.utils.interactive_env import is_interactive_env
@ -160,6 +160,24 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
# AIMessage(content="5 + 2 is 7"), # AIMessage(content="5 + 2 is 7"),
# HumanMessage(content="now multiply that by 4"), # HumanMessage(content="now multiply that by 4"),
# ]) # ])
Limiting the number of messages:
.. code-block:: python
from langchain_core.prompts import MessagesPlaceholder
prompt = MessagesPlaceholder("history", n_messages=1)
prompt.format_messages(
history=[
("system", "You are an AI assistant."),
("human", "Hello!"),
]
)
# -> [
# HumanMessage(content="Hello!"),
# ]
""" """
variable_name: str variable_name: str
@ -170,6 +188,10 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
list. If False then a named argument with name `variable_name` must be passed list. If False then a named argument with name `variable_name` must be passed
in, even if the value is an empty list.""" in, even if the value is an empty list."""
n_messages: Optional[PositiveInt] = None
"""Maximum number of messages to include. If None, then will include all.
Defaults to None."""
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
@ -197,7 +219,10 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
f"variable {self.variable_name} should be a list of base messages, " f"variable {self.variable_name} should be a list of base messages, "
f"got {value}" f"got {value}"
) )
return convert_to_messages(value) value = convert_to_messages(value)
if self.n_messages:
value = value[-self.n_messages :]
return value
@property @property
def input_variables(self) -> List[str]: def input_variables(self) -> List[str]:

View File

@ -655,6 +655,21 @@ def test_messages_placeholder() -> None:
] ]
def test_messages_placeholder_with_max() -> None:
history = [
AIMessage(content="1"),
AIMessage(content="2"),
AIMessage(content="3"),
]
prompt = MessagesPlaceholder("history")
assert prompt.format_messages(history=history) == history
prompt = MessagesPlaceholder("history", n_messages=2)
assert prompt.format_messages(history=history) == [
AIMessage(content="2"),
AIMessage(content="3"),
]
def test_chat_prompt_message_placeholder_partial() -> None: def test_chat_prompt_message_placeholder_partial() -> None:
prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("history")]) prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("history")])
prompt = prompt.partial(history=[("system", "foo")]) prompt = prompt.partial(history=[("system", "foo")])