forked from Archives/langchain
feat: add MultiStrategy output parser
- A type of parser where many strategies can be tried before exception - A strategy is a tuple like class of (parser, predicate, name=None) - Strategies are tried if predicate is True - Strategies are tried in order, allows for fallbacks - Base interface allows existing parsers to use multiple strategies - New strategies can be added for new output errors and covered by tests
This commit is contained in:
parent
a9108c1809
commit
55e0e2d6ac
@ -2,14 +2,22 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
from langchain.schema import OutputParserException
|
from langchain.schema import OutputParserException
|
||||||
|
|
||||||
|
REGEXES = {
|
||||||
|
"json_markdown": r"```(json)?(.*?)```",
|
||||||
|
# must use greedy matching to match the outermost code block
|
||||||
|
"nested_json_md_code_block": r"```(json)?(.*)```",
|
||||||
|
}
|
||||||
|
|
||||||
def parse_json_markdown(json_string: str) -> dict:
|
|
||||||
|
def parse_json_markdown(json_string: str, regex: Optional[str] = None) -> dict:
|
||||||
# Try to find JSON string within triple backticks
|
# Try to find JSON string within triple backticks
|
||||||
match = re.search(r"```(json)?(.*?)```", json_string, re.DOTALL)
|
if regex is None:
|
||||||
|
regex = REGEXES["json_markdown"]
|
||||||
|
match = re.search(regex, json_string, re.DOTALL)
|
||||||
|
|
||||||
# If no match found, assume the entire string is a JSON string
|
# If no match found, assume the entire string is a JSON string
|
||||||
if match is None:
|
if match is None:
|
||||||
@ -27,6 +35,56 @@ def parse_json_markdown(json_string: str) -> dict:
|
|||||||
return parsed
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
|
def fix_code_in_json(text: str) -> str:
|
||||||
|
"""Fixes nested code block in json markdown"""
|
||||||
|
# Extract the code block and replace it with a placeholder
|
||||||
|
pattern = r"```([^`]*?)```"
|
||||||
|
match = re.search(pattern, text)
|
||||||
|
if match:
|
||||||
|
code_block = match.group(1)
|
||||||
|
text = re.sub(pattern, "CODE_BLOCK_PLACEHOLDER", text, count=1)
|
||||||
|
|
||||||
|
# Escape the special characters in the code block
|
||||||
|
escaped_code_block = (
|
||||||
|
code_block.replace("\n", "\\n").replace("\t", "\\t").replace('"', '\\"')
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add backtick pairs to escaped code block
|
||||||
|
escaped_code_block = "[BEGIN_CODE]" + escaped_code_block + "[END_CODE]"
|
||||||
|
|
||||||
|
# Replace the placeholder in the original text with the escaped code block
|
||||||
|
text = text.replace("CODE_BLOCK_PLACEHOLDER", escaped_code_block)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def fix_json_with_embedded_code_block(text: str, max_loop: int = 20) -> dict:
|
||||||
|
"""Try to fix json with embedded code block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: JSON string with embedded code block
|
||||||
|
max_loop: Maximum number of loops to try fixing the JSON string
|
||||||
|
"""
|
||||||
|
loop = 0
|
||||||
|
while True:
|
||||||
|
if loop > max_loop:
|
||||||
|
raise ValueError("Max loop reached")
|
||||||
|
try:
|
||||||
|
text = fix_code_in_json(text)
|
||||||
|
json.loads(text)
|
||||||
|
break
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
if text[e.pos] == "\n":
|
||||||
|
text = text[: e.pos] + "\\n" + text[e.pos + 1 :]
|
||||||
|
text = text.replace("[BEGIN_CODE]", "```")
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
loop += 1
|
||||||
|
final_text = text.replace("[END_CODE]", "```")
|
||||||
|
return json.loads(final_text)
|
||||||
|
|
||||||
|
|
||||||
def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
|
def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
|
||||||
try:
|
try:
|
||||||
json_obj = parse_json_markdown(text)
|
json_obj = parse_json_markdown(text)
|
||||||
|
0
langchain/output_parsers/multi_strategy/__init__.py
Normal file
0
langchain/output_parsers/multi_strategy/__init__.py
Normal file
38
langchain/output_parsers/multi_strategy/agent.py
Normal file
38
langchain/output_parsers/multi_strategy/agent.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
"""Multi strategy parser that implements AgentOutputParser."""
|
||||||
|
from typing import Any, Sequence, Union
|
||||||
|
|
||||||
|
from langchain.agents.agent import AgentOutputParser
|
||||||
|
from langchain.agents.conversational_chat.prompt import FORMAT_INSTRUCTIONS
|
||||||
|
from langchain.output_parsers.multi_strategy import strategies
|
||||||
|
from langchain.output_parsers.multi_strategy.base import (
|
||||||
|
MultiStrategyParser,
|
||||||
|
ParseStrategy,
|
||||||
|
)
|
||||||
|
from langchain.schema import (
|
||||||
|
AgentAction,
|
||||||
|
AgentFinish,
|
||||||
|
)
|
||||||
|
|
||||||
|
U = Union[AgentAction, AgentFinish]
|
||||||
|
TReactAgentOutput = U
|
||||||
|
|
||||||
|
|
||||||
|
class ConvMultiStrategyParser(MultiStrategyParser[U, dict], AgentOutputParser):
|
||||||
|
"""Multi strategy parser that implements AgentOutputParser."""
|
||||||
|
|
||||||
|
def get_format_instructions(self) -> str:
|
||||||
|
return FORMAT_INSTRUCTIONS
|
||||||
|
|
||||||
|
def __init__(self, strategies: Sequence[ParseStrategy[dict]],
|
||||||
|
**kwargs: dict) -> None:
|
||||||
|
super().__init__(strategies=strategies, **kwargs)
|
||||||
|
|
||||||
|
def final_parse(self, text: str, parsed: dict) -> U:
|
||||||
|
action, action_input = parsed["action"], parsed["action_input"]
|
||||||
|
if action == "Final Answer":
|
||||||
|
return AgentFinish({"output": action_input}, text)
|
||||||
|
else:
|
||||||
|
return AgentAction(action, action_input, text)
|
||||||
|
|
||||||
|
|
||||||
|
default_parser = ConvMultiStrategyParser(strategies.json_react_strategies)
|
117
langchain/output_parsers/multi_strategy/base.py
Normal file
117
langchain/output_parsers/multi_strategy/base.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
"""Multi strategy output parser."""
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Callable, Generic, Iterator, Sequence, TypeVar, Union, Optional
|
||||||
|
|
||||||
|
from langchain.schema import (
|
||||||
|
BaseOutputParser,
|
||||||
|
OutputParserException,
|
||||||
|
)
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
S = TypeVar("S")
|
||||||
|
|
||||||
|
TPredicate = Callable[[str], bool]
|
||||||
|
TParser = Callable[[str], S]
|
||||||
|
|
||||||
|
|
||||||
|
class ParseStrategy(Generic[S]):
|
||||||
|
"""A strategy is a pair of (parser, predicate).
|
||||||
|
|
||||||
|
This class behave like a tuple for easy definition of multiple strategies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, parser: TParser[S], predicate: TPredicate, name: Optional[str] = None
|
||||||
|
):
|
||||||
|
assert callable(parser), "first argument <parser> must be callable"
|
||||||
|
self.parser = parser
|
||||||
|
assert callable(predicate), "second argument <predicate> must be callable"
|
||||||
|
self.predicate = predicate
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
if self.name is None:
|
||||||
|
return f"ParseStrategy(parser={self.parser}," "predicate={self.predicate})"
|
||||||
|
return (
|
||||||
|
f"ParseStrategy[{self.name}](parser={self.parser},"
|
||||||
|
"predicate={self.predicate})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> Union[TParser[S], TPredicate]:
|
||||||
|
"""Behaves like a tuple."""
|
||||||
|
if index == 0:
|
||||||
|
return self.parser
|
||||||
|
elif index == 1:
|
||||||
|
return self.predicate
|
||||||
|
else:
|
||||||
|
raise IndexError("tuple index out of range")
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[Any]:
|
||||||
|
"""Implement tuple unpacking."""
|
||||||
|
yield self.parser
|
||||||
|
yield self.predicate
|
||||||
|
|
||||||
|
|
||||||
|
class MultiStrategyParser(BaseOutputParser[T], ABC, Generic[T, S]):
|
||||||
|
"""Try multiple strategies to parse the output.
|
||||||
|
|
||||||
|
A strategy is a tuple of (parser, predicate). The parser takes the some
|
||||||
|
text as input and returns some type S. The parser is only called if the
|
||||||
|
predicate returns True.
|
||||||
|
|
||||||
|
When the `parse` method is called, all registered strategies are tried
|
||||||
|
in order and the first one that succeeds returns its result.
|
||||||
|
|
||||||
|
The returned value of type `S` is then passed to the final_parse method to
|
||||||
|
produce the final result compatible with the inhertited output parser
|
||||||
|
interface.
|
||||||
|
|
||||||
|
Appending a strategy to the end makes it a fallback strategy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
strategies: Sequence[ParseStrategy[S]]
|
||||||
|
"""List of strategies to try. The first one that succeeds is returned."""
|
||||||
|
|
||||||
|
def add_strategy(self, *strategy: ParseStrategy[S]) -> None:
|
||||||
|
"""Register a new strategy.
|
||||||
|
|
||||||
|
A strategy is a callbale that takes in text as `str` and returns
|
||||||
|
some type `S`.
|
||||||
|
"""
|
||||||
|
self.strategies = [*self.strategies, *strategy]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def final_parse(self, text: str, parsed: S) -> T:
|
||||||
|
"""Parse the output of a strategy."""
|
||||||
|
|
||||||
|
def parse(self, text: str) -> T:
|
||||||
|
"""Try the registered strategies in order.
|
||||||
|
|
||||||
|
Returns the output of the first succeeding strategy."""
|
||||||
|
|
||||||
|
if len(self.strategies) == 0:
|
||||||
|
raise OutputParserException("No strategy available")
|
||||||
|
for strategy, predicate in self.strategies:
|
||||||
|
log.debug(f"trying strategy {strategy}")
|
||||||
|
if not predicate(text):
|
||||||
|
log.debug(f"Skipping strategy {strategy}")
|
||||||
|
if predicate(text):
|
||||||
|
try:
|
||||||
|
parsed = strategy(text)
|
||||||
|
result = self.final_parse(text, parsed)
|
||||||
|
log.debug(f"Strategy {strategy} succeeded")
|
||||||
|
return result
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise OutputParserException(f"Could not parse output: {text}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
return "multi_strategy"
|
49
langchain/output_parsers/multi_strategy/strategies.py
Normal file
49
langchain/output_parsers/multi_strategy/strategies.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
"""Strategies used with MultiStrategyParser parsers."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
from langchain.output_parsers.json import (
|
||||||
|
REGEXES,
|
||||||
|
fix_json_with_embedded_code_block,
|
||||||
|
parse_json_markdown,
|
||||||
|
)
|
||||||
|
from langchain.output_parsers.multi_strategy.base import ParseStrategy
|
||||||
|
|
||||||
|
|
||||||
|
def is_bare_json(text: str) -> dict:
|
||||||
|
"""Tries to load as bare json"""
|
||||||
|
return json.loads(text.strip())
|
||||||
|
|
||||||
|
|
||||||
|
def json_markdown(text: str) -> dict:
|
||||||
|
"""Extract a json object from markdown markup"""
|
||||||
|
return parse_json_markdown(text)
|
||||||
|
|
||||||
|
|
||||||
|
def json_nested_md_code_block(text: str) -> dict:
|
||||||
|
"""Extract the outermost code block. Can accomodate nested code blocks."""
|
||||||
|
return parse_json_markdown(text, regex=REGEXES["nested_json_md_code_block"])
|
||||||
|
|
||||||
|
|
||||||
|
def fallback(text: str) -> dict:
|
||||||
|
"""Example fallback strategy."""
|
||||||
|
return {"action": "Final Answer", "action_input": text}
|
||||||
|
|
||||||
|
|
||||||
|
# The order of the strategies is important
|
||||||
|
# They are tried in order and the first one that matches is used
|
||||||
|
json_react_strategies = (
|
||||||
|
ParseStrategy(is_bare_json, lambda text: text.startswith("{"), name="bare_json"),
|
||||||
|
ParseStrategy(json_markdown, lambda text: text.find("```") != -1),
|
||||||
|
ParseStrategy(
|
||||||
|
json_nested_md_code_block,
|
||||||
|
lambda text: text.find("```") != -1,
|
||||||
|
name="nested_code_block",
|
||||||
|
),
|
||||||
|
ParseStrategy(
|
||||||
|
fix_json_with_embedded_code_block,
|
||||||
|
lambda text: text.find("```") != -1,
|
||||||
|
name="fix_embedded_code_block",
|
||||||
|
),
|
||||||
|
# this is where a fallback would go
|
||||||
|
# ParseStrategy(fallback, lambda _: True),
|
||||||
|
)
|
4
tests/unit_tests/data/llm_outputs/bare_json
Normal file
4
tests/unit_tests/data/llm_outputs/bare_json
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"action": "Final Answer",
|
||||||
|
"action_input": "To implement a Singleton class in Python, you can define a class with a private constructor, a class variable to store the instance and a static method to get the instance. Here's an example:\n\n```python\nclass Singleton:\n __instance = None\n\n def __init__(self):\n if Singleton.__instance != None:\n raise Exception('You cannot create more than one instance of Singleton class.')\n else:\n Singleton.__instance = self\n\n @staticmethod \n def getInstance():\n if Singleton.__instance == None:\n Singleton()\n return Singleton.__instance\n```"
|
||||||
|
}
|
25
tests/unit_tests/data/llm_outputs/bare_json_embed_code_block
Normal file
25
tests/unit_tests/data/llm_outputs/bare_json_embed_code_block
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
{
|
||||||
|
"action": "Final Answer",
|
||||||
|
"action_input": "Sure, here is a simple pseudo code representation of the proof of work algorithm:
|
||||||
|
|
||||||
|
```
|
||||||
|
function proofOfWork(block, difficulty):
|
||||||
|
target = "0" * difficulty
|
||||||
|
nonce = 0
|
||||||
|
while True:
|
||||||
|
hash = calculateHash(block, nonce)
|
||||||
|
if hash.startswith(target):
|
||||||
|
return nonce
|
||||||
|
nonce += 1
|
||||||
|
|
||||||
|
block = getBlockData()
|
||||||
|
difficulty = getDifficulty()
|
||||||
|
nonce = proofOfWork(block, difficulty)
|
||||||
|
```
|
||||||
|
|
||||||
|
In this pseudo code, the `proofOfWork` function takes a `block` and a `difficulty` as input. It initializes a `target` string with the desired number of leading zeros based on the difficulty. The function then starts a loop and calculates the hash of the `block` with an incremented `nonce` value. If the hash starts with the required number of zeros, the function returns the `nonce`. Otherwise, it increments the `nonce` and continues the loop until a valid solution is found.
|
||||||
|
|
||||||
|
To use the proof of work algorithm, you would need to provide the `block` data and the desired `difficulty` level. The algorithm will return the `nonce` value that satisfies the proof of work requirements.
|
||||||
|
|
||||||
|
Please note that this is a simplified representation of the algorithm and actual implementations may have additional complexities and optimizations."
|
||||||
|
}
|
@ -0,0 +1,16 @@
|
|||||||
|
Here is an example implementation of a singleton class in Python:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class Singleton:
|
||||||
|
_instance = None
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
print("Creating new instance")
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
else:
|
||||||
|
print("Using existing instance")
|
||||||
|
return cls._instance
|
||||||
|
```
|
||||||
|
|
||||||
|
In this implementation, the `_instance` variable keeps track of whether an instance of the class has already been created. The `__new__` method is called when an instance of the class is requested. If an instance has already been created, it returns that instance. Otherwise, it creates a new instance and returns that.
|
10
tests/unit_tests/data/llm_outputs/json_nested_code_block
Normal file
10
tests/unit_tests/data/llm_outputs/json_nested_code_block
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
I apologize for the previous incomplete response. Here's the response in the required format:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"action": "Final Answer",
|
||||||
|
"action_input": "To implement a singleton class in Python, you can use a decorator or a metaclass. Here's an example of using a decorator:\n\n```python\nfrom functools import wraps\n\ndef singleton(cls):\n instances = {}\n\n @wraps(cls)\n def get_instance(*args, **kwargs):\n if cls not in instances:\n instances[cls] = cls(*args, **kwargs)\n return instances[cls]\n\n return get_instance\n\n@singleton\nclass MyClass:\n pass\n```"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
I hope this helps! Let me know if you have any other questions.
|
56
tests/unit_tests/output_parsers/test_multi_strategy.py
Normal file
56
tests/unit_tests/output_parsers/test_multi_strategy.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from typing import List, Tuple, Any
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.output_parsers.multi_strategy.base import MultiStrategyParser
|
||||||
|
from langchain.output_parsers.multi_strategy.agent import ConvMultiStrategyParser
|
||||||
|
from langchain.output_parsers.multi_strategy import strategies
|
||||||
|
|
||||||
|
# How the test works:
|
||||||
|
# it loads all llm output files from the ../data/llm_outputs directory
|
||||||
|
# For each file it tries a MultiStrategyParser with the strategies to test.
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_outputs() -> List[Tuple[str, str]]:
|
||||||
|
outputs = []
|
||||||
|
for path in (Path(__file__).parent.parent / "data/llm_outputs/").glob("*"):
|
||||||
|
with open(str(path), "r") as f:
|
||||||
|
outputs.append((f.read(), path.name))
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
llm_outputs = prepare_outputs()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("output, name", llm_outputs, ids=[x[1] for x in llm_outputs])
|
||||||
|
def test_json_react_strategies(
|
||||||
|
output: str, name: str, parser: MultiStrategyParser[Any, Any]
|
||||||
|
) -> None:
|
||||||
|
# the ignored test is for the fallback strategy
|
||||||
|
if name != "ignored_format_instructions":
|
||||||
|
_test_json_react_strategy(output, name, parser)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_json_react_strategy(
|
||||||
|
output: str, name: str, parser: MultiStrategyParser[Any, Any]
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
parser.parse(output)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error parsing output entry: {name}.")
|
||||||
|
|
||||||
|
|
||||||
|
def test_fix_json_with_embedded_code_block() -> None:
|
||||||
|
path = Path(__file__).parent.parent / "data/llm_outputs/bare_json_embed_code_block"
|
||||||
|
with open(str(path), "r") as f:
|
||||||
|
output = f.read()
|
||||||
|
res = strategies.fix_json_with_embedded_code_block(output)
|
||||||
|
assert type(res) == dict
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
res = strategies.fix_json_with_embedded_code_block(output, max_loop=1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(name="parser")
|
||||||
|
def conv_multi_strategy_parser() -> Any:
|
||||||
|
return ConvMultiStrategyParser(strategies.json_react_strategies)
|
Loading…
Reference in New Issue
Block a user