diff --git a/langchain/prompts/chat.py b/langchain/prompts/chat.py index bc814158c3..7be1390b35 100644 --- a/langchain/prompts/chat.py +++ b/langchain/prompts/chat.py @@ -168,6 +168,8 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): for message in messages: if isinstance(message, BaseMessagePromptTemplate): input_vars.update(message.input_variables) + if "partial_variables" in values: + input_vars = input_vars - set(values["partial_variables"]) if "input_variables" in values: if input_vars != set(values["input_variables"]): raise ValueError( diff --git a/tests/unit_tests/prompts/test_chat.py b/tests/unit_tests/prompts/test_chat.py index d1faac582a..17c114e640 100644 --- a/tests/unit_tests/prompts/test_chat.py +++ b/tests/unit_tests/prompts/test_chat.py @@ -162,3 +162,31 @@ def test_infer_variables() -> None: messages = [HumanMessagePromptTemplate.from_template("{foo}")] prompt = ChatPromptTemplate(messages=messages) assert prompt.input_variables == ["foo"] + + +def test_chat_valid_with_partial_variables() -> None: + messages = [ + HumanMessagePromptTemplate.from_template( + "Do something with {question} using {context} giving it like {formatins}" + ) + ] + prompt = ChatPromptTemplate( + messages=messages, + input_variables=["question", "context"], + partial_variables={"formatins": "some structure"}, + ) + assert set(prompt.input_variables) == set(["question", "context"]) + assert prompt.partial_variables == {"formatins": "some structure"} + + +def test_chat_valid_infer_variables() -> None: + messages = [ + HumanMessagePromptTemplate.from_template( + "Do something with {question} using {context} giving it like {formatins}" + ) + ] + prompt = ChatPromptTemplate( + messages=messages, partial_variables={"formatins": "some structure"} + ) + assert set(prompt.input_variables) == set(["question", "context"]) + assert prompt.partial_variables == {"formatins": "some structure"}