diff --git a/langchain/prompts/base.py b/langchain/prompts/base.py index 5221ff36..60cc7804 100644 --- a/langchain/prompts/base.py +++ b/langchain/prompts/base.py @@ -69,14 +69,21 @@ class RegexParser(BaseOutputParser, BaseModel): regex: str output_keys: List[str] + default_output_key: Optional[str] = None def parse(self, text: str) -> Dict[str, str]: """Parse the output of an LLM call.""" match = re.search(self.regex, text) if match: - return {key: match.group(i) for i, key in enumerate(self.output_keys)} + return {key: match.group(i + 1) for i, key in enumerate(self.output_keys)} else: - raise ValueError(f"Could not parse output: {text}") + if self.default_output_key is None: + raise ValueError(f"Could not parse output: {text}") + else: + return { + key: text if key == self.default_output_key else "" + for key in self.output_keys + } class BasePromptTemplate(BaseModel, ABC):