Add kwargs to from_* in PrompTemplate (#2161)

This will let us use output parsers, etc, while using the `from_*`
helper functions
This commit is contained in:
Tim Asp 2023-03-29 22:13:27 -07:00 committed by GitHub
parent f83c36d8fd
commit 7d90691adb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -82,6 +82,7 @@ class PromptTemplate(StringPromptTemplate, BaseModel):
input_variables: List[str], input_variables: List[str],
example_separator: str = "\n\n", example_separator: str = "\n\n",
prefix: str = "", prefix: str = "",
**kwargs: Any,
) -> PromptTemplate: ) -> PromptTemplate:
"""Take examples in list format with prefix and suffix to create a prompt. """Take examples in list format with prefix and suffix to create a prompt.
@ -102,11 +103,11 @@ class PromptTemplate(StringPromptTemplate, BaseModel):
The final prompt generated. The final prompt generated.
""" """
template = example_separator.join([prefix, *examples, suffix]) template = example_separator.join([prefix, *examples, suffix])
return cls(input_variables=input_variables, template=template) return cls(input_variables=input_variables, template=template, **kwargs)
@classmethod @classmethod
def from_file( def from_file(
cls, template_file: Union[str, Path], input_variables: List[str] cls, template_file: Union[str, Path], input_variables: List[str], **kwargs: Any
) -> PromptTemplate: ) -> PromptTemplate:
"""Load a prompt from a file. """Load a prompt from a file.
@ -119,15 +120,17 @@ class PromptTemplate(StringPromptTemplate, BaseModel):
""" """
with open(str(template_file), "r") as f: with open(str(template_file), "r") as f:
template = f.read() template = f.read()
return cls(input_variables=input_variables, template=template) return cls(input_variables=input_variables, template=template, **kwargs)
@classmethod @classmethod
def from_template(cls, template: str) -> PromptTemplate: def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
"""Load a prompt template from a template.""" """Load a prompt template from a template."""
input_variables = { input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None v for _, v, _, _ in Formatter().parse(template) if v is not None
} }
return cls(input_variables=list(sorted(input_variables)), template=template) return cls(
input_variables=list(sorted(input_variables)), template=template, **kwargs
)
# For backwards compatibility. # For backwards compatibility.