core: In astream_events v2 propagate cancel/break to the inner astream call (#22865)

- previous behavior was for the inner astream to continue running with
no interruption
- also propagate break in core runnable methods
pull/22868/head
Nuno Campos 4 months ago committed by GitHub
parent a766815a99
commit bae82e966a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -13,6 +13,7 @@ from operator import itemgetter
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
@ -79,7 +80,7 @@ from langchain_core.runnables.utils import (
is_async_callable,
is_async_generator,
)
from langchain_core.utils.aiter import atee, py_anext
from langchain_core.utils.aiter import aclosing, atee, py_anext
from langchain_core.utils.iter import safetee
if TYPE_CHECKING:
@ -1141,8 +1142,9 @@ class Runnable(Generic[Input, Output], ABC):
'Only versions "v1" and "v2" of the schema is currently supported.'
)
async for event in event_stream:
yield event
async with aclosing(event_stream):
async for event in event_stream:
yield event
def transform(
self,
@ -1948,7 +1950,7 @@ class Runnable(Generic[Input, Output], ABC):
kwargs["run_manager"] = run_manager
context = copy_context()
context.run(_set_config_context, child_config)
iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
iterator_ = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
if stream_handler := next(
(
@ -1960,7 +1962,11 @@ class Runnable(Generic[Input, Output], ABC):
None,
):
# populates streamed_output in astream_log() output if needed
iterator = stream_handler.tap_output_aiter(run_manager.run_id, iterator)
iterator = stream_handler.tap_output_aiter(
run_manager.run_id, iterator_
)
else:
iterator = iterator_
try:
while True:
if accepts_context(asyncio.create_task):
@ -2001,6 +2007,9 @@ class Runnable(Generic[Input, Output], ABC):
raise
else:
await run_manager.on_chain_end(final_output, inputs=final_input)
finally:
if hasattr(iterator_, "aclose"):
await iterator_.aclose()
class RunnableSerializable(Serializable, Runnable[Input, Output]):
@ -3907,23 +3916,29 @@ class RunnableLambda(Runnable[Input, Output]):
if is_async_generator(afunc):
output: Optional[Output] = None
async for chunk in cast(
AsyncIterator[Output],
acall_func_with_variable_args(
cast(Callable, afunc),
input,
config,
run_manager,
**kwargs,
),
):
if output is None:
output = chunk
else:
try:
output = output + chunk # type: ignore[operator]
except TypeError:
async with aclosing(
cast(
AsyncGenerator[Any, Any],
acall_func_with_variable_args(
cast(Callable, afunc),
input,
config,
run_manager,
**kwargs,
),
)
) as stream:
async for chunk in cast(
AsyncIterator[Output],
stream,
):
if output is None:
output = chunk
else:
try:
output = output + chunk # type: ignore[operator]
except TypeError:
output = chunk
else:
output = await acall_func_with_variable_args(
cast(Callable, afunc), input, config, run_manager, **kwargs

@ -37,7 +37,7 @@ from langchain_core.runnables.utils import (
from langchain_core.tracers._streaming import _StreamingCallbackHandler
from langchain_core.tracers.log_stream import LogEntry
from langchain_core.tracers.memory_stream import _MemoryStream
from langchain_core.utils.aiter import py_anext
from langchain_core.utils.aiter import aclosing, py_anext
if TYPE_CHECKING:
from langchain_core.documents import Document
@ -903,11 +903,10 @@ async def _astream_events_implementation_v2(
async def consume_astream() -> None:
try:
# if astream also calls tap_output_aiter this will be a no-op
async for _ in event_streamer.tap_output_aiter(
run_id, runnable.astream(input, config, **kwargs)
):
# All the content will be picked up
pass
async with aclosing(runnable.astream(input, config, **kwargs)) as stream:
async for _ in event_streamer.tap_output_aiter(run_id, stream):
# All the content will be picked up
pass
finally:
await event_streamer.send_stream.aclose()
@ -942,7 +941,8 @@ async def _astream_events_implementation_v2(
yield event
finally:
# Wait for the runnable to finish, if not cancelled (eg. by break)
try:
await task
except asyncio.CancelledError:
pass
if task.cancel():
try:
await task
except asyncio.CancelledError:
pass

@ -5,6 +5,8 @@ MIT License
"""
from collections import deque
from contextlib import AbstractAsyncContextManager
from types import TracebackType
from typing import (
Any,
AsyncContextManager,
@ -18,6 +20,7 @@ from typing import (
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
@ -207,3 +210,40 @@ class Tee(Generic[T]):
atee = Tee
class aclosing(AbstractAsyncContextManager):
"""Async context manager for safely finalizing an asynchronously cleaned-up
resource such as an async generator, calling its ``aclose()`` method.
Code like this:
async with aclosing(<module>.fetch(<arguments>)) as agen:
<block>
is equivalent to this:
agen = <module>.fetch(<arguments>)
try:
<block>
finally:
await agen.aclose()
"""
def __init__(
self, thing: Union[AsyncGenerator[Any, Any], AsyncIterator[Any]]
) -> None:
self.thing = thing
async def __aenter__(self) -> Union[AsyncGenerator[Any, Any], AsyncIterator[Any]]:
return self.thing
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
if hasattr(self.thing, "aclose"):
await self.thing.aclose()

@ -1,4 +1,5 @@
"""Module that contains tests for runnable.astream_events API."""
import asyncio
import sys
import uuid
from itertools import cycle
@ -38,6 +39,7 @@ from langchain_core.runnables import (
RunnableConfig,
RunnableGenerator,
RunnableLambda,
chain,
ensure_config,
)
from langchain_core.runnables.config import get_callback_manager_for_config
@ -45,6 +47,7 @@ from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.schema import StreamEvent
from langchain_core.runnables.utils import Input, Output
from langchain_core.tools import tool
from langchain_core.utils.aiter import aclosing
from tests.unit_tests.stubs import AnyStr
@ -2195,3 +2198,140 @@ async def test_with_explicit_config() -> None:
for event in events
if event["event"] == "on_chat_model_stream"
] == ["hello", " ", "world"]
async def test_break_astream_events() -> None:
class AwhileMaker:
def __init__(self) -> None:
self.reset()
async def __call__(self, input: Any) -> Any:
self.started = True
try:
await asyncio.sleep(0.5)
return input
except asyncio.CancelledError:
self.cancelled = True
raise
def reset(self) -> None:
self.started = False
self.cancelled = False
alittlewhile = AwhileMaker()
awhile = AwhileMaker()
anotherwhile = AwhileMaker()
outer_cancelled = False
@chain
async def sequence(input: Any) -> Any:
try:
yield await alittlewhile(input)
yield await awhile(input)
yield await anotherwhile(input)
except asyncio.CancelledError:
nonlocal outer_cancelled
outer_cancelled = True
raise
# test interrupting astream_events v2
got_event = False
thread2: RunnableConfig = {"configurable": {"thread_id": 2}}
async with aclosing(
sequence.astream_events({"value": 1}, thread2, version="v2")
) as stream:
async for chunk in stream:
if chunk["event"] == "on_chain_stream":
got_event = True
assert chunk["data"]["chunk"] == {"value": 1}
break
# did break
assert got_event
# did cancel outer chain
assert outer_cancelled
# node "alittlewhile" starts, not cancelled
assert alittlewhile.started is True
assert alittlewhile.cancelled is False
# node "awhile" starts but is cancelled
assert awhile.started is True
assert awhile.cancelled is True
# node "anotherwhile" should never start
assert anotherwhile.started is False
async def test_cancel_astream_events() -> None:
class AwhileMaker:
def __init__(self) -> None:
self.reset()
async def __call__(self, input: Any) -> Any:
self.started = True
try:
await asyncio.sleep(0.5)
return input
except asyncio.CancelledError:
self.cancelled = True
raise
def reset(self) -> None:
self.started = False
self.cancelled = False
alittlewhile = AwhileMaker()
awhile = AwhileMaker()
anotherwhile = AwhileMaker()
outer_cancelled = False
@chain
async def sequence(input: Any) -> Any:
try:
yield await alittlewhile(input)
yield await awhile(input)
yield await anotherwhile(input)
except asyncio.CancelledError:
nonlocal outer_cancelled
outer_cancelled = True
raise
got_event = False
async def aconsume(stream: AsyncIterator[Any]) -> None:
nonlocal got_event
# here we don't need aclosing as cancelling the task is propagated
# to the async generator being consumed
async for chunk in stream:
if chunk["event"] == "on_chain_stream":
got_event = True
assert chunk["data"]["chunk"] == {"value": 1}
task.cancel()
thread2: RunnableConfig = {"configurable": {"thread_id": 2}}
task = asyncio.create_task(
aconsume(sequence.astream_events({"value": 1}, thread2, version="v2"))
)
with pytest.raises(asyncio.CancelledError):
await task
# did break
assert got_event
# did cancel outer chain
assert outer_cancelled
# node "alittlewhile" starts, not cancelled
assert alittlewhile.started is True
assert alittlewhile.cancelled is False
# node "awhile" starts but is cancelled
assert awhile.started is True
assert awhile.cancelled is True
# node "anotherwhile" should never start
assert anotherwhile.started is False

Loading…
Cancel
Save