allow partials in from_template (#4638)

This commit is contained in:
Harrison Chase 2023-05-13 21:47:20 -07:00 committed by GitHub
parent fbfa49f2c1
commit f2f2aced6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 0 deletions

View File

@ -135,6 +135,12 @@ class PromptTemplate(StringPromptTemplate):
v for _, v, _, _ in Formatter().parse(template) if v is not None v for _, v, _, _ in Formatter().parse(template) if v is not None
} }
if "partial_variables" in kwargs:
partial_variables = kwargs["partial_variables"]
input_variables = {
var for var in input_variables if var not in partial_variables
}
return cls( return cls(
input_variables=list(sorted(input_variables)), template=template, **kwargs input_variables=list(sorted(input_variables)), template=template, **kwargs
) )

View File

@ -56,6 +56,30 @@ def create_chat_prompt_template() -> ChatPromptTemplate:
) )
def test_create_chat_prompt_template_from_template() -> None:
"""Create a chat prompt template."""
prompt = ChatPromptTemplate.from_template("hi {foo} {bar}")
assert prompt.messages == [
HumanMessagePromptTemplate.from_template("hi {foo} {bar}")
]
def test_create_chat_prompt_template_from_template_partial() -> None:
"""Create a chat prompt template with partials."""
prompt = ChatPromptTemplate.from_template(
"hi {foo} {bar}", partial_variables={"foo": "jim"}
)
expected_prompt = PromptTemplate(
template="hi {foo} {bar}",
input_variables=["bar"],
partial_variables={"foo": "jim"},
)
assert len(prompt.messages) == 1
output_prompt = prompt.messages[0]
assert isinstance(output_prompt, HumanMessagePromptTemplate)
assert output_prompt.prompt == expected_prompt
def test_chat_prompt_template() -> None: def test_chat_prompt_template() -> None:
"""Test chat prompt template.""" """Test chat prompt template."""
prompt_template = create_chat_prompt_template() prompt_template = create_chat_prompt_template()