mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
allow partials in from_template (#4638)
This commit is contained in:
parent
fbfa49f2c1
commit
f2f2aced6d
@ -135,6 +135,12 @@ class PromptTemplate(StringPromptTemplate):
|
||||
v for _, v, _, _ in Formatter().parse(template) if v is not None
|
||||
}
|
||||
|
||||
if "partial_variables" in kwargs:
|
||||
partial_variables = kwargs["partial_variables"]
|
||||
input_variables = {
|
||||
var for var in input_variables if var not in partial_variables
|
||||
}
|
||||
|
||||
return cls(
|
||||
input_variables=list(sorted(input_variables)), template=template, **kwargs
|
||||
)
|
||||
|
@ -56,6 +56,30 @@ def create_chat_prompt_template() -> ChatPromptTemplate:
|
||||
)
|
||||
|
||||
|
||||
def test_create_chat_prompt_template_from_template() -> None:
|
||||
"""Create a chat prompt template."""
|
||||
prompt = ChatPromptTemplate.from_template("hi {foo} {bar}")
|
||||
assert prompt.messages == [
|
||||
HumanMessagePromptTemplate.from_template("hi {foo} {bar}")
|
||||
]
|
||||
|
||||
|
||||
def test_create_chat_prompt_template_from_template_partial() -> None:
|
||||
"""Create a chat prompt template with partials."""
|
||||
prompt = ChatPromptTemplate.from_template(
|
||||
"hi {foo} {bar}", partial_variables={"foo": "jim"}
|
||||
)
|
||||
expected_prompt = PromptTemplate(
|
||||
template="hi {foo} {bar}",
|
||||
input_variables=["bar"],
|
||||
partial_variables={"foo": "jim"},
|
||||
)
|
||||
assert len(prompt.messages) == 1
|
||||
output_prompt = prompt.messages[0]
|
||||
assert isinstance(output_prompt, HumanMessagePromptTemplate)
|
||||
assert output_prompt.prompt == expected_prompt
|
||||
|
||||
|
||||
def test_chat_prompt_template() -> None:
|
||||
"""Test chat prompt template."""
|
||||
prompt_template = create_chat_prompt_template()
|
||||
|
Loading…
Reference in New Issue
Block a user