mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
core[patch]: simple fallback streaming (#16055)
This commit is contained in:
parent
4ef0ed4ddc
commit
1e29b676d5
@ -302,7 +302,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -2,6 +2,8 @@ import asyncio
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
@ -30,6 +32,7 @@ from langchain_core.runnables.utils import (
|
||||
Output,
|
||||
get_unique_config_specs,
|
||||
)
|
||||
from langchain_core.utils.aiter import py_anext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.callbacks.manager import AsyncCallbackManagerForChainRun
|
||||
@ -415,3 +418,118 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
raise sorted_handled_exceptions[0][1]
|
||||
to_return.update(handled_exceptions)
|
||||
return [output for _, output in sorted(to_return.items())] # type: ignore
|
||||
|
||||
def stream(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
""""""
|
||||
if self.exception_key is not None and not isinstance(input, dict):
|
||||
raise ValueError(
|
||||
"If 'exception_key' is specified then input must be a dictionary."
|
||||
f"However found a type of {type(input)} for input"
|
||||
)
|
||||
# setup callbacks
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
)
|
||||
first_error = None
|
||||
last_error = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
if self.exception_key and last_error is not None:
|
||||
input[self.exception_key] = last_error
|
||||
stream = runnable.stream(
|
||||
input,
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
**kwargs,
|
||||
)
|
||||
chunk = next(stream)
|
||||
except self.exceptions_to_handle as e:
|
||||
first_error = e if first_error is None else first_error
|
||||
last_error = e
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
else:
|
||||
first_error = None
|
||||
break
|
||||
if first_error:
|
||||
run_manager.on_chain_error(first_error)
|
||||
raise first_error
|
||||
|
||||
yield chunk
|
||||
output: Optional[Output] = chunk
|
||||
try:
|
||||
for chunk in stream:
|
||||
yield chunk
|
||||
try:
|
||||
output = output + chunk # type: ignore
|
||||
except TypeError:
|
||||
output = None
|
||||
except BaseException as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
run_manager.on_chain_end(output)
|
||||
|
||||
async def astream(
|
||||
self,
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Output]:
|
||||
if self.exception_key is not None and not isinstance(input, dict):
|
||||
raise ValueError(
|
||||
"If 'exception_key' is specified then input must be a dictionary."
|
||||
f"However found a type of {type(input)} for input"
|
||||
)
|
||||
# setup callbacks
|
||||
config = ensure_config(config)
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
)
|
||||
first_error = None
|
||||
last_error = None
|
||||
for runnable in self.runnables:
|
||||
try:
|
||||
if self.exception_key and last_error is not None:
|
||||
input[self.exception_key] = last_error
|
||||
stream = runnable.astream(
|
||||
input,
|
||||
patch_config(config, callbacks=run_manager.get_child()),
|
||||
**kwargs,
|
||||
)
|
||||
chunk = await cast(Awaitable[Output], py_anext(stream))
|
||||
except self.exceptions_to_handle as e:
|
||||
first_error = e if first_error is None else first_error
|
||||
last_error = e
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
else:
|
||||
first_error = None
|
||||
break
|
||||
if first_error:
|
||||
await run_manager.on_chain_error(first_error)
|
||||
raise first_error
|
||||
|
||||
yield chunk
|
||||
output: Optional[Output] = chunk
|
||||
try:
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
try:
|
||||
output = output + chunk # type: ignore
|
||||
except TypeError:
|
||||
output = None
|
||||
except BaseException as e:
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
await run_manager.on_chain_end(output)
|
||||
|
@ -1,5 +1,5 @@
|
||||
import sys
|
||||
from typing import Any
|
||||
from typing import Any, AsyncIterator, Iterator
|
||||
|
||||
import pytest
|
||||
from syrupy import SnapshotAssertion
|
||||
@ -8,6 +8,7 @@ from langchain_core.load import dumps
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableGenerator,
|
||||
RunnableLambda,
|
||||
RunnableParallel,
|
||||
RunnablePassthrough,
|
||||
@ -229,3 +230,61 @@ async def test_abatch() -> None:
|
||||
|
||||
expected = ["first", "second", RuntimeError()]
|
||||
_assert_potential_error(actual, expected)
|
||||
|
||||
|
||||
def _generate(input: Iterator) -> Iterator[str]:
|
||||
yield from "foo bar"
|
||||
|
||||
|
||||
def _generate_immediate_error(input: Iterator) -> Iterator[str]:
|
||||
raise ValueError()
|
||||
yield ""
|
||||
|
||||
|
||||
def _generate_delayed_error(input: Iterator) -> Iterator[str]:
|
||||
yield ""
|
||||
raise ValueError()
|
||||
|
||||
|
||||
def test_fallbacks_stream() -> None:
|
||||
runnable = RunnableGenerator(_generate_immediate_error).with_fallbacks(
|
||||
[RunnableGenerator(_generate)]
|
||||
)
|
||||
assert list(runnable.stream({})) == [c for c in "foo bar"]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runnable = RunnableGenerator(_generate_delayed_error).with_fallbacks(
|
||||
[RunnableGenerator(_generate)]
|
||||
)
|
||||
list(runnable.stream({}))
|
||||
|
||||
|
||||
async def _agenerate(input: AsyncIterator) -> AsyncIterator[str]:
|
||||
for c in "foo bar":
|
||||
yield c
|
||||
|
||||
|
||||
async def _agenerate_immediate_error(input: AsyncIterator) -> AsyncIterator[str]:
|
||||
raise ValueError()
|
||||
yield ""
|
||||
|
||||
|
||||
async def _agenerate_delayed_error(input: AsyncIterator) -> AsyncIterator[str]:
|
||||
yield ""
|
||||
raise ValueError()
|
||||
|
||||
|
||||
async def test_fallbacks_astream() -> None:
|
||||
runnable = RunnableGenerator(_agenerate_immediate_error).with_fallbacks(
|
||||
[RunnableGenerator(_agenerate)]
|
||||
)
|
||||
expected = (c for c in "foo bar")
|
||||
async for c in runnable.astream({}):
|
||||
assert c == next(expected)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runnable = RunnableGenerator(_agenerate_delayed_error).with_fallbacks(
|
||||
[RunnableGenerator(_agenerate)]
|
||||
)
|
||||
async for c in runnable.astream({}):
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user