mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
core:Add optional max_messages to MessagePlaceholder (#16098)
- **Description:** Add optional max_messages to MessagePlaceholder - **Issue:** [16096](https://github.com/langchain-ai/langchain/issues/16096) - **Dependencies:** None - **Twitter handle:** @davedecaprio Sometimes it's better to limit the history in the prompt itself rather than the memory. This is needed if you want different prompts in the chain to have different history lengths. --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
7193634ae6
commit
a4bcb45f65
@ -38,7 +38,7 @@ from langchain_core.prompts.base import BasePromptTemplate
|
|||||||
from langchain_core.prompts.image import ImagePromptTemplate
|
from langchain_core.prompts.image import ImagePromptTemplate
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
|
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator
|
from langchain_core.pydantic_v1 import Field, PositiveInt, root_validator
|
||||||
from langchain_core.utils import get_colored_text
|
from langchain_core.utils import get_colored_text
|
||||||
from langchain_core.utils.interactive_env import is_interactive_env
|
from langchain_core.utils.interactive_env import is_interactive_env
|
||||||
|
|
||||||
@ -160,6 +160,24 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
|||||||
# AIMessage(content="5 + 2 is 7"),
|
# AIMessage(content="5 + 2 is 7"),
|
||||||
# HumanMessage(content="now multiply that by 4"),
|
# HumanMessage(content="now multiply that by 4"),
|
||||||
# ])
|
# ])
|
||||||
|
|
||||||
|
Limiting the number of messages:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_core.prompts import MessagesPlaceholder
|
||||||
|
|
||||||
|
prompt = MessagesPlaceholder("history", n_messages=1)
|
||||||
|
|
||||||
|
prompt.format_messages(
|
||||||
|
history=[
|
||||||
|
("system", "You are an AI assistant."),
|
||||||
|
("human", "Hello!"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# -> [
|
||||||
|
# HumanMessage(content="Hello!"),
|
||||||
|
# ]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
variable_name: str
|
variable_name: str
|
||||||
@ -170,6 +188,10 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
|||||||
list. If False then a named argument with name `variable_name` must be passed
|
list. If False then a named argument with name `variable_name` must be passed
|
||||||
in, even if the value is an empty list."""
|
in, even if the value is an empty list."""
|
||||||
|
|
||||||
|
n_messages: Optional[PositiveInt] = None
|
||||||
|
"""Maximum number of messages to include. If None, then will include all.
|
||||||
|
Defaults to None."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> List[str]:
|
||||||
"""Get the namespace of the langchain object."""
|
"""Get the namespace of the langchain object."""
|
||||||
@ -197,7 +219,10 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
|||||||
f"variable {self.variable_name} should be a list of base messages, "
|
f"variable {self.variable_name} should be a list of base messages, "
|
||||||
f"got {value}"
|
f"got {value}"
|
||||||
)
|
)
|
||||||
return convert_to_messages(value)
|
value = convert_to_messages(value)
|
||||||
|
if self.n_messages:
|
||||||
|
value = value[-self.n_messages :]
|
||||||
|
return value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_variables(self) -> List[str]:
|
def input_variables(self) -> List[str]:
|
||||||
|
@ -655,6 +655,21 @@ def test_messages_placeholder() -> None:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_messages_placeholder_with_max() -> None:
|
||||||
|
history = [
|
||||||
|
AIMessage(content="1"),
|
||||||
|
AIMessage(content="2"),
|
||||||
|
AIMessage(content="3"),
|
||||||
|
]
|
||||||
|
prompt = MessagesPlaceholder("history")
|
||||||
|
assert prompt.format_messages(history=history) == history
|
||||||
|
prompt = MessagesPlaceholder("history", n_messages=2)
|
||||||
|
assert prompt.format_messages(history=history) == [
|
||||||
|
AIMessage(content="2"),
|
||||||
|
AIMessage(content="3"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_chat_prompt_message_placeholder_partial() -> None:
|
def test_chat_prompt_message_placeholder_partial() -> None:
|
||||||
prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("history")])
|
prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("history")])
|
||||||
prompt = prompt.partial(history=[("system", "foo")])
|
prompt = prompt.partial(history=[("system", "foo")])
|
||||||
|
Loading…
Reference in New Issue
Block a user