From 3d8aa88e26b8f28c32f47ef9ca3c266b7ed15975 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 29 Sep 2023 15:28:46 +0100 Subject: [PATCH] Add async tests and comments --- .../langchain/output_parsers/json.py | 4 +- .../langchain/schema/output_parser.py | 2 + .../unit_tests/output_parsers/test_json.py | 50 ++++++++++++++++++- 3 files changed, 54 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/output_parsers/json.py b/libs/langchain/langchain/output_parsers/json.py index d4423a0d2f..aafaedb67d 100644 --- a/libs/langchain/langchain/output_parsers/json.py +++ b/libs/langchain/langchain/output_parsers/json.py @@ -205,8 +205,10 @@ class PartialFunctionsJsonOutputParser(BaseCumulativeTransformOutputParser[Any]) def _diff(self, prev: Optional[Any], next: Any) -> Any: return jsonpatch.make_patch(prev, next).patch + # This method would be called by the default implementation of `parse_result` + # but we're overriding that method so it's not needed. def parse(self, text: str) -> Any: - pass + raise NotImplementedError() class PartialJsonOutputParser(BaseCumulativeTransformOutputParser[Any]): diff --git a/libs/langchain/langchain/schema/output_parser.py b/libs/langchain/langchain/schema/output_parser.py index 89a065e9ad..6d2e388893 100644 --- a/libs/langchain/langchain/schema/output_parser.py +++ b/libs/langchain/langchain/schema/output_parser.py @@ -340,6 +340,8 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): diff: bool = False def _diff(self, prev: Optional[T], next: T) -> T: + """Convert parsed outputs into a diff format. The semantics of this are + up to the output parser.""" raise NotImplementedError() def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]: diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_json.py b/libs/langchain/tests/unit_tests/output_parsers/test_json.py index 00dbd4d3b3..b9daee1a51 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_json.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_json.py @@ -1,5 +1,5 @@ import json -from typing import Any, Iterator, Tuple +from typing import Any, AsyncIterator, Iterator, Tuple import pytest @@ -492,3 +492,51 @@ def test_partial_functions_json_output_parser_diff() -> None: chain = input_iter | PartialFunctionsJsonOutputParser(diff=True) assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF + + +@pytest.mark.asyncio +async def test_partial_text_json_output_parser_async() -> None: + async def input_iter(_: Any) -> AsyncIterator[str]: + for token in STREAMED_TOKENS: + yield token + + chain = input_iter | PartialJsonOutputParser() + + assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON + + +@pytest.mark.asyncio +async def test_partial_functions_json_output_parser_async() -> None: + async def input_iter(_: Any) -> AsyncIterator[AIMessageChunk]: + for token in STREAMED_TOKENS: + yield AIMessageChunk( + content="", additional_kwargs={"function_call": {"arguments": token}} + ) + + chain = input_iter | PartialFunctionsJsonOutputParser() + + assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON + + +@pytest.mark.asyncio +async def test_partial_text_json_output_parser_diff_async() -> None: + async def input_iter(_: Any) -> AsyncIterator[str]: + for token in STREAMED_TOKENS: + yield token + + chain = input_iter | PartialJsonOutputParser(diff=True) + + assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF + + +@pytest.mark.asyncio +async def test_partial_functions_json_output_parser_diff_async() -> None: + async def input_iter(_: Any) -> AsyncIterator[AIMessageChunk]: + for token in STREAMED_TOKENS: + yield AIMessageChunk( + content="", additional_kwargs={"function_call": {"arguments": token}} + ) + + chain = input_iter | PartialFunctionsJsonOutputParser(diff=True) + + assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF