core:Add optional max_messages to MessagePlaceholder (#16098)

- **Description:** Add optional max_messages to MessagePlaceholder
- **Issue:**
[16096](https://github.com/langchain-ai/langchain/issues/16096)
- **Dependencies:** None
- **Twitter handle:** @davedecaprio

Sometimes it's better to limit the history in the prompt itself rather
than the memory. This is needed if you want different prompts in the
chain to have different history lengths.

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
pull/23200/head
David DeCaprio 2 months ago committed by GitHub
parent 7193634ae6
commit a4bcb45f65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -38,7 +38,7 @@ from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.image import ImagePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.pydantic_v1 import Field, PositiveInt, root_validator
from langchain_core.utils import get_colored_text
from langchain_core.utils.interactive_env import is_interactive_env
@ -160,6 +160,24 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
# AIMessage(content="5 + 2 is 7"),
# HumanMessage(content="now multiply that by 4"),
# ])
Limiting the number of messages:
.. code-block:: python
from langchain_core.prompts import MessagesPlaceholder
prompt = MessagesPlaceholder("history", n_messages=1)
prompt.format_messages(
history=[
("system", "You are an AI assistant."),
("human", "Hello!"),
]
)
# -> [
# HumanMessage(content="Hello!"),
# ]
"""
variable_name: str
@ -170,6 +188,10 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
list. If False then a named argument with name `variable_name` must be passed
in, even if the value is an empty list."""
n_messages: Optional[PositiveInt] = None
"""Maximum number of messages to include. If None, then will include all.
Defaults to None."""
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
@ -197,7 +219,10 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
f"variable {self.variable_name} should be a list of base messages, "
f"got {value}"
)
return convert_to_messages(value)
value = convert_to_messages(value)
if self.n_messages:
value = value[-self.n_messages :]
return value
@property
def input_variables(self) -> List[str]:

@ -655,6 +655,21 @@ def test_messages_placeholder() -> None:
]
def test_messages_placeholder_with_max() -> None:
history = [
AIMessage(content="1"),
AIMessage(content="2"),
AIMessage(content="3"),
]
prompt = MessagesPlaceholder("history")
assert prompt.format_messages(history=history) == history
prompt = MessagesPlaceholder("history", n_messages=2)
assert prompt.format_messages(history=history) == [
AIMessage(content="2"),
AIMessage(content="3"),
]
def test_chat_prompt_message_placeholder_partial() -> None:
prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("history")])
prompt = prompt.partial(history=[("system", "foo")])

Loading…
Cancel
Save