|
|
|
@ -1,5 +1,6 @@
|
|
|
|
|
"""BasePrompt schema definition."""
|
|
|
|
|
import json
|
|
|
|
|
import re
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Any, Callable, Dict, List, Optional, Union
|
|
|
|
@ -55,7 +56,7 @@ class BaseOutputParser(ABC):
|
|
|
|
|
"""Parse the output of an LLM call."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ListOutputParser(ABC):
|
|
|
|
|
class ListOutputParser(BaseOutputParser):
|
|
|
|
|
"""Class to parse the output of an LLM call to a list."""
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
@ -63,6 +64,21 @@ class ListOutputParser(ABC):
|
|
|
|
|
"""Parse the output of an LLM call."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RegexParser(BaseOutputParser, BaseModel):
|
|
|
|
|
"""Class to parse the output into a dictionary."""
|
|
|
|
|
|
|
|
|
|
regex: str
|
|
|
|
|
output_keys: List[str]
|
|
|
|
|
|
|
|
|
|
def parse(self, text: str) -> Dict[str, str]:
|
|
|
|
|
"""Parse the output of an LLM call."""
|
|
|
|
|
match = re.search(self.regex, text)
|
|
|
|
|
if match:
|
|
|
|
|
return {key: match.group(i) for i, key in enumerate(self.output_keys)}
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Could not parse output: {text}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BasePromptTemplate(BaseModel, ABC):
|
|
|
|
|
"""Base prompt should expose the format method, returning a prompt."""
|
|
|
|
|
|
|
|
|
|