From 93472ee9e62d13f3f8bf3adfa9b732d1536edb14 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 12 Feb 2024 21:57:38 -0500 Subject: [PATCH] core[patch]: Replace memory stream implementation used by LogStreamCallbackHandler (#17185) This PR replaces the memory stream implementation used by the LogStreamCallbackHandler. This implementation resolves an issue in which streamed logs and streamed events originating from sync code would arrive only after the entire sync code would finish execution (rather than arriving in real time as they're generated). One example is if trying to stream tokens from an llm within a tool. If the tool was an async tool, but the llm was invoked via stream (sync variant) rather than astream (async variant), then the tokens would fail to stream in real time and would all arrived bunched up after the tool invocation completed. --- .../core/langchain_core/tracers/log_stream.py | 23 ++-- .../langchain_core/tracers/memory_stream.py | 104 +++++++++++++++ libs/core/tests/unit_tests/test_tools.py | 4 +- .../unit_tests/tracers/test_memory_stream.py | 122 ++++++++++++++++++ 4 files changed, 239 insertions(+), 14 deletions(-) create mode 100644 libs/core/langchain_core/tracers/memory_stream.py create mode 100644 libs/core/tests/unit_tests/tracers/test_memory_stream.py diff --git a/libs/core/langchain_core/tracers/log_stream.py b/libs/core/langchain_core/tracers/log_stream.py index b3cadb55a6..2d92964b19 100644 --- a/libs/core/langchain_core/tracers/log_stream.py +++ b/libs/core/langchain_core/tracers/log_stream.py @@ -2,7 +2,6 @@ from __future__ import annotations import asyncio import copy -import math import threading from collections import defaultdict from typing import ( @@ -20,7 +19,6 @@ from typing import ( from uuid import UUID import jsonpatch # type: ignore[import] -from anyio import BrokenResourceError, ClosedResourceError, create_memory_object_stream from typing_extensions import NotRequired, TypedDict from langchain_core.load import dumps @@ -29,6 +27,7 @@ from langchain_core.outputs import ChatGenerationChunk, GenerationChunk from langchain_core.runnables import Runnable, RunnableConfig, ensure_config from langchain_core.runnables.utils import Input, Output from langchain_core.tracers.base import BaseTracer +from langchain_core.tracers.memory_stream import _MemoryStream from langchain_core.tracers.schemas import Run @@ -210,12 +209,11 @@ class LogStreamCallbackHandler(BaseTracer): self.exclude_types = exclude_types self.exclude_tags = exclude_tags - send_stream: Any - receive_stream: Any - send_stream, receive_stream = create_memory_object_stream(math.inf) + loop = asyncio.get_event_loop() + memory_stream = _MemoryStream[RunLogPatch](loop) self.lock = threading.Lock() - self.send_stream = send_stream - self.receive_stream = receive_stream + self.send_stream = memory_stream.get_send_stream() + self.receive_stream = memory_stream.get_receive_stream() self._key_map_by_run_id: Dict[UUID, str] = {} self._counter_map_by_name: Dict[str, int] = defaultdict(int) self.root_id: Optional[UUID] = None @@ -225,11 +223,12 @@ class LogStreamCallbackHandler(BaseTracer): def send(self, *ops: Dict[str, Any]) -> bool: """Send a patch to the stream, return False if the stream is closed.""" - try: - self.send_stream.send_nowait(RunLogPatch(*ops)) - return True - except (ClosedResourceError, BrokenResourceError): - return False + # We will likely want to wrap this in try / except at some point + # to handle exceptions that might arise at run time. + # For now we'll let the exception bubble up, and always return + # True on the happy path. + self.send_stream.send_nowait(RunLogPatch(*ops)) + return True async def tap_output_aiter( self, run_id: UUID, output: AsyncIterator[T] diff --git a/libs/core/langchain_core/tracers/memory_stream.py b/libs/core/langchain_core/tracers/memory_stream.py new file mode 100644 index 0000000000..871b5ca932 --- /dev/null +++ b/libs/core/langchain_core/tracers/memory_stream.py @@ -0,0 +1,104 @@ +"""Module implements a memory stream for communication between two co-routines. + +This module provides a way to communicate between two co-routines using a memory +channel. The writer and reader can be in the same event loop or in different event +loops. When they're in different event loops, they will also be in different +threads. + +This is useful in situations when there's a mix of synchronous and asynchronous +used in the code. +""" +import asyncio +from asyncio import AbstractEventLoop, Queue +from typing import AsyncIterator, Generic, TypeVar + +T = TypeVar("T") + + +class _SendStream(Generic[T]): + def __init__( + self, reader_loop: AbstractEventLoop, queue: Queue, done: object + ) -> None: + """Create a writer for the queue and done object. + + Args: + reader_loop: The event loop to use for the writer. This loop will be used + to schedule the writes to the queue. + queue: The queue to write to. This is an asyncio queue. + done: Special sentinel object to indicate that the writer is done. + """ + self._reader_loop = reader_loop + self._queue = queue + self._done = done + + async def send(self, item: T) -> None: + """Schedule the item to be written to the queue using the original loop.""" + return self.send_nowait(item) + + def send_nowait(self, item: T) -> None: + """Schedule the item to be written to the queue using the original loop.""" + self._reader_loop.call_soon_threadsafe(self._queue.put_nowait, item) + + async def aclose(self) -> None: + """Schedule the done object write the queue using the original loop.""" + return self.close() + + def close(self) -> None: + """Schedule the done object write the queue using the original loop.""" + self._reader_loop.call_soon_threadsafe(self._queue.put_nowait, self._done) + + +class _ReceiveStream(Generic[T]): + def __init__(self, queue: Queue, done: object) -> None: + """Create a reader for the queue and done object. + + This reader should be used in the same loop as the loop that was passed + to the channel. + """ + self._queue = queue + self._done = done + self._is_closed = False + + async def __aiter__(self) -> AsyncIterator[T]: + while True: + item = await self._queue.get() + if item is self._done: + self._is_closed = True + break + yield item + + +class _MemoryStream(Generic[T]): + """Stream data from a writer to a reader even if they are in different threads. + + Uses asyncio queues to communicate between two co-routines. This implementation + should work even if the writer and reader co-routines belong to two different + event loops (e.g. one running from an event loop in the main thread + and the other running in an event loop in a background thread). + + This implementation is meant to be used with a single writer and a single reader. + + This is an internal implementation to LangChain please do not use it directly. + """ + + def __init__(self, loop: AbstractEventLoop) -> None: + """Create a channel for the given loop. + + Args: + loop: The event loop to use for the channel. The reader is assumed + to be running in the same loop as the one passed to this constructor. + This will NOT be validated at run time. + """ + self._loop = loop + self._queue: asyncio.Queue = asyncio.Queue(maxsize=0) + self._done = object() + + def get_send_stream(self) -> _SendStream[T]: + """Get a writer for the channel.""" + return _SendStream[T]( + reader_loop=self._loop, queue=self._queue, done=self._done + ) + + def get_receive_stream(self) -> _ReceiveStream[T]: + """Get a reader for the channel.""" + return _ReceiveStream[T](queue=self._queue, done=self._done) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index d89f8b2657..001ebaa756 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -738,7 +738,7 @@ def test_validation_error_handling_callable() -> None: ], ) def test_validation_error_handling_non_validation_error( - handler: Union[bool, str, Callable[[ValidationError], str]] + handler: Union[bool, str, Callable[[ValidationError], str]], ) -> None: """Test that validation errors are handled correctly.""" @@ -800,7 +800,7 @@ async def test_async_validation_error_handling_callable() -> None: ], ) async def test_async_validation_error_handling_non_validation_error( - handler: Union[bool, str, Callable[[ValidationError], str]] + handler: Union[bool, str, Callable[[ValidationError], str]], ) -> None: """Test that validation errors are handled correctly.""" diff --git a/libs/core/tests/unit_tests/tracers/test_memory_stream.py b/libs/core/tests/unit_tests/tracers/test_memory_stream.py new file mode 100644 index 0000000000..85d1698682 --- /dev/null +++ b/libs/core/tests/unit_tests/tracers/test_memory_stream.py @@ -0,0 +1,122 @@ +import asyncio +import math +import time +from concurrent.futures import ThreadPoolExecutor +from typing import AsyncIterator + +from langchain_core.tracers.memory_stream import _MemoryStream + + +async def test_same_event_loop() -> None: + """Test that the memory stream works when the same event loop is used. + + This is the easy case. + """ + reader_loop = asyncio.get_event_loop() + channel = _MemoryStream[dict](reader_loop) + writer = channel.get_send_stream() + reader = channel.get_receive_stream() + + async def producer() -> None: + """Produce items with slight delay.""" + tic = time.time() + for i in range(3): + await asyncio.sleep(0.10) + toc = time.time() + await writer.send( + { + "item": i, + "produce_time": toc - tic, + } + ) + await writer.aclose() + + async def consumer() -> AsyncIterator[dict]: + tic = time.time() + async for item in reader: + toc = time.time() + yield { + "receive_time": toc - tic, + **item, + } + + asyncio.create_task(producer()) + + items = [item async for item in consumer()] + + for item in items: + delta_time = item["receive_time"] - item["produce_time"] + # Allow a generous 10ms of delay + # The test is meant to verify that the producer and consumer are running in + # parallel despite the fact that the producer is running from another thread. + # abs_tol is used to allow for some delay in the producer and consumer + # due to overhead. + # To verify that the producer and consumer are running in parallel, we + # expect the delta_time to be smaller than the sleep delay in the producer + # * # of items = 30 ms + assert ( + math.isclose(delta_time, 0, abs_tol=0.010) is True + ), f"delta_time: {delta_time}" + + +async def test_queue_for_streaming_via_sync_call() -> None: + """Test via async -> sync -> async path.""" + reader_loop = asyncio.get_event_loop() + channel = _MemoryStream[dict](reader_loop) + writer = channel.get_send_stream() + reader = channel.get_receive_stream() + + async def producer() -> None: + """Produce items with slight delay.""" + tic = time.time() + for i in range(3): + await asyncio.sleep(0.10) + toc = time.time() + await writer.send( + { + "item": i, + "produce_time": toc - tic, + } + ) + await writer.aclose() + + def sync_call() -> None: + """Blocking sync call.""" + asyncio.run(producer()) + + async def consumer() -> AsyncIterator[dict]: + tic = time.time() + async for item in reader: + toc = time.time() + yield { + "receive_time": toc - tic, + **item, + } + + with ThreadPoolExecutor() as executor: + executor.submit(sync_call) + items = [item async for item in consumer()] + + for item in items: + delta_time = item["receive_time"] - item["produce_time"] + # Allow a generous 10ms of delay + # The test is meant to verify that the producer and consumer are running in + # parallel despite the fact that the producer is running from another thread. + # abs_tol is used to allow for some delay in the producer and consumer + # due to overhead. + # To verify that the producer and consumer are running in parallel, we + # expect the delta_time to be smaller than the sleep delay in the producer + # * # of items = 30 ms + assert ( + math.isclose(delta_time, 0, abs_tol=0.010) is True + ), f"delta_time: {delta_time}" + + +async def test_closed_stream() -> None: + reader_loop = asyncio.get_event_loop() + channel = _MemoryStream[str](reader_loop) + writer = channel.get_send_stream() + reader = channel.get_receive_stream() + await writer.aclose() + + assert [chunk async for chunk in reader] == []