|
|
|
@ -2,10 +2,10 @@
|
|
|
|
|
import json
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Any, Dict, List, Union
|
|
|
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
|
|
|
|
|
|
import yaml
|
|
|
|
|
from pydantic import BaseModel, root_validator
|
|
|
|
|
from pydantic import BaseModel, Extra, root_validator
|
|
|
|
|
|
|
|
|
|
from langchain.formatting import formatter
|
|
|
|
|
|
|
|
|
@ -32,11 +32,35 @@ def check_valid_template(
|
|
|
|
|
raise ValueError("Invalid prompt schema.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseOutputParser(ABC):
|
|
|
|
|
"""Class to parse the output of an LLM call."""
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def parse(self, text: str) -> Union[str, List[str], Dict[str, str]]:
|
|
|
|
|
"""Parse the output of an LLM call."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ListOutputParser(ABC):
|
|
|
|
|
"""Class to parse the output of an LLM call to a list."""
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def parse(self, text: str) -> List[str]:
|
|
|
|
|
"""Parse the output of an LLM call."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BasePromptTemplate(BaseModel, ABC):
|
|
|
|
|
"""Base prompt should expose the format method, returning a prompt."""
|
|
|
|
|
|
|
|
|
|
input_variables: List[str]
|
|
|
|
|
"""A list of the names of the variables the prompt template expects."""
|
|
|
|
|
output_parser: Optional[BaseOutputParser] = None
|
|
|
|
|
"""How to parse the output of calling an LLM on this formatted prompt."""
|
|
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
|
"""Configuration for this pydantic object."""
|
|
|
|
|
|
|
|
|
|
extra = Extra.forbid
|
|
|
|
|
arbitrary_types_allowed = True
|
|
|
|
|
|
|
|
|
|
@root_validator()
|
|
|
|
|
def validate_variable_names(cls, values: Dict) -> Dict:
|
|
|
|
|