|
|
|
@ -1,8 +1,8 @@
|
|
|
|
|
"""BasePrompt schema definition."""
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
from typing import Any, Dict, List
|
|
|
|
|
from typing import Any, Dict, List, Union
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, root_validator
|
|
|
|
|
from pydantic import BaseModel, Field, root_validator
|
|
|
|
|
|
|
|
|
|
from langchain.formatting import formatter
|
|
|
|
|
|
|
|
|
@ -29,11 +29,29 @@ def check_valid_template(
|
|
|
|
|
raise ValueError("Invalid prompt schema.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OutputParser(ABC):
|
|
|
|
|
"""Class to parse the output of an LLM call."""
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]:
|
|
|
|
|
"""Parse the output of an LLM call."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DefaultParser(OutputParser):
|
|
|
|
|
"""Just return the text."""
|
|
|
|
|
|
|
|
|
|
def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]:
|
|
|
|
|
"""Parse the output of an LLM call."""
|
|
|
|
|
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BasePromptTemplate(BaseModel, ABC):
|
|
|
|
|
"""Base prompt should expose the format method, returning a prompt."""
|
|
|
|
|
|
|
|
|
|
input_variables: List[str]
|
|
|
|
|
"""A list of the names of the variables the prompt template expects."""
|
|
|
|
|
output_parser: OutputParser = Field(default_factory=DefaultParser)
|
|
|
|
|
"""How to parse the output of calling an LLM on this formatted prompt."""
|
|
|
|
|
|
|
|
|
|
@root_validator()
|
|
|
|
|
def validate_variable_names(cls, values: Dict) -> Dict:
|
|
|
|
|