From 6efd5fa2b9d46c7b4db6ad638097f010b745f0cc Mon Sep 17 00:00:00 2001 From: Avinash Raj Date: Tue, 20 Jun 2023 10:38:15 +0530 Subject: [PATCH] Fix for #6431 - chatprompt template with partial variables giing validation error (#6456) W.r.t recent changes, ChatPromptTemplate does not accepting partial variables. This PR should fix that issue. Fixes #6431 #### Who can review? @hwchase17 --------- Co-authored-by: Harrison Chase --- langchain/prompts/chat.py | 2 ++ tests/unit_tests/prompts/test_chat.py | 28 +++++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/langchain/prompts/chat.py b/langchain/prompts/chat.py index bc814158..7be1390b 100644 --- a/langchain/prompts/chat.py +++ b/langchain/prompts/chat.py @@ -168,6 +168,8 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): for message in messages: if isinstance(message, BaseMessagePromptTemplate): input_vars.update(message.input_variables) + if "partial_variables" in values: + input_vars = input_vars - set(values["partial_variables"]) if "input_variables" in values: if input_vars != set(values["input_variables"]): raise ValueError( diff --git a/tests/unit_tests/prompts/test_chat.py b/tests/unit_tests/prompts/test_chat.py index d1faac58..17c114e6 100644 --- a/tests/unit_tests/prompts/test_chat.py +++ b/tests/unit_tests/prompts/test_chat.py @@ -162,3 +162,31 @@ def test_infer_variables() -> None: messages = [HumanMessagePromptTemplate.from_template("{foo}")] prompt = ChatPromptTemplate(messages=messages) assert prompt.input_variables == ["foo"] + + +def test_chat_valid_with_partial_variables() -> None: + messages = [ + HumanMessagePromptTemplate.from_template( + "Do something with {question} using {context} giving it like {formatins}" + ) + ] + prompt = ChatPromptTemplate( + messages=messages, + input_variables=["question", "context"], + partial_variables={"formatins": "some structure"}, + ) + assert set(prompt.input_variables) == set(["question", "context"]) + assert prompt.partial_variables == {"formatins": "some structure"} + + +def test_chat_valid_infer_variables() -> None: + messages = [ + HumanMessagePromptTemplate.from_template( + "Do something with {question} using {context} giving it like {formatins}" + ) + ] + prompt = ChatPromptTemplate( + messages=messages, partial_variables={"formatins": "some structure"} + ) + assert set(prompt.input_variables) == set(["question", "context"]) + assert prompt.partial_variables == {"formatins": "some structure"}