Add async tests and comments

pull/11222/head
Nuno Campos 10 months ago
parent 091d8845d5
commit 3d8aa88e26

@ -205,8 +205,10 @@ class PartialFunctionsJsonOutputParser(BaseCumulativeTransformOutputParser[Any])
def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch
# This method would be called by the default implementation of `parse_result`
# but we're overriding that method so it's not needed.
def parse(self, text: str) -> Any:
pass
raise NotImplementedError()
class PartialJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):

@ -340,6 +340,8 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
diff: bool = False
def _diff(self, prev: Optional[T], next: T) -> T:
"""Convert parsed outputs into a diff format. The semantics of this are
up to the output parser."""
raise NotImplementedError()
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:

@ -1,5 +1,5 @@
import json
from typing import Any, Iterator, Tuple
from typing import Any, AsyncIterator, Iterator, Tuple
import pytest
@ -492,3 +492,51 @@ def test_partial_functions_json_output_parser_diff() -> None:
chain = input_iter | PartialFunctionsJsonOutputParser(diff=True)
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF
@pytest.mark.asyncio
async def test_partial_text_json_output_parser_async() -> None:
async def input_iter(_: Any) -> AsyncIterator[str]:
for token in STREAMED_TOKENS:
yield token
chain = input_iter | PartialJsonOutputParser()
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
@pytest.mark.asyncio
async def test_partial_functions_json_output_parser_async() -> None:
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunk]:
for token in STREAMED_TOKENS:
yield AIMessageChunk(
content="", additional_kwargs={"function_call": {"arguments": token}}
)
chain = input_iter | PartialFunctionsJsonOutputParser()
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
@pytest.mark.asyncio
async def test_partial_text_json_output_parser_diff_async() -> None:
async def input_iter(_: Any) -> AsyncIterator[str]:
for token in STREAMED_TOKENS:
yield token
chain = input_iter | PartialJsonOutputParser(diff=True)
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF
@pytest.mark.asyncio
async def test_partial_functions_json_output_parser_diff_async() -> None:
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunk]:
for token in STREAMED_TOKENS:
yield AIMessageChunk(
content="", additional_kwargs={"function_call": {"arguments": token}}
)
chain = input_iter | PartialFunctionsJsonOutputParser(diff=True)
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF

Loading…
Cancel
Save