forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
46 lines
1.5 KiB
Python
46 lines
1.5 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
from typing import Dict, Optional
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from langchain.schema import BaseOutputParser
|
|
|
|
|
|
class RegexDictParser(BaseOutputParser, BaseModel):
|
|
"""Class to parse the output into a dictionary."""
|
|
|
|
regex_pattern: str = r"{}:\s?([^.'\n']*)\.?" # : :meta private:
|
|
output_key_to_format: Dict[str, str]
|
|
no_update_value: Optional[str] = None
|
|
|
|
@property
|
|
def _type(self) -> str:
|
|
"""Return the type key."""
|
|
return "regex_dict_parser"
|
|
|
|
def parse(self, text: str) -> Dict[str, str]:
|
|
"""Parse the output of an LLM call."""
|
|
result = {}
|
|
for output_key, expected_format in self.output_key_to_format.items():
|
|
specific_regex = self.regex_pattern.format(re.escape(expected_format))
|
|
matches = re.findall(specific_regex, text)
|
|
if not matches:
|
|
raise ValueError(
|
|
f"No match found for output key: {output_key} with expected format \
|
|
{expected_format} on text {text}"
|
|
)
|
|
elif len(matches) > 1:
|
|
raise ValueError(
|
|
f"Multiple matches found for output key: {output_key} with \
|
|
expected format {expected_format} on text {text}"
|
|
)
|
|
elif (
|
|
self.no_update_value is not None and matches[0] == self.no_update_value
|
|
):
|
|
continue
|
|
else:
|
|
result[output_key] = matches[0]
|
|
return result
|