Add a streaming json parser (#11193)

<img width="1728" alt="Screenshot 2023-09-28 at 20 15 01"
src="https://github.com/langchain-ai/langchain/assets/56902/ed0644c3-6db7-41b9-9543-e34fce46d3e5">


<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
pull/11235/head
Nuno Campos 1 year ago committed by GitHub
commit 1ddf9f74b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -75,14 +75,16 @@ class OpenAIFunctionsAgentOutputParser(AgentOutputParser):
return_values={"output": message.content}, log=message.content
)
def parse_result(self, result: List[Generation]) -> Union[AgentAction, AgentFinish]:
def parse_result(
self, result: List[Generation], *, partial: bool = False
) -> Union[AgentAction, AgentFinish]:
if not isinstance(result[0], ChatGeneration):
raise ValueError("This output parser only works on ChatGeneration output")
message = result[0].message
return self._parse_ai_message(message)
async def aparse_result(
self, result: List[Generation]
self, result: List[Generation], *, partial: bool = False
) -> Union[AgentAction, AgentFinish]:
return await asyncio.get_running_loop().run_in_executor(
None, self.parse_result, result

@ -3,9 +3,14 @@ from __future__ import annotations
import json
import re
from json import JSONDecodeError
from typing import Any, List
from typing import Any, Callable, List, Optional
from langchain.schema import BaseOutputParser, OutputParserException
import jsonpatch
from langchain.schema.output_parser import (
BaseCumulativeTransformOutputParser,
OutputParserException,
)
def _replace_new_line(match: re.Match[str]) -> str:
@ -38,7 +43,70 @@ def _custom_parser(multiline_string: str) -> str:
return multiline_string
def parse_json_markdown(json_string: str) -> dict:
# Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py
# MIT License
def parse_partial_json(s: str, *, strict: bool = False) -> Any:
# Attempt to parse the string as-is.
try:
return json.loads(s, strict=strict)
except json.JSONDecodeError:
pass
# Initialize variables.
new_s = ""
stack = []
is_inside_string = False
escaped = False
# Process each character in the string one at a time.
for char in s:
if is_inside_string:
if char == '"' and not escaped:
is_inside_string = False
elif char == "\n" and not escaped:
char = "\\n" # Replace the newline character with the escape sequence.
elif char == "\\":
escaped = not escaped
else:
escaped = False
else:
if char == '"':
is_inside_string = True
escaped = False
elif char == "{":
stack.append("}")
elif char == "[":
stack.append("]")
elif char == "}" or char == "]":
if stack and stack[-1] == char:
stack.pop()
else:
# Mismatched closing character; the input is malformed.
return None
# Append the processed character to the new string.
new_s += char
# If we're still inside a string at the end of processing,
# we need to close the string.
if is_inside_string:
new_s += '"'
# Close any remaining open structures in the reverse order that they were opened.
for closing_char in reversed(stack):
new_s += closing_char
# Attempt to parse the modified string as JSON.
try:
return json.loads(new_s, strict=strict)
except json.JSONDecodeError:
# If we still can't parse the string as JSON, return None to indicate failure.
return None
def parse_json_markdown(
json_string: str, *, parser: Callable[[str], Any] = json.loads
) -> dict:
"""
Parse a JSON string from a Markdown string.
@ -65,7 +133,7 @@ def parse_json_markdown(json_string: str) -> dict:
json_str = _custom_parser(json_str)
# Parse the JSON string into a Python dictionary
parsed = json.loads(json_str)
parsed = parser(json_str)
return parsed
@ -95,13 +163,23 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
return json_obj
class SimpleJsonOutputParser(BaseOutputParser[Any]):
"""Parse the output of an LLM call to a JSON object."""
class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
"""Parse the output of an LLM call to a JSON object.
When used in streaming mode, it will yield partial JSON objects containing
all the keys that have been returned so far.
In streaming, if `diff` is set to `True`, yields JSONPatch operations
describing the difference between the previous and the current object.
"""
def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch
def parse(self, text: str) -> Any:
text = text.strip()
try:
return json.loads(text)
return parse_json_markdown(text.strip(), parser=parse_partial_json)
except JSONDecodeError as e:
raise OutputParserException(f"Invalid json output: {text}") from e

@ -1,14 +1,20 @@
import copy
import json
from typing import Any, Dict, List, Type, Union
from typing import Any, Dict, List, Optional, Type, Union
import jsonpatch
from langchain.output_parsers.json import parse_partial_json
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.schema import (
ChatGeneration,
Generation,
OutputParserException,
)
from langchain.schema.output_parser import BaseGenerationOutputParser
from langchain.schema.output_parser import (
BaseCumulativeTransformOutputParser,
BaseGenerationOutputParser,
)
class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
@ -17,7 +23,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
args_only: bool = True
"""Whether to only return the arguments to the function call."""
def parse_result(self, result: List[Generation]) -> Any:
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
@ -34,7 +40,7 @@ class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
return func_call
class JsonOutputFunctionsParser(OutputFunctionsParser):
class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
"""Parse an output as the Json object."""
strict: bool = False
@ -45,25 +51,72 @@ class JsonOutputFunctionsParser(OutputFunctionsParser):
Useful when the parsed output may include unicode characters or new lines.
"""
def parse_result(self, result: List[Generation]) -> Any:
function_call_info = super().parse_result(result)
if self.args_only:
try:
return json.loads(function_call_info, strict=self.strict)
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
else:
try:
function_call_info["arguments"] = json.loads(
function_call_info["arguments"], strict=self.strict
)
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
return function_call_info
args_only: bool = True
"""Whether to only return the arguments to the function call."""
def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
if len(result) != 1:
raise OutputParserException(
f"Expected exactly one result, but got {len(result)}"
)
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
"This output parser can only be used with a chat generation."
)
message = generation.message
try:
function_call = message.additional_kwargs["function_call"]
except KeyError as exc:
if partial:
return None
else:
raise OutputParserException(f"Could not parse function call: {exc}")
try:
if partial:
if self.args_only:
return parse_partial_json(
function_call["arguments"], strict=self.strict
)
else:
return {
**function_call,
"arguments": parse_partial_json(
function_call["arguments"], strict=self.strict
),
}
else:
if self.args_only:
try:
return json.loads(
function_call["arguments"], strict=self.strict
)
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
else:
try:
return {
**function_call,
"arguments": json.loads(
function_call["arguments"], strict=self.strict
),
}
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
except KeyError:
return None
# This method would be called by the default implementation of `parse_result`
# but we're overriding that method so it's not needed.
def parse(self, text: str) -> Any:
raise NotImplementedError()
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
@ -72,7 +125,7 @@ class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
key_name: str
"""The name of the key to return."""
def parse_result(self, result: List[Generation]) -> Any:
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
res = super().parse_result(result)
return res[self.key_name]
@ -97,7 +150,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
)
return values
def parse_result(self, result: List[Generation]) -> Any:
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
_result = super().parse_result(result)
if self.args_only:
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
@ -114,6 +167,6 @@ class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
attr_name: str
"""The name of the attribute to return."""
def parse_result(self, result: List[Generation]) -> Any:
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
result = super().parse_result(result)
return getattr(result, self.attr_name)

@ -17,8 +17,13 @@ from typing import (
from typing_extensions import get_args
from langchain.load.serializable import Serializable
from langchain.schema.messages import AnyMessage, BaseMessage
from langchain.schema.output import ChatGeneration, Generation
from langchain.schema.messages import AnyMessage, BaseMessage, BaseMessageChunk
from langchain.schema.output import (
ChatGeneration,
ChatGenerationChunk,
Generation,
GenerationChunk,
)
from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable, RunnableConfig
@ -29,7 +34,7 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC):
"""Abstract base class for parsing the outputs of a model."""
@abstractmethod
def parse_result(self, result: List[Generation]) -> T:
def parse_result(self, result: List[Generation], *, partial: bool = False) -> T:
"""Parse a list of candidate model Generations into a specific format.
Args:
@ -40,7 +45,9 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC):
Structured output.
"""
async def aparse_result(self, result: List[Generation]) -> T:
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
) -> T:
"""Parse a list of candidate model Generations into a specific format.
Args:
@ -200,7 +207,7 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
run_type="parser",
)
def parse_result(self, result: List[Generation]) -> T:
def parse_result(self, result: List[Generation], *, partial: bool = False) -> T:
"""Parse a list of candidate model Generations into a specific format.
The return value is parsed from only the first Generation in the result, which
@ -226,7 +233,9 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
Structured output.
"""
async def aparse_result(self, result: List[Generation]) -> T:
async def aparse_result(
self, result: List[Generation], *, partial: bool = False
) -> T:
"""Parse a list of candidate model Generations into a specific format.
The return value is parsed from only the first Generation in the result, which
@ -329,6 +338,74 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
yield chunk
class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
"""Base class for an output parser that can handle streaming input."""
diff: bool = False
"""In streaming mode, whether to yield diffs between the previous and current
parsed output, or just the current parsed output.
"""
def _diff(self, prev: Optional[T], next: T) -> T:
"""Convert parsed outputs into a diff format. The semantics of this are
up to the output parser."""
raise NotImplementedError()
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
prev_parsed = None
acc_gen = None
for chunk in input:
if isinstance(chunk, BaseMessageChunk):
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict())
)
else:
chunk_gen = GenerationChunk(text=chunk)
if acc_gen is None:
acc_gen = chunk_gen
else:
acc_gen += chunk_gen
parsed = self.parse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed:
if self.diff:
yield self._diff(prev_parsed, parsed)
else:
yield parsed
prev_parsed = parsed
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[T]:
prev_parsed = None
acc_gen = None
async for chunk in input:
if isinstance(chunk, BaseMessageChunk):
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict())
)
else:
chunk_gen = GenerationChunk(text=chunk)
if acc_gen is None:
acc_gen = chunk_gen
else:
acc_gen += chunk_gen
parsed = self.parse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed:
if self.diff:
yield self._diff(prev_parsed, parsed)
else:
yield parsed
prev_parsed = parsed
class StrOutputParser(BaseTransformOutputParser[str]):
"""OutputParser that parses LLMResult into the top likely string."""

@ -1,55 +0,0 @@
"""Test the BaseOutputParser class and its sub-classes."""
from abc import ABC
from collections import defaultdict
from typing import List, Optional, Set, Type
import pytest
from langchain.schema import BaseOutputParser
def non_abstract_subclasses(
cls: Type[ABC], to_skip: Optional[Set] = None
) -> List[Type]:
"""Recursively find all non-abstract subclasses of a class."""
_to_skip = to_skip or set()
subclasses = []
for subclass in cls.__subclasses__():
if not getattr(subclass, "__abstractmethods__", None):
if subclass.__name__ not in _to_skip:
subclasses.append(subclass)
subclasses.extend(non_abstract_subclasses(subclass, to_skip=_to_skip))
return subclasses
# parsers defined not in the output_parsers module:
_PARSERS_TO_SKIP = {
"FakeOutputParser",
"BaseOutputParser",
"FinishedOutputParser",
"RouterOutputParser",
"TrajectoryRunEvalOutputParser",
}
_NON_ABSTRACT_PARSERS = non_abstract_subclasses(
BaseOutputParser, to_skip=_PARSERS_TO_SKIP
)
@pytest.mark.parametrize("cls", _NON_ABSTRACT_PARSERS)
def test_subclass_implements_type(cls: Type[BaseOutputParser]) -> None:
try:
cls._type
except NotImplementedError:
pytest.fail(f"_type property is not implemented in class {cls.__name__}")
def test_all_subclasses_implement_unique_type() -> None:
types = defaultdict(list)
for cls in _NON_ABSTRACT_PARSERS:
try:
types[cls._type].append(cls.__name__)
except NotImplementedError:
# This is handled in the previous test
pass
dups = {t: names for t, names in types.items() if len(names) > 1}
assert not dups, f"Duplicate types: {dups}"

@ -1,6 +1,15 @@
import json
from typing import Any, AsyncIterator, Iterator, Tuple
import pytest
from langchain.output_parsers.json import parse_json_markdown
from langchain.output_parsers.json import (
SimpleJsonOutputParser,
parse_json_markdown,
parse_partial_json,
)
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.schema.messages import AIMessageChunk
GOOD_JSON = """```json
{
@ -183,3 +192,351 @@ def test_parse_json_with_python_dict() -> None:
"action": "Final Answer",
"action_input": {"foo": "bar", "bar": "foo"},
}
TEST_CASES_PARTIAL = [
('{"foo": "bar", "bar": "foo"}', '{"foo": "bar", "bar": "foo"}'),
('{"foo": "bar", "bar": "foo', '{"foo": "bar", "bar": "foo"}'),
('{"foo": "bar", "bar": "foo}', '{"foo": "bar", "bar": "foo}"}'),
('{"foo": "bar", "bar": "foo[', '{"foo": "bar", "bar": "foo["}'),
('{"foo": "bar", "bar": "foo\\"', '{"foo": "bar", "bar": "foo\\""}'),
]
@pytest.mark.parametrize("json_strings", TEST_CASES_PARTIAL)
def test_parse_partial_json(json_strings: Tuple[str, str]) -> None:
case, expected = json_strings
parsed = parse_partial_json(case)
assert parsed == json.loads(expected)
STREAMED_TOKENS = """
{
"
setup
":
"
Why
did
the
bears
start
a
band
called
Bears
Bears
Bears
?
"
,
"
punchline
":
"
Because
they
wanted
to
play
bear
-y
good
music
!
"
,
"
audience
":
[
"
Haha
"
,
"
So
funny
"
]
}
""".splitlines()
EXPECTED_STREAMED_JSON = [
{},
{"setup": ""},
{"setup": "Why"},
{"setup": "Why did"},
{"setup": "Why did the"},
{"setup": "Why did the bears"},
{"setup": "Why did the bears start"},
{"setup": "Why did the bears start a"},
{"setup": "Why did the bears start a band"},
{"setup": "Why did the bears start a band called"},
{"setup": "Why did the bears start a band called Bears"},
{"setup": "Why did the bears start a band called Bears Bears"},
{"setup": "Why did the bears start a band called Bears Bears Bears"},
{"setup": "Why did the bears start a band called Bears Bears Bears ?"},
{
"setup": "Why did the bears start a band called Bears Bears Bears ?",
"punchline": "",
},
{
"setup": "Why did the bears start a band called Bears Bears Bears ?",
"punchline": "Because",
},
{
"setup": "Why did the bears start a band called Bears Bears Bears ?",
"punchline": "Because they",
},
{
"setup": "Why did the bears start a band called Bears Bears Bears ?",
"punchline": "Because they wanted",
},
{
"setup": "Why did the bears start a band called Bears Bears Bears ?",
"punchline": "Because they wanted to",
},
{
"setup": "Why did the bears start a band called Bears Bears Bears ?",
"punchline": "Because they wanted to play",
},
{
"setup": "Why did the bears start a band called Bears Bears Bears ?",
"punchline": "Because they wanted to play bear",
},
{
"setup": "Why did the bears start a band called Bears Bears Bears ?",
"punchline": "Because they wanted to play bear -y",
},
{
"setup": "Why did the bears start a band called Bears Bears Bears ?",
"punchline": "Because they wanted to play bear -y good",
},
{
"setup": "Why did the bears start a band called Bears Bears Bears ?",
"punchline": "Because they wanted to play bear -y good music",
},
{
"setup": "Why did the bears start a band called Bears Bears Bears ?",
"punchline": "Because they wanted to play bear -y good music !",
},
{
"punchline": "Because they wanted to play bear -y good music !",
"setup": "Why did the bears start a band called Bears Bears Bears ?",
"audience": [],
},
{
"punchline": "Because they wanted to play bear -y good music !",
"setup": "Why did the bears start a band called Bears Bears Bears ?",
"audience": [""],
},
{
"punchline": "Because they wanted to play bear -y good music !",
"setup": "Why did the bears start a band called Bears Bears Bears ?",
"audience": ["Haha"],
},
{
"punchline": "Because they wanted to play bear -y good music !",
"setup": "Why did the bears start a band called Bears Bears Bears ?",
"audience": ["Haha", ""],
},
{
"punchline": "Because they wanted to play bear -y good music !",
"setup": "Why did the bears start a band called Bears Bears Bears ?",
"audience": ["Haha", "So"],
},
{
"punchline": "Because they wanted to play bear -y good music !",
"setup": "Why did the bears start a band called Bears Bears Bears ?",
"audience": ["Haha", "So funny"],
},
]
EXPECTED_STREAMED_JSON_DIFF = [
[{"op": "replace", "path": "", "value": {}}],
[{"op": "add", "path": "/setup", "value": ""}],
[{"op": "replace", "path": "/setup", "value": "Why"}],
[{"op": "replace", "path": "/setup", "value": "Why did"}],
[{"op": "replace", "path": "/setup", "value": "Why did the"}],
[{"op": "replace", "path": "/setup", "value": "Why did the bears"}],
[{"op": "replace", "path": "/setup", "value": "Why did the bears start"}],
[{"op": "replace", "path": "/setup", "value": "Why did the bears start a"}],
[{"op": "replace", "path": "/setup", "value": "Why did the bears start a band"}],
[
{
"op": "replace",
"path": "/setup",
"value": "Why did the bears start a band called",
}
],
[
{
"op": "replace",
"path": "/setup",
"value": "Why did the bears start a band called Bears",
}
],
[
{
"op": "replace",
"path": "/setup",
"value": "Why did the bears start a band called Bears Bears",
}
],
[
{
"op": "replace",
"path": "/setup",
"value": "Why did the bears start a band called Bears Bears Bears",
}
],
[
{
"op": "replace",
"path": "/setup",
"value": "Why did the bears start a band called Bears Bears Bears ?",
}
],
[{"op": "add", "path": "/punchline", "value": ""}],
[{"op": "replace", "path": "/punchline", "value": "Because"}],
[{"op": "replace", "path": "/punchline", "value": "Because they"}],
[{"op": "replace", "path": "/punchline", "value": "Because they wanted"}],
[{"op": "replace", "path": "/punchline", "value": "Because they wanted to"}],
[{"op": "replace", "path": "/punchline", "value": "Because they wanted to play"}],
[
{
"op": "replace",
"path": "/punchline",
"value": "Because they wanted to play bear",
}
],
[
{
"op": "replace",
"path": "/punchline",
"value": "Because they wanted to play bear -y",
}
],
[
{
"op": "replace",
"path": "/punchline",
"value": "Because they wanted to play bear -y good",
}
],
[
{
"op": "replace",
"path": "/punchline",
"value": "Because they wanted to play bear -y good music",
}
],
[
{
"op": "replace",
"path": "/punchline",
"value": "Because they wanted to play bear -y good music !",
}
],
[{"op": "add", "path": "/audience", "value": []}],
[{"op": "add", "path": "/audience/0", "value": ""}],
[{"op": "replace", "path": "/audience/0", "value": "Haha"}],
[{"op": "add", "path": "/audience/1", "value": ""}],
[{"op": "replace", "path": "/audience/1", "value": "So"}],
[{"op": "replace", "path": "/audience/1", "value": "So funny"}],
]
def test_partial_text_json_output_parser() -> None:
def input_iter(_: Any) -> Iterator[str]:
for token in STREAMED_TOKENS:
yield token
chain = input_iter | SimpleJsonOutputParser()
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
def test_partial_functions_json_output_parser() -> None:
def input_iter(_: Any) -> Iterator[AIMessageChunk]:
for token in STREAMED_TOKENS:
yield AIMessageChunk(
content="", additional_kwargs={"function_call": {"arguments": token}}
)
chain = input_iter | JsonOutputFunctionsParser()
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
def test_partial_text_json_output_parser_diff() -> None:
def input_iter(_: Any) -> Iterator[str]:
for token in STREAMED_TOKENS:
yield token
chain = input_iter | SimpleJsonOutputParser(diff=True)
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF
def test_partial_functions_json_output_parser_diff() -> None:
def input_iter(_: Any) -> Iterator[AIMessageChunk]:
for token in STREAMED_TOKENS:
yield AIMessageChunk(
content="", additional_kwargs={"function_call": {"arguments": token}}
)
chain = input_iter | JsonOutputFunctionsParser(diff=True)
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON_DIFF
@pytest.mark.asyncio
async def test_partial_text_json_output_parser_async() -> None:
async def input_iter(_: Any) -> AsyncIterator[str]:
for token in STREAMED_TOKENS:
yield token
chain = input_iter | SimpleJsonOutputParser()
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
@pytest.mark.asyncio
async def test_partial_functions_json_output_parser_async() -> None:
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunk]:
for token in STREAMED_TOKENS:
yield AIMessageChunk(
content="", additional_kwargs={"function_call": {"arguments": token}}
)
chain = input_iter | JsonOutputFunctionsParser()
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
@pytest.mark.asyncio
async def test_partial_text_json_output_parser_diff_async() -> None:
async def input_iter(_: Any) -> AsyncIterator[str]:
for token in STREAMED_TOKENS:
yield token
chain = input_iter | SimpleJsonOutputParser(diff=True)
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF
@pytest.mark.asyncio
async def test_partial_functions_json_output_parser_diff_async() -> None:
async def input_iter(_: Any) -> AsyncIterator[AIMessageChunk]:
for token in STREAMED_TOKENS:
yield AIMessageChunk(
content="", additional_kwargs={"function_call": {"arguments": token}}
)
chain = input_iter | JsonOutputFunctionsParser(diff=True)
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON_DIFF

@ -582,7 +582,9 @@ async def test_with_config(mocker: MockerFixture) -> None:
) == [5, 7]
assert len(spy.call_args_list) == 2
for i, call in enumerate(spy.call_args_list):
for i, call in enumerate(
sorted(spy.call_args_list, key=lambda x: 0 if x.args[0] == "hello" else 1)
):
assert call.args[0] == ("hello" if i == 0 else "wooorld")
if i == 0:
assert call.args[1].get("recursion_limit") == 5

Loading…
Cancel
Save