From 7d90691adbf8adc18e386bf11890ee0f4cff2097 Mon Sep 17 00:00:00 2001 From: Tim Asp <707699+timothyasp@users.noreply.github.com> Date: Wed, 29 Mar 2023 22:13:27 -0700 Subject: [PATCH] Add kwargs to from_* in PrompTemplate (#2161) This will let us use output parsers, etc, while using the `from_*` helper functions --- langchain/prompts/prompt.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/langchain/prompts/prompt.py b/langchain/prompts/prompt.py index f2cf0aed..7210c52f 100644 --- a/langchain/prompts/prompt.py +++ b/langchain/prompts/prompt.py @@ -82,6 +82,7 @@ class PromptTemplate(StringPromptTemplate, BaseModel): input_variables: List[str], example_separator: str = "\n\n", prefix: str = "", + **kwargs: Any, ) -> PromptTemplate: """Take examples in list format with prefix and suffix to create a prompt. @@ -102,11 +103,11 @@ class PromptTemplate(StringPromptTemplate, BaseModel): The final prompt generated. """ template = example_separator.join([prefix, *examples, suffix]) - return cls(input_variables=input_variables, template=template) + return cls(input_variables=input_variables, template=template, **kwargs) @classmethod def from_file( - cls, template_file: Union[str, Path], input_variables: List[str] + cls, template_file: Union[str, Path], input_variables: List[str], **kwargs: Any ) -> PromptTemplate: """Load a prompt from a file. @@ -119,15 +120,17 @@ class PromptTemplate(StringPromptTemplate, BaseModel): """ with open(str(template_file), "r") as f: template = f.read() - return cls(input_variables=input_variables, template=template) + return cls(input_variables=input_variables, template=template, **kwargs) @classmethod - def from_template(cls, template: str) -> PromptTemplate: + def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate: """Load a prompt template from a template.""" input_variables = { v for _, v, _, _ in Formatter().parse(template) if v is not None } - return cls(input_variables=list(sorted(input_variables)), template=template) + return cls( + input_variables=list(sorted(input_variables)), template=template, **kwargs + ) # For backwards compatibility.