|
|
|
@ -1,4 +1,5 @@
|
|
|
|
|
"""Module that contains tests for runnable.astream_events API."""
|
|
|
|
|
import asyncio
|
|
|
|
|
import sys
|
|
|
|
|
import uuid
|
|
|
|
|
from itertools import cycle
|
|
|
|
@ -38,6 +39,7 @@ from langchain_core.runnables import (
|
|
|
|
|
RunnableConfig,
|
|
|
|
|
RunnableGenerator,
|
|
|
|
|
RunnableLambda,
|
|
|
|
|
chain,
|
|
|
|
|
ensure_config,
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.runnables.config import get_callback_manager_for_config
|
|
|
|
@ -45,6 +47,7 @@ from langchain_core.runnables.history import RunnableWithMessageHistory
|
|
|
|
|
from langchain_core.runnables.schema import StreamEvent
|
|
|
|
|
from langchain_core.runnables.utils import Input, Output
|
|
|
|
|
from langchain_core.tools import tool
|
|
|
|
|
from langchain_core.utils.aiter import aclosing
|
|
|
|
|
from tests.unit_tests.stubs import AnyStr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -2195,3 +2198,140 @@ async def test_with_explicit_config() -> None:
|
|
|
|
|
for event in events
|
|
|
|
|
if event["event"] == "on_chat_model_stream"
|
|
|
|
|
] == ["hello", " ", "world"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def test_break_astream_events() -> None:
|
|
|
|
|
class AwhileMaker:
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
self.reset()
|
|
|
|
|
|
|
|
|
|
async def __call__(self, input: Any) -> Any:
|
|
|
|
|
self.started = True
|
|
|
|
|
try:
|
|
|
|
|
await asyncio.sleep(0.5)
|
|
|
|
|
return input
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
self.cancelled = True
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def reset(self) -> None:
|
|
|
|
|
self.started = False
|
|
|
|
|
self.cancelled = False
|
|
|
|
|
|
|
|
|
|
alittlewhile = AwhileMaker()
|
|
|
|
|
awhile = AwhileMaker()
|
|
|
|
|
anotherwhile = AwhileMaker()
|
|
|
|
|
|
|
|
|
|
outer_cancelled = False
|
|
|
|
|
|
|
|
|
|
@chain
|
|
|
|
|
async def sequence(input: Any) -> Any:
|
|
|
|
|
try:
|
|
|
|
|
yield await alittlewhile(input)
|
|
|
|
|
yield await awhile(input)
|
|
|
|
|
yield await anotherwhile(input)
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
nonlocal outer_cancelled
|
|
|
|
|
outer_cancelled = True
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
# test interrupting astream_events v2
|
|
|
|
|
|
|
|
|
|
got_event = False
|
|
|
|
|
thread2: RunnableConfig = {"configurable": {"thread_id": 2}}
|
|
|
|
|
async with aclosing(
|
|
|
|
|
sequence.astream_events({"value": 1}, thread2, version="v2")
|
|
|
|
|
) as stream:
|
|
|
|
|
async for chunk in stream:
|
|
|
|
|
if chunk["event"] == "on_chain_stream":
|
|
|
|
|
got_event = True
|
|
|
|
|
assert chunk["data"]["chunk"] == {"value": 1}
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# did break
|
|
|
|
|
assert got_event
|
|
|
|
|
# did cancel outer chain
|
|
|
|
|
assert outer_cancelled
|
|
|
|
|
|
|
|
|
|
# node "alittlewhile" starts, not cancelled
|
|
|
|
|
assert alittlewhile.started is True
|
|
|
|
|
assert alittlewhile.cancelled is False
|
|
|
|
|
|
|
|
|
|
# node "awhile" starts but is cancelled
|
|
|
|
|
assert awhile.started is True
|
|
|
|
|
assert awhile.cancelled is True
|
|
|
|
|
|
|
|
|
|
# node "anotherwhile" should never start
|
|
|
|
|
assert anotherwhile.started is False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def test_cancel_astream_events() -> None:
|
|
|
|
|
class AwhileMaker:
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
self.reset()
|
|
|
|
|
|
|
|
|
|
async def __call__(self, input: Any) -> Any:
|
|
|
|
|
self.started = True
|
|
|
|
|
try:
|
|
|
|
|
await asyncio.sleep(0.5)
|
|
|
|
|
return input
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
self.cancelled = True
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
def reset(self) -> None:
|
|
|
|
|
self.started = False
|
|
|
|
|
self.cancelled = False
|
|
|
|
|
|
|
|
|
|
alittlewhile = AwhileMaker()
|
|
|
|
|
awhile = AwhileMaker()
|
|
|
|
|
anotherwhile = AwhileMaker()
|
|
|
|
|
|
|
|
|
|
outer_cancelled = False
|
|
|
|
|
|
|
|
|
|
@chain
|
|
|
|
|
async def sequence(input: Any) -> Any:
|
|
|
|
|
try:
|
|
|
|
|
yield await alittlewhile(input)
|
|
|
|
|
yield await awhile(input)
|
|
|
|
|
yield await anotherwhile(input)
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
nonlocal outer_cancelled
|
|
|
|
|
outer_cancelled = True
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
got_event = False
|
|
|
|
|
|
|
|
|
|
async def aconsume(stream: AsyncIterator[Any]) -> None:
|
|
|
|
|
nonlocal got_event
|
|
|
|
|
# here we don't need aclosing as cancelling the task is propagated
|
|
|
|
|
# to the async generator being consumed
|
|
|
|
|
async for chunk in stream:
|
|
|
|
|
if chunk["event"] == "on_chain_stream":
|
|
|
|
|
got_event = True
|
|
|
|
|
assert chunk["data"]["chunk"] == {"value": 1}
|
|
|
|
|
task.cancel()
|
|
|
|
|
|
|
|
|
|
thread2: RunnableConfig = {"configurable": {"thread_id": 2}}
|
|
|
|
|
task = asyncio.create_task(
|
|
|
|
|
aconsume(sequence.astream_events({"value": 1}, thread2, version="v2"))
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
with pytest.raises(asyncio.CancelledError):
|
|
|
|
|
await task
|
|
|
|
|
|
|
|
|
|
# did break
|
|
|
|
|
assert got_event
|
|
|
|
|
# did cancel outer chain
|
|
|
|
|
assert outer_cancelled
|
|
|
|
|
|
|
|
|
|
# node "alittlewhile" starts, not cancelled
|
|
|
|
|
assert alittlewhile.started is True
|
|
|
|
|
assert alittlewhile.cancelled is False
|
|
|
|
|
|
|
|
|
|
# node "awhile" starts but is cancelled
|
|
|
|
|
assert awhile.started is True
|
|
|
|
|
assert awhile.cancelled is True
|
|
|
|
|
|
|
|
|
|
# node "anotherwhile" should never start
|
|
|
|
|
assert anotherwhile.started is False
|
|
|
|
|