diff --git a/langchain/prompts/prompt.py b/langchain/prompts/prompt.py index c61cf69f81..31f87d43b9 100644 --- a/langchain/prompts/prompt.py +++ b/langchain/prompts/prompt.py @@ -135,6 +135,12 @@ class PromptTemplate(StringPromptTemplate): v for _, v, _, _ in Formatter().parse(template) if v is not None } + if "partial_variables" in kwargs: + partial_variables = kwargs["partial_variables"] + input_variables = { + var for var in input_variables if var not in partial_variables + } + return cls( input_variables=list(sorted(input_variables)), template=template, **kwargs ) diff --git a/tests/unit_tests/prompts/test_chat.py b/tests/unit_tests/prompts/test_chat.py index 6defde6991..87f64c9599 100644 --- a/tests/unit_tests/prompts/test_chat.py +++ b/tests/unit_tests/prompts/test_chat.py @@ -56,6 +56,30 @@ def create_chat_prompt_template() -> ChatPromptTemplate: ) +def test_create_chat_prompt_template_from_template() -> None: + """Create a chat prompt template.""" + prompt = ChatPromptTemplate.from_template("hi {foo} {bar}") + assert prompt.messages == [ + HumanMessagePromptTemplate.from_template("hi {foo} {bar}") + ] + + +def test_create_chat_prompt_template_from_template_partial() -> None: + """Create a chat prompt template with partials.""" + prompt = ChatPromptTemplate.from_template( + "hi {foo} {bar}", partial_variables={"foo": "jim"} + ) + expected_prompt = PromptTemplate( + template="hi {foo} {bar}", + input_variables=["bar"], + partial_variables={"foo": "jim"}, + ) + assert len(prompt.messages) == 1 + output_prompt = prompt.messages[0] + assert isinstance(output_prompt, HumanMessagePromptTemplate) + assert output_prompt.prompt == expected_prompt + + def test_chat_prompt_template() -> None: """Test chat prompt template.""" prompt_template = create_chat_prompt_template()