From c29e9b641216acf30b5050d399eaa05ae43e6f22 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Fri, 2 Feb 2024 10:23:37 -0800 Subject: [PATCH] core[patch]: fix chat prompt partial messages placeholder var (#16918) --- libs/core/langchain_core/prompts/base.py | 7 ++---- libs/core/langchain_core/prompts/chat.py | 22 +++++++++---------- .../tests/unit_tests/prompts/test_chat.py | 20 ++++++++++++++++- .../__snapshots__/test_runnable.ambr | 6 +++-- .../load/__snapshots__/test_dump.ambr | 3 ++- 5 files changed, 38 insertions(+), 20 deletions(-) diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index 07ac722255..9a65e5e32d 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -47,9 +47,7 @@ class BasePromptTemplate( If not provided, all variables are assumed to be strings.""" output_parser: Optional[BaseOutputParser] = None """How to parse the output of calling an LLM on this formatted prompt.""" - partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field( - default_factory=dict - ) + partial_variables: Mapping[str, Any] = Field(default_factory=dict) @classmethod def get_lc_namespace(cls) -> List[str]: @@ -143,8 +141,7 @@ class BasePromptTemplate( def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]: # Get partial params: partial_kwargs = { - k: v if isinstance(v, str) else v() - for k, v in self.partial_variables.items() + k: v if not callable(v) else v() for k, v in self.partial_variables.items() } return {**partial_kwargs, **kwargs} diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index cfb3192cfa..4fbd0666b1 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -5,7 +5,6 @@ from abc import ABC, abstractmethod from pathlib import Path from typing import ( Any, - Callable, Dict, List, Optional, @@ -130,13 +129,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate): f"variable {self.variable_name} should be a list of base messages, " f"got {value}" ) - for v in convert_to_messages(value): - if not isinstance(v, BaseMessage): - raise ValueError( - f"variable {self.variable_name} should be a list of base messages," - f" got {value}" - ) - return value + return convert_to_messages(value) @property def input_variables(self) -> List[str]: @@ -755,13 +748,20 @@ class ChatPromptTemplate(BaseChatPromptTemplate): # Automatically infer input variables from messages input_vars: Set[str] = set() + partial_vars: Dict[str, Any] = {} for _message in _messages: - if isinstance( + if isinstance(_message, MessagesPlaceholder) and _message.optional: + partial_vars[_message.variable_name] = [] + elif isinstance( _message, (BaseChatPromptTemplate, BaseMessagePromptTemplate) ): input_vars.update(_message.input_variables) - return cls(input_variables=sorted(input_vars), messages=_messages) + return cls( + input_variables=sorted(input_vars), + messages=_messages, + partial_variables=partial_vars, + ) def format(self, **kwargs: Any) -> str: """Format the chat template into a string. @@ -799,7 +799,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate): raise ValueError(f"Unexpected input: {message_template}") return result - def partial(self, **kwargs: Union[str, Callable[[], str]]) -> ChatPromptTemplate: + def partial(self, **kwargs: Any) -> ChatPromptTemplate: """Get a new ChatPromptTemplate with some input variables already filled in. Args: diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 029244afe8..0f14605f40 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -503,9 +503,27 @@ def test_messages_placeholder() -> None: prompt.format_messages() prompt = MessagesPlaceholder("history", optional=True) assert prompt.format_messages() == [] - prompt.format_messages( + assert prompt.format_messages( history=[("system", "You are an AI assistant."), "Hello!"] ) == [ SystemMessage(content="You are an AI assistant."), HumanMessage(content="Hello!"), ] + + +def test_chat_prompt_message_placeholder_partial() -> None: + prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("history")]) + prompt = prompt.partial(history=[("system", "foo")]) + assert prompt.format_messages() == [SystemMessage(content="foo")] + assert prompt.format_messages(history=[("system", "bar")]) == [ + SystemMessage(content="bar") + ] + + prompt = ChatPromptTemplate.from_messages( + [ + MessagesPlaceholder("history", optional=True), + ] + ) + assert prompt.format_messages() == [] + prompt = prompt.partial(history=[("system", "foo")]) + assert prompt.format_messages() == [SystemMessage(content="foo")] diff --git a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr index 051520c045..b149765e88 100644 --- a/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr +++ b/libs/core/tests/unit_tests/runnables/__snapshots__/test_runnable.ambr @@ -1544,7 +1544,8 @@ } } } - ] + ], + "partial_variables": {} } }, "middle": [], @@ -1617,7 +1618,8 @@ } } } - ] + ], + "partial_variables": {} } }, "middle": [], diff --git a/libs/langchain/tests/unit_tests/load/__snapshots__/test_dump.ambr b/libs/langchain/tests/unit_tests/load/__snapshots__/test_dump.ambr index 0540760d4c..826746ac24 100644 --- a/libs/langchain/tests/unit_tests/load/__snapshots__/test_dump.ambr +++ b/libs/langchain/tests/unit_tests/load/__snapshots__/test_dump.ambr @@ -191,7 +191,8 @@ } } } - ] + ], + "partial_variables": {} } } }