|
|
|
@ -48,13 +48,24 @@ def check_valid_template(
|
|
|
|
|
raise ValueError("Invalid prompt schema.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseOutputParser(ABC):
|
|
|
|
|
class BaseOutputParser(BaseModel, ABC):
|
|
|
|
|
"""Class to parse the output of an LLM call."""
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]:
|
|
|
|
|
"""Parse the output of an LLM call."""
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _type(self) -> str:
|
|
|
|
|
"""Return the type key."""
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
def dict(self, **kwargs: Any) -> Dict:
|
|
|
|
|
"""Return dictionary representation of output parser."""
|
|
|
|
|
output_parser_dict = super().dict()
|
|
|
|
|
output_parser_dict["_type"] = self._type
|
|
|
|
|
return output_parser_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ListOutputParser(BaseOutputParser):
|
|
|
|
|
"""Class to parse the output of an LLM call to a list."""
|
|
|
|
@ -79,6 +90,11 @@ class RegexParser(BaseOutputParser, BaseModel):
|
|
|
|
|
output_keys: List[str]
|
|
|
|
|
default_output_key: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _type(self) -> str:
|
|
|
|
|
"""Return the type key."""
|
|
|
|
|
return "regex_parser"
|
|
|
|
|
|
|
|
|
|
def parse(self, text: str) -> Dict[str, str]:
|
|
|
|
|
"""Parse the output of an LLM call."""
|
|
|
|
|
match = re.search(self.regex, text)
|
|
|
|
@ -142,7 +158,7 @@ class BasePromptTemplate(BaseModel, ABC):
|
|
|
|
|
|
|
|
|
|
def dict(self, **kwargs: Any) -> Dict:
|
|
|
|
|
"""Return dictionary representation of prompt."""
|
|
|
|
|
prompt_dict = super().dict()
|
|
|
|
|
prompt_dict = super().dict(**kwargs)
|
|
|
|
|
prompt_dict["_type"] = self._prompt_type
|
|
|
|
|
return prompt_dict
|
|
|
|
|
|
|
|
|
|