Add kwargs to from_* in PrompTemplate (#2161)

This will let us use output parsers, etc, while using the `from_*`
helper functions
This commit is contained in:
Tim Asp 2023-03-29 22:13:27 -07:00 committed by GitHub
parent f83c36d8fd
commit 7d90691adb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -82,6 +82,7 @@ class PromptTemplate(StringPromptTemplate, BaseModel):
input_variables: List[str],
example_separator: str = "\n\n",
prefix: str = "",
**kwargs: Any,
) -> PromptTemplate:
"""Take examples in list format with prefix and suffix to create a prompt.
@ -102,11 +103,11 @@ class PromptTemplate(StringPromptTemplate, BaseModel):
The final prompt generated.
"""
template = example_separator.join([prefix, *examples, suffix])
return cls(input_variables=input_variables, template=template)
return cls(input_variables=input_variables, template=template, **kwargs)
@classmethod
def from_file(
cls, template_file: Union[str, Path], input_variables: List[str]
cls, template_file: Union[str, Path], input_variables: List[str], **kwargs: Any
) -> PromptTemplate:
"""Load a prompt from a file.
@ -119,15 +120,17 @@ class PromptTemplate(StringPromptTemplate, BaseModel):
"""
with open(str(template_file), "r") as f:
template = f.read()
return cls(input_variables=input_variables, template=template)
return cls(input_variables=input_variables, template=template, **kwargs)
@classmethod
def from_template(cls, template: str) -> PromptTemplate:
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
"""Load a prompt template from a template."""
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
return cls(input_variables=list(sorted(input_variables)), template=template)
return cls(
input_variables=list(sorted(input_variables)), template=template, **kwargs
)
# For backwards compatibility.