forked from Archives/langchain
Nc/combining output parser (#3014)
Co-authored-by: vowelparrot <130414180+vowelparrot@users.noreply.github.com>fix_agent_callbacks
parent
79bb5c4f95
commit
dac32c59e5
@ -0,0 +1,51 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from pydantic import root_validator
|
||||||
|
|
||||||
|
from langchain.schema import BaseOutputParser
|
||||||
|
|
||||||
|
|
||||||
|
class CombiningOutputParser(BaseOutputParser):
|
||||||
|
"""Class to combine multiple output parsers into one."""
|
||||||
|
|
||||||
|
parsers: List[BaseOutputParser]
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_parsers(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Validate the parsers."""
|
||||||
|
parsers = values["parsers"]
|
||||||
|
if len(parsers) < 2:
|
||||||
|
raise ValueError("Must have at least two parsers")
|
||||||
|
for parser in parsers:
|
||||||
|
if parser._type == "combining":
|
||||||
|
raise ValueError("Cannot nest combining parsers")
|
||||||
|
if parser._type == "list":
|
||||||
|
raise ValueError("Cannot comine list parsers")
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
"""Return the type key."""
|
||||||
|
return "combining"
|
||||||
|
|
||||||
|
def get_format_instructions(self) -> str:
|
||||||
|
"""Instructions on how the LLM output should be formatted."""
|
||||||
|
|
||||||
|
initial = f"For your first output: {self.parsers[0].get_format_instructions()}"
|
||||||
|
subsequent = "\n".join(
|
||||||
|
[
|
||||||
|
f"Complete that output fully. Then produce another output, separated by two newline characters: {p.get_format_instructions()}" # noqa: E501
|
||||||
|
for p in self.parsers[1:]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return f"{initial}\n{subsequent}"
|
||||||
|
|
||||||
|
def parse(self, text: str) -> Dict[str, Any]:
|
||||||
|
"""Parse the output of an LLM call."""
|
||||||
|
texts = text.split("\n\n")
|
||||||
|
output = dict()
|
||||||
|
for i, parser in enumerate(self.parsers):
|
||||||
|
output.update(parser.parse(texts[i].strip()))
|
||||||
|
return output
|
@ -0,0 +1,45 @@
|
|||||||
|
"""Test in memory docstore."""
|
||||||
|
from langchain.output_parsers.combining import CombiningOutputParser
|
||||||
|
from langchain.output_parsers.regex import RegexParser
|
||||||
|
from langchain.output_parsers.structured import ResponseSchema, StructuredOutputParser
|
||||||
|
|
||||||
|
DEF_EXPECTED_RESULT = {
|
||||||
|
"answer": "Paris",
|
||||||
|
"source": "https://en.wikipedia.org/wiki/France",
|
||||||
|
"confidence": "A",
|
||||||
|
"explanation": "Paris is the capital of France according to Wikipedia.",
|
||||||
|
}
|
||||||
|
|
||||||
|
DEF_README = """```json
|
||||||
|
{
|
||||||
|
"answer": "Paris",
|
||||||
|
"source": "https://en.wikipedia.org/wiki/France"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
//Confidence: A, Explanation: Paris is the capital of France according to Wikipedia."""
|
||||||
|
|
||||||
|
|
||||||
|
def test_combining_dict_result() -> None:
|
||||||
|
"""Test combining result."""
|
||||||
|
parsers = [
|
||||||
|
StructuredOutputParser(
|
||||||
|
response_schemas=[
|
||||||
|
ResponseSchema(
|
||||||
|
name="answer", description="answer to the user's question"
|
||||||
|
),
|
||||||
|
ResponseSchema(
|
||||||
|
name="source",
|
||||||
|
description="source used to answer the user's question",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
RegexParser(
|
||||||
|
regex=r"Confidence: (A|B|C), Explanation: (.*)",
|
||||||
|
output_keys=["confidence", "explanation"],
|
||||||
|
default_output_key="noConfidence",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
combining_parser = CombiningOutputParser(parsers=parsers)
|
||||||
|
result_dict = combining_parser.parse(DEF_README)
|
||||||
|
assert DEF_EXPECTED_RESULT == result_dict
|
Loading…
Reference in New Issue