langchain[patch]: fix `OutputType` of OutputParsers and fix legacy API in OutputParsers (#19792)

# Description

This pull request aims to address specific issues related to the
ambiguity and error-proneness of the output types of certain output
parsers, as well as the absence of unit tests for some parsers. These
issues could potentially lead to runtime errors or unexpected behaviors
due to type mismatches when used, causing confusion for developers and
users. Through clarifying output types, this PR seeks to improve the
stability and reliability.

Therefore, this pull request

- fixes the `OutputType` of OutputParsers to be the expected type;
- e.g. `OutputType` property of `EnumOutputParser` raises `TypeError`.
This PR introduce a logic to extract `OutputType` from its attribute.
- and fixes the legacy API in OutputParsers like `LLMChain.run` to the
modern API like `LLMChain.invoke`;
- Note: For `OutputFixingParser`, `RetryOutputParser` and
`RetryWithErrorOutputParser`, this PR introduces `legacy` attribute with
False as default value in order to keep the backward compatibility
- and adds the tests for the `OutputFixingParser` and
`RetryOutputParser`.

The following table shows my expected output and the actual output of
the `OutputType` of OutputParsers.
I have used this table to fix `OutputType` of OutputParsers.

| Class Name of OutputParser | My Expected `OutputType` (after this PR)|
Actual `OutputType` [evidence](#evidence) (before this PR)| Fix Required
|
|---------|--------------|---------|--------|
| BooleanOutputParser | `<class 'bool'>` | `<class 'bool'>` | NO |
| CombiningOutputParser | `typing.Dict[str, Any]` | `TypeError` is
raised | YES |
| DatetimeOutputParser | `<class 'datetime.datetime'>` | `<class
'datetime.datetime'>` | NO |
| EnumOutputParser(enum=MyEnum) | `MyEnum` | `TypeError` is raised | YES
|
| OutputFixingParser | The same type as `self.parser.OutputType` | `~T`
| YES |
| CommaSeparatedListOutputParser | `typing.List[str]` |
`typing.List[str]` | NO |
| MarkdownListOutputParser | `typing.List[str]` | `typing.List[str]` |
NO |
| NumberedListOutputParser | `typing.List[str]` | `typing.List[str]` |
NO |
| JsonOutputKeyToolsParser | `typing.Any` | `typing.Any` | NO |
| JsonOutputToolsParser | `typing.Any` | `typing.Any` | NO |
| PydanticToolsParser | `typing.Any` | `typing.Any` | NO |
| PandasDataFrameOutputParser | `typing.Dict[str, Any]` | `TypeError` is
raised | YES |
| PydanticOutputParser(pydantic_object=MyModel) | `<class
'__main__.MyModel'>` | `<class '__main__.MyModel'>` | NO |
| RegexParser | `typing.Dict[str, str]` | `TypeError` is raised | YES |
| RegexDictParser | `typing.Dict[str, str]` | `TypeError` is raised |
YES |
| RetryOutputParser | The same type as `self.parser.OutputType` | `~T` |
YES |
| RetryWithErrorOutputParser | The same type as `self.parser.OutputType`
| `~T` | YES |
| StructuredOutputParser | `typing.Dict[str, Any]` | `TypeError` is
raised | YES |
| YamlOutputParser(pydantic_object=MyModel) | `MyModel` | `~T` | YES |

NOTE: In "Fix Required", "YES" means that it is required to fix in this
PR while "NO" means that it is not required.

# Issue

No issues for this PR.

# Twitter handle

- [hmdev3](https://twitter.com/hmdev3)

# Questions:

1. Is it required to create tests for legacy APIs `LLMChain.run` in the
following scripts?
   - libs/langchain/tests/unit_tests/output_parsers/test_fix.py;
   - libs/langchain/tests/unit_tests/output_parsers/test_retry.py.

2. Is there a more appropriate expected output type than I expect in the
above table?
- e.g. the `OutputType` of `CombiningOutputParser` should be
SOMETHING...

# Actual outputs (before this PR)

<div id='evidence'></div>

<details><summary>Actual outputs</summary>

## Requirements

- Python==3.9.13
- langchain==0.1.13

```python
Python 3.9.13 (tags/v3.9.13:6de2ca5, May 17 2022, 16:36:42) [MSC v.1929 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import langchain
>>> langchain.__version__
'0.1.13'
>>> from langchain import output_parsers
```

### `BooleanOutputParser`

```python
>>> output_parsers.BooleanOutputParser().OutputType
<class 'bool'>
```

### `CombiningOutputParser`

```python
>>> output_parsers.CombiningOutputParser(parsers=[output_parsers.DatetimeOutputParser(), output_parsers.CommaSeparatedListOutputParser()]).OutputType
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "D:\workspace\venv\lib\site-packages\langchain_core\output_parsers\base.py", line 160, in OutputType
    raise TypeError(
TypeError: Runnable CombiningOutputParser doesn't have an inferable OutputType. Override the OutputType property to specify the output type.
```

### `DatetimeOutputParser`

```python
>>> output_parsers.DatetimeOutputParser().OutputType
<class 'datetime.datetime'>
```

### `EnumOutputParser`

```python
>>> from enum import Enum
>>> class MyEnum(Enum):
...     a = 'a'
...     b = 'b'
...
>>> output_parsers.EnumOutputParser(enum=MyEnum).OutputType
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "D:\workspace\venv\lib\site-packages\langchain_core\output_parsers\base.py", line 160, in OutputType
    raise TypeError(
TypeError: Runnable EnumOutputParser doesn't have an inferable OutputType. Override the OutputType property to specify the output type.
```

### `OutputFixingParser`

```python
>>> output_parsers.OutputFixingParser(parser=output_parsers.DatetimeOutputParser()).OutputType
~T
```

### `CommaSeparatedListOutputParser`

```python
>>> output_parsers.CommaSeparatedListOutputParser().OutputType
typing.List[str]
```

### `MarkdownListOutputParser`

```python
>>> output_parsers.MarkdownListOutputParser().OutputType
typing.List[str]
```

### `NumberedListOutputParser`

```python
>>> output_parsers.NumberedListOutputParser().OutputType
typing.List[str]
```

### `JsonOutputKeyToolsParser`

```python
>>> output_parsers.JsonOutputKeyToolsParser(key_name='tool').OutputType
typing.Any
```

### `JsonOutputToolsParser`

```python
>>> output_parsers.JsonOutputToolsParser().OutputType
typing.Any
```

### `PydanticToolsParser`

```python
>>> from langchain.pydantic_v1 import BaseModel
>>> class MyModel(BaseModel):
...     a: int
...
>>> output_parsers.PydanticToolsParser(tools=[MyModel, MyModel]).OutputType
typing.Any
```

### `PandasDataFrameOutputParser`

```python
>>> output_parsers.PandasDataFrameOutputParser().OutputType
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "D:\workspace\venv\lib\site-packages\langchain_core\output_parsers\base.py", line 160, in OutputType
    raise TypeError(
TypeError: Runnable PandasDataFrameOutputParser doesn't have an inferable OutputType. Override the OutputType property to specify the output type.
```

### `PydanticOutputParser`

```python
>>> output_parsers.PydanticOutputParser(pydantic_object=MyModel).OutputType
<class '__main__.MyModel'>
```

### `RegexParser`

```python
>>> output_parsers.RegexParser(regex='$', output_keys=['a']).OutputType
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "D:\workspace\venv\lib\site-packages\langchain_core\output_parsers\base.py", line 160, in OutputType
    raise TypeError(
TypeError: Runnable RegexParser doesn't have an inferable OutputType. Override the OutputType property to specify the output type.
```

### `RegexDictParser`

```python
>>> output_parsers.RegexDictParser(output_key_to_format={'a':'a'}).OutputType
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "D:\workspace\venv\lib\site-packages\langchain_core\output_parsers\base.py", line 160, in OutputType
    raise TypeError(
TypeError: Runnable RegexDictParser doesn't have an inferable OutputType. Override the OutputType property to specify the output type.
```

### `RetryOutputParser`

```python
>>> output_parsers.RetryOutputParser(parser=output_parsers.DatetimeOutputParser()).OutputType
~T
```

### `RetryWithErrorOutputParser`

```python
>>> output_parsers.RetryWithErrorOutputParser(parser=output_parsers.DatetimeOutputParser()).OutputType
~T
```

### `StructuredOutputParser`

```python
>>> from langchain.output_parsers.structured import ResponseSchema
>>> response_schemas = [ResponseSchema(name="foo",description="a list of strings",type="List[string]"),ResponseSchema(name="bar",description="a string",type="string"), ]
>>> output_parsers.StructuredOutputParser.from_response_schemas(response_schemas).OutputType
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "D:\workspace\venv\lib\site-packages\langchain_core\output_parsers\base.py", line 160, in OutputType
    raise TypeError(
TypeError: Runnable StructuredOutputParser doesn't have an inferable OutputType. Override the OutputType property to specify the output type.
```

### `YamlOutputParser`

```python
>>> output_parsers.YamlOutputParser(pydantic_object=MyModel).OutputType
~T
```


<div>

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
pull/23135/head
hmasdev 4 months ago committed by GitHub
parent e271f75bee
commit ff0c06b1e5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -6,7 +6,7 @@ from langchain_core.output_parsers import BaseOutputParser
from langchain_core.pydantic_v1 import root_validator
class CombiningOutputParser(BaseOutputParser):
class CombiningOutputParser(BaseOutputParser[Dict[str, Any]]):
"""Combine multiple output parsers into one."""
parsers: List[BaseOutputParser]

@ -1,12 +1,12 @@
from enum import Enum
from typing import Any, Dict, List, Type
from typing import Dict, List, Type
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.pydantic_v1 import root_validator
class EnumOutputParser(BaseOutputParser):
class EnumOutputParser(BaseOutputParser[Enum]):
"""Parse an output that is one of a set of values."""
enum: Type[Enum]
@ -23,7 +23,7 @@ class EnumOutputParser(BaseOutputParser):
def _valid_values(self) -> List[str]:
return [e.value for e in self.enum]
def parse(self, response: str) -> Any:
def parse(self, response: str) -> Enum:
try:
return self.enum(response.strip())
except ValueError:
@ -34,3 +34,7 @@ class EnumOutputParser(BaseOutputParser):
def get_format_instructions(self) -> str:
return f"Select one of the following options: {', '.join(self._valid_values)}"
@property
def OutputType(self) -> Type[Enum]:
return self.enum

@ -1,11 +1,12 @@
from __future__ import annotations
from typing import Any, TypeVar
from typing import Any, TypeVar, Union
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.runnables import RunnableSerializable
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
@ -22,10 +23,12 @@ class OutputFixingParser(BaseOutputParser[T]):
parser: BaseOutputParser[T]
"""The parser to use to parse the output."""
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Any
"""The LLMChain to use to retry the completion."""
retry_chain: Union[RunnableSerializable, Any]
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
max_retries: int = 1
"""The maximum number of times to retry the parse."""
legacy: bool = True
"""Whether to use the run or arun method of the retry_chain."""
@classmethod
def from_llm(
@ -46,9 +49,7 @@ class OutputFixingParser(BaseOutputParser[T]):
Returns:
OutputFixingParser
"""
from langchain.chains.llm import LLMChain
chain = LLMChain(llm=llm, prompt=prompt)
chain = prompt | llm
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
def parse(self, completion: str) -> T:
@ -62,11 +63,29 @@ class OutputFixingParser(BaseOutputParser[T]):
raise e
else:
retries += 1
completion = self.retry_chain.run(
instructions=self.parser.get_format_instructions(),
completion=completion,
error=repr(e),
)
if self.legacy and hasattr(self.retry_chain, "run"):
completion = self.retry_chain.run(
instructions=self.parser.get_format_instructions(),
completion=completion,
error=repr(e),
)
else:
try:
completion = self.retry_chain.invoke(
dict(
instructions=self.parser.get_format_instructions(), # noqa: E501
input=completion,
error=repr(e),
)
)
except (NotImplementedError, AttributeError):
# Case: self.parser does not have get_format_instructions # noqa: E501
completion = self.retry_chain.invoke(
dict(
input=completion,
error=repr(e),
)
)
raise OutputParserException("Failed to parse")
@ -81,11 +100,29 @@ class OutputFixingParser(BaseOutputParser[T]):
raise e
else:
retries += 1
completion = await self.retry_chain.arun(
instructions=self.parser.get_format_instructions(),
completion=completion,
error=repr(e),
)
if self.legacy and hasattr(self.retry_chain, "arun"):
completion = await self.retry_chain.arun(
instructions=self.parser.get_format_instructions(), # noqa: E501
completion=completion,
error=repr(e),
)
else:
try:
completion = await self.retry_chain.ainvoke(
dict(
instructions=self.parser.get_format_instructions(), # noqa: E501
input=completion,
error=repr(e),
)
)
except (NotImplementedError, AttributeError):
# Case: self.parser does not have get_format_instructions # noqa: E501
completion = await self.retry_chain.ainvoke(
dict(
input=completion,
error=repr(e),
)
)
raise OutputParserException("Failed to parse")
@ -95,3 +132,7 @@ class OutputFixingParser(BaseOutputParser[T]):
@property
def _type(self) -> str:
return "output_fixing"
@property
def OutputType(self) -> type[T]:
return self.parser.OutputType

@ -10,7 +10,7 @@ from langchain.output_parsers.format_instructions import (
)
class PandasDataFrameOutputParser(BaseOutputParser):
class PandasDataFrameOutputParser(BaseOutputParser[Dict[str, Any]]):
"""Parse an output using Pandas DataFrame format."""
"""The Pandas DataFrame to parse."""

@ -6,7 +6,7 @@ from typing import Dict, List, Optional
from langchain_core.output_parsers import BaseOutputParser
class RegexParser(BaseOutputParser):
class RegexParser(BaseOutputParser[Dict[str, str]]):
"""Parse the output of an LLM call using a regex."""
@classmethod

@ -6,7 +6,7 @@ from typing import Dict, Optional
from langchain_core.output_parsers import BaseOutputParser
class RegexDictParser(BaseOutputParser):
class RegexDictParser(BaseOutputParser[Dict[str, str]]):
"""Parse the output of an LLM call into a Dictionary using a regex."""
regex_pattern: str = r"{}:\s?([^.'\n']*)\.?" # : :meta private:

@ -1,12 +1,13 @@
from __future__ import annotations
from typing import Any, TypeVar
from typing import Any, TypeVar, Union
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompt_values import PromptValue
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.runnables import RunnableSerializable
NAIVE_COMPLETION_RETRY = """Prompt:
{prompt}
@ -43,10 +44,12 @@ class RetryOutputParser(BaseOutputParser[T]):
parser: BaseOutputParser[T]
"""The parser to use to parse the output."""
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Any
"""The LLMChain to use to retry the completion."""
retry_chain: Union[RunnableSerializable, Any]
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
max_retries: int = 1
"""The maximum number of times to retry the parse."""
legacy: bool = True
"""Whether to use the run or arun method of the retry_chain."""
@classmethod
def from_llm(
@ -67,9 +70,7 @@ class RetryOutputParser(BaseOutputParser[T]):
Returns:
RetryOutputParser
"""
from langchain.chains.llm import LLMChain
chain = LLMChain(llm=llm, prompt=prompt)
chain = prompt | llm
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
@ -92,9 +93,19 @@ class RetryOutputParser(BaseOutputParser[T]):
raise e
else:
retries += 1
completion = self.retry_chain.run(
prompt=prompt_value.to_string(), completion=completion
)
if self.legacy and hasattr(self.retry_chain, "run"):
completion = self.retry_chain.run(
prompt=prompt_value.to_string(),
completion=completion,
error=repr(e),
)
else:
completion = self.retry_chain.invoke(
dict(
prompt=prompt_value.to_string(),
input=completion,
)
)
raise OutputParserException("Failed to parse")
@ -118,9 +129,19 @@ class RetryOutputParser(BaseOutputParser[T]):
raise e
else:
retries += 1
completion = await self.retry_chain.arun(
prompt=prompt_value.to_string(), completion=completion
)
if self.legacy and hasattr(self.retry_chain, "arun"):
completion = await self.retry_chain.arun(
prompt=prompt_value.to_string(),
completion=completion,
error=repr(e),
)
else:
completion = await self.retry_chain.ainvoke(
dict(
prompt=prompt_value.to_string(),
input=completion,
)
)
raise OutputParserException("Failed to parse")
@ -136,6 +157,10 @@ class RetryOutputParser(BaseOutputParser[T]):
def _type(self) -> str:
return "retry"
@property
def OutputType(self) -> type[T]:
return self.parser.OutputType
class RetryWithErrorOutputParser(BaseOutputParser[T]):
"""Wrap a parser and try to fix parsing errors.
@ -149,11 +174,13 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
parser: BaseOutputParser[T]
"""The parser to use to parse the output."""
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains
retry_chain: Any
"""The LLMChain to use to retry the completion."""
# Should be an LLMChain but we want to avoid top-level imports from langchain.chains # noqa: E501
retry_chain: Union[RunnableSerializable, Any]
"""The RunnableSerializable to use to retry the completion (Legacy: LLMChain)."""
max_retries: int = 1
"""The maximum number of times to retry the parse."""
legacy: bool = True
"""Whether to use the run or arun method of the retry_chain."""
@classmethod
def from_llm(
@ -174,12 +201,10 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
Returns:
A RetryWithErrorOutputParser.
"""
from langchain.chains.llm import LLMChain
chain = LLMChain(llm=llm, prompt=prompt)
chain = prompt | llm
return cls(parser=parser, retry_chain=chain, max_retries=max_retries)
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T:
def parse_with_prompt(self, completion: str, prompt_value: PromptValue) -> T: # noqa: E501
retries = 0
while retries <= self.max_retries:
@ -190,11 +215,20 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
raise e
else:
retries += 1
completion = self.retry_chain.run(
prompt=prompt_value.to_string(),
completion=completion,
error=repr(e),
)
if self.legacy and hasattr(self.retry_chain, "run"):
completion = self.retry_chain.run(
prompt=prompt_value.to_string(),
completion=completion,
error=repr(e),
)
else:
completion = self.retry_chain.invoke(
dict(
input=completion,
prompt=prompt_value.to_string(),
error=repr(e),
)
)
raise OutputParserException("Failed to parse")
@ -209,11 +243,20 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
raise e
else:
retries += 1
completion = await self.retry_chain.arun(
prompt=prompt_value.to_string(),
completion=completion,
error=repr(e),
)
if self.legacy and hasattr(self.retry_chain, "arun"):
completion = await self.retry_chain.arun(
prompt=prompt_value.to_string(),
completion=completion,
error=repr(e),
)
else:
completion = await self.retry_chain.ainvoke(
dict(
prompt=prompt_value.to_string(),
input=completion,
error=repr(e),
)
)
raise OutputParserException("Failed to parse")
@ -228,3 +271,7 @@ class RetryWithErrorOutputParser(BaseOutputParser[T]):
@property
def _type(self) -> str:
return "retry_with_error"
@property
def OutputType(self) -> type[T]:
return self.parser.OutputType

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, List
from typing import Any, Dict, List
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.output_parsers.json import parse_and_check_json_markdown
@ -31,7 +31,7 @@ def _get_sub_string(schema: ResponseSchema) -> str:
)
class StructuredOutputParser(BaseOutputParser):
class StructuredOutputParser(BaseOutputParser[Dict[str, Any]]):
"""Parse the output of an LLM call to a structured output."""
response_schemas: List[ResponseSchema]
@ -92,7 +92,7 @@ class StructuredOutputParser(BaseOutputParser):
else:
return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str)
def parse(self, text: str) -> Any:
def parse(self, text: str) -> Dict[str, Any]:
expected_keys = [rs.name for rs in self.response_schemas]
return parse_and_check_json_markdown(text, expected_keys)

@ -60,3 +60,7 @@ class YamlOutputParser(BaseOutputParser[T]):
@property
def _type(self) -> str:
return "yaml"
@property
def OutputType(self) -> Type[T]:
return self.pydantic_object

@ -39,3 +39,8 @@ def test_boolean_output_parser_parse() -> None:
# Bad input
with pytest.raises(ValueError):
parser.parse("BOOM")
def test_boolean_output_parser_output_type() -> None:
"""Test the output type of the boolean output parser is a boolean."""
assert BooleanOutputParser().OutputType == bool

@ -1,4 +1,6 @@
"""Test in memory docstore."""
from typing import Any, Dict
from langchain.output_parsers.combining import CombiningOutputParser
from langchain.output_parsers.regex import RegexParser
from langchain.output_parsers.structured import ResponseSchema, StructuredOutputParser
@ -43,3 +45,27 @@ def test_combining_dict_result() -> None:
combining_parser = CombiningOutputParser(parsers=parsers)
result_dict = combining_parser.parse(DEF_README)
assert DEF_EXPECTED_RESULT == result_dict
def test_combining_output_parser_output_type() -> None:
"""Test combining output parser output type is Dict[str, Any]."""
parsers = [
StructuredOutputParser(
response_schemas=[
ResponseSchema(
name="answer", description="answer to the user's question"
),
ResponseSchema(
name="source",
description="source used to answer the user's question",
),
]
),
RegexParser(
regex=r"Confidence: (A|B|C), Explanation: (.*)",
output_keys=["confidence", "explanation"],
default_output_key="noConfidence",
),
]
combining_parser = CombiningOutputParser(parsers=parsers)
assert combining_parser.OutputType is Dict[str, Any]

@ -30,3 +30,8 @@ def test_enum_output_parser_parse() -> None:
assert False, "Should have raised OutputParserException"
except OutputParserException:
pass
def test_enum_output_parser_output_type() -> None:
"""Test the output type of the enum output parser is the expected enum."""
assert EnumOutputParser(enum=Colors).OutputType is Colors

@ -0,0 +1,121 @@
from typing import Any
import pytest
from langchain_core.exceptions import OutputParserException
from langchain_core.runnables import RunnablePassthrough
from langchain.output_parsers.boolean import BooleanOutputParser
from langchain.output_parsers.datetime import DatetimeOutputParser
from langchain.output_parsers.fix import BaseOutputParser, OutputFixingParser
class SuccessfulParseAfterRetries(BaseOutputParser[str]):
parse_count: int = 0 # Number of times parse has been called
attemp_count_before_success: int # Number of times to fail before succeeding # noqa
def parse(self, *args: Any, **kwargs: Any) -> str:
self.parse_count += 1
if self.parse_count <= self.attemp_count_before_success:
raise OutputParserException("error")
return "parsed"
class SuccessfulParseAfterRetriesWithGetFormatInstructions(SuccessfulParseAfterRetries): # noqa
def get_format_instructions(self) -> str:
return "instructions"
@pytest.mark.parametrize(
"base_parser",
[
SuccessfulParseAfterRetries(attemp_count_before_success=5),
SuccessfulParseAfterRetriesWithGetFormatInstructions(
attemp_count_before_success=5
), # noqa: E501
],
)
def test_output_fixing_parser_parse(
base_parser: SuccessfulParseAfterRetries,
) -> None:
# preparation
n: int = (
base_parser.attemp_count_before_success
) # Success on the (n+1)-th attempt # noqa
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = OutputFixingParser(
parser=base_parser,
max_retries=n, # n times to retry, that is, (n+1) times call
retry_chain=RunnablePassthrough(),
legacy=False,
)
# test
assert parser.parse("completion") == "parsed"
assert base_parser.parse_count == n + 1
# TODO: test whether "instructions" is passed to the retry_chain
@pytest.mark.parametrize(
"base_parser",
[
SuccessfulParseAfterRetries(attemp_count_before_success=5),
SuccessfulParseAfterRetriesWithGetFormatInstructions(
attemp_count_before_success=5
), # noqa: E501
],
)
async def test_output_fixing_parser_aparse(
base_parser: SuccessfulParseAfterRetries,
) -> None:
n: int = (
base_parser.attemp_count_before_success
) # Success on the (n+1)-th attempt # noqa
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = OutputFixingParser(
parser=base_parser,
max_retries=n, # n times to retry, that is, (n+1) times call
retry_chain=RunnablePassthrough(),
legacy=False,
)
assert (await parser.aparse("completion")) == "parsed"
assert base_parser.parse_count == n + 1
# TODO: test whether "instructions" is passed to the retry_chain
def test_output_fixing_parser_parse_fail() -> None:
n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = OutputFixingParser(
parser=base_parser,
max_retries=n - 1, # n-1 times to retry, that is, n times call
retry_chain=RunnablePassthrough(),
legacy=False,
)
with pytest.raises(OutputParserException):
parser.parse("completion")
assert base_parser.parse_count == n
async def test_output_fixing_parser_aparse_fail() -> None:
n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = OutputFixingParser(
parser=base_parser,
max_retries=n - 1, # n-1 times to retry, that is, n times call
retry_chain=RunnablePassthrough(),
legacy=False,
)
with pytest.raises(OutputParserException):
await parser.aparse("completion")
assert base_parser.parse_count == n
@pytest.mark.parametrize(
"base_parser",
[
BooleanOutputParser(),
DatetimeOutputParser(),
],
)
def test_output_fixing_parser_output_type(base_parser: BaseOutputParser) -> None: # noqa: E501
parser = OutputFixingParser(parser=base_parser, retry_chain=RunnablePassthrough()) # noqa: E501
assert parser.OutputType is base_parser.OutputType

@ -1,4 +1,6 @@
"""Test PandasDataframeParser"""
from typing import Any, Dict
import pandas as pd
from langchain_core.exceptions import OutputParserException
@ -108,3 +110,8 @@ def test_pandas_output_parser_invalid_special_op() -> None:
assert False, "Should have raised OutputParserException"
except OutputParserException:
assert True
def test_pandas_output_parser_output_type() -> None:
"""Test the output type of the pandas dataframe output parser is a pandas dataframe.""" # noqa: E501
assert parser.OutputType is Dict[str, Any]

@ -60,6 +60,7 @@ def test_pydantic_output_parser() -> None:
result = pydantic_parser.parse(DEF_RESULT)
print("parse_result:", result) # noqa: T201
assert DEF_EXPECTED_RESULT == result
assert pydantic_parser.OutputType is TestModel
def test_pydantic_output_parser_fail() -> None:

@ -0,0 +1,38 @@
from typing import Dict
from langchain.output_parsers.regex import RegexParser
# NOTE: The almost same constant variables in ./test_combining_parser.py
DEF_EXPECTED_RESULT = {
"confidence": "A",
"explanation": "Paris is the capital of France according to Wikipedia.",
}
DEF_README = """```json
{
"answer": "Paris",
"source": "https://en.wikipedia.org/wiki/France"
}
```
//Confidence: A, Explanation: Paris is the capital of France according to Wikipedia."""
def test_regex_parser_parse() -> None:
"""Test regex parser parse."""
parser = RegexParser(
regex=r"Confidence: (A|B|C), Explanation: (.*)",
output_keys=["confidence", "explanation"],
default_output_key="noConfidence",
)
assert DEF_EXPECTED_RESULT == parser.parse(DEF_README)
def test_regex_parser_output_type() -> None:
"""Test regex parser output type is Dict[str, str]."""
parser = RegexParser(
regex=r"Confidence: (A|B|C), Explanation: (.*)",
output_keys=["confidence", "explanation"],
default_output_key="noConfidence",
)
assert parser.OutputType is Dict[str, str]

@ -1,5 +1,7 @@
"""Test in memory docstore."""
from typing import Dict
from langchain.output_parsers.regex_dict import RegexDictParser
DEF_EXPECTED_RESULT = {"action": "Search", "action_input": "How to use this class?"}
@ -36,3 +38,11 @@ def test_regex_dict_result() -> None:
result_dict = regex_dict_parser.parse(DEF_README)
print("parse_result:", result_dict) # noqa: T201
assert DEF_EXPECTED_RESULT == result_dict
def test_regex_dict_output_type() -> None:
"""Test regex dict output type."""
regex_dict_parser = RegexDictParser(
output_key_to_format=DEF_OUTPUT_KEY_TO_FORMAT, no_update_value="N/A"
)
assert regex_dict_parser.OutputType is Dict[str, str]

@ -0,0 +1,196 @@
from typing import Any
import pytest
from langchain_core.prompt_values import StringPromptValue
from langchain_core.runnables import RunnablePassthrough
from langchain.output_parsers.boolean import BooleanOutputParser
from langchain.output_parsers.datetime import DatetimeOutputParser
from langchain.output_parsers.retry import (
BaseOutputParser,
OutputParserException,
RetryOutputParser,
RetryWithErrorOutputParser,
)
class SuccessfulParseAfterRetries(BaseOutputParser[str]):
parse_count: int = 0 # Number of times parse has been called
attemp_count_before_success: int # Number of times to fail before succeeding # noqa
error_msg: str = "error"
def parse(self, *args: Any, **kwargs: Any) -> str:
self.parse_count += 1
if self.parse_count <= self.attemp_count_before_success:
raise OutputParserException(self.error_msg)
return "parsed"
def test_retry_output_parser_parse_with_prompt() -> None:
n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryOutputParser(
parser=base_parser,
retry_chain=RunnablePassthrough(),
max_retries=n, # n times to retry, that is, (n+1) times call
legacy=False,
)
actual = parser.parse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501
assert actual == "parsed"
assert base_parser.parse_count == n + 1
def test_retry_output_parser_parse_with_prompt_fail() -> None:
n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryOutputParser(
parser=base_parser,
retry_chain=RunnablePassthrough(),
max_retries=n - 1, # n-1 times to retry, that is, n times call
legacy=False,
)
with pytest.raises(OutputParserException):
parser.parse_with_prompt("completion", StringPromptValue(text="dummy"))
assert base_parser.parse_count == n
async def test_retry_output_parser_aparse_with_prompt() -> None:
n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryOutputParser(
parser=base_parser,
retry_chain=RunnablePassthrough(),
max_retries=n, # n times to retry, that is, (n+1) times call
legacy=False,
)
actual = await parser.aparse_with_prompt(
"completion", StringPromptValue(text="dummy")
)
assert actual == "parsed"
assert base_parser.parse_count == n + 1
async def test_retry_output_parser_aparse_with_prompt_fail() -> None:
n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryOutputParser(
parser=base_parser,
retry_chain=RunnablePassthrough(),
max_retries=n - 1, # n-1 times to retry, that is, n times call
legacy=False,
)
with pytest.raises(OutputParserException):
await parser.aparse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501
assert base_parser.parse_count == n
@pytest.mark.parametrize(
"base_parser",
[
BooleanOutputParser(),
DatetimeOutputParser(),
],
)
def test_retry_output_parser_output_type(base_parser: BaseOutputParser) -> None:
parser = RetryOutputParser(
parser=base_parser,
retry_chain=RunnablePassthrough(),
legacy=False,
)
assert parser.OutputType is base_parser.OutputType
def test_retry_output_parser_parse_is_not_implemented() -> None:
parser = RetryOutputParser(
parser=BooleanOutputParser(),
retry_chain=RunnablePassthrough(),
legacy=False,
)
with pytest.raises(NotImplementedError):
parser.parse("completion")
def test_retry_with_error_output_parser_parse_with_prompt() -> None:
n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryWithErrorOutputParser(
parser=base_parser,
retry_chain=RunnablePassthrough(),
max_retries=n, # n times to retry, that is, (n+1) times call
legacy=False,
)
actual = parser.parse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501
assert actual == "parsed"
assert base_parser.parse_count == n + 1
def test_retry_with_error_output_parser_parse_with_prompt_fail() -> None:
n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryWithErrorOutputParser(
parser=base_parser,
retry_chain=RunnablePassthrough(),
max_retries=n - 1, # n-1 times to retry, that is, n times call
legacy=False,
)
with pytest.raises(OutputParserException):
parser.parse_with_prompt("completion", StringPromptValue(text="dummy"))
assert base_parser.parse_count == n
async def test_retry_with_error_output_parser_aparse_with_prompt() -> None:
n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryWithErrorOutputParser(
parser=base_parser,
retry_chain=RunnablePassthrough(),
max_retries=n, # n times to retry, that is, (n+1) times call
legacy=False,
)
actual = await parser.aparse_with_prompt(
"completion", StringPromptValue(text="dummy")
)
assert actual == "parsed"
assert base_parser.parse_count == n + 1
async def test_retry_with_error_output_parser_aparse_with_prompt_fail() -> None: # noqa: E501
n: int = 5 # Success on the (n+1)-th attempt
base_parser = SuccessfulParseAfterRetries(attemp_count_before_success=n)
parser = RetryWithErrorOutputParser(
parser=base_parser,
retry_chain=RunnablePassthrough(),
max_retries=n - 1, # n-1 times to retry, that is, n times call
legacy=False,
)
with pytest.raises(OutputParserException):
await parser.aparse_with_prompt("completion", StringPromptValue(text="dummy")) # noqa: E501
assert base_parser.parse_count == n
@pytest.mark.parametrize(
"base_parser",
[
BooleanOutputParser(),
DatetimeOutputParser(),
],
)
def test_retry_with_error_output_parser_output_type(
base_parser: BaseOutputParser,
) -> None:
parser = RetryWithErrorOutputParser(
parser=base_parser,
retry_chain=RunnablePassthrough(),
legacy=False,
)
assert parser.OutputType is base_parser.OutputType
def test_retry_with_error_output_parser_parse_is_not_implemented() -> None:
parser = RetryWithErrorOutputParser(
parser=BooleanOutputParser(),
retry_chain=RunnablePassthrough(),
legacy=False,
)
with pytest.raises(NotImplementedError):
parser.parse("completion")

@ -1,9 +1,12 @@
from typing import Any, Dict
from langchain_core.exceptions import OutputParserException
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
def test_parse() -> None:
"""Test parsing structured output."""
response_schemas = [
ResponseSchema(name="name", description="desc"),
ResponseSchema(name="age", description="desc"),
@ -24,3 +27,13 @@ def test_parse() -> None:
pass # Test passes if OutputParserException is raised
else:
assert False, f"Expected OutputParserException, but got {parser.parse(text)}"
def test_output_type() -> None:
"""Test the output type of the structured output parser is Dict[str, Any]."""
response_schemas = [
ResponseSchema(name="name", description="desc"),
ResponseSchema(name="age", description="desc"),
]
parser = StructuredOutputParser.from_response_schemas(response_schemas)
assert parser.OutputType == Dict[str, Any]

@ -93,3 +93,9 @@ def test_yaml_output_parser_fail() -> None:
assert "Failed to parse TestModel from completion" in str(e)
else:
assert False, "Expected OutputParserException"
def test_yaml_output_parser_output_type() -> None:
"""Test YamlOutputParser OutputType."""
yaml_parser = YamlOutputParser(pydantic_object=TestModel)
assert yaml_parser.OutputType is TestModel

Loading…
Cancel
Save