diff --git a/libs/core/langchain_core/tracers/log_stream.py b/libs/core/langchain_core/tracers/log_stream.py index b3cadb55a6..2d92964b19 100644 --- a/libs/core/langchain_core/tracers/log_stream.py +++ b/libs/core/langchain_core/tracers/log_stream.py @@ -2,7 +2,6 @@ from __future__ import annotations import asyncio import copy -import math import threading from collections import defaultdict from typing import ( @@ -20,7 +19,6 @@ from typing import ( from uuid import UUID import jsonpatch # type: ignore[import] -from anyio import BrokenResourceError, ClosedResourceError, create_memory_object_stream from typing_extensions import NotRequired, TypedDict from langchain_core.load import dumps @@ -29,6 +27,7 @@ from langchain_core.outputs import ChatGenerationChunk, GenerationChunk from langchain_core.runnables import Runnable, RunnableConfig, ensure_config from langchain_core.runnables.utils import Input, Output from langchain_core.tracers.base import BaseTracer +from langchain_core.tracers.memory_stream import _MemoryStream from langchain_core.tracers.schemas import Run @@ -210,12 +209,11 @@ class LogStreamCallbackHandler(BaseTracer): self.exclude_types = exclude_types self.exclude_tags = exclude_tags - send_stream: Any - receive_stream: Any - send_stream, receive_stream = create_memory_object_stream(math.inf) + loop = asyncio.get_event_loop() + memory_stream = _MemoryStream[RunLogPatch](loop) self.lock = threading.Lock() - self.send_stream = send_stream - self.receive_stream = receive_stream + self.send_stream = memory_stream.get_send_stream() + self.receive_stream = memory_stream.get_receive_stream() self._key_map_by_run_id: Dict[UUID, str] = {} self._counter_map_by_name: Dict[str, int] = defaultdict(int) self.root_id: Optional[UUID] = None @@ -225,11 +223,12 @@ class LogStreamCallbackHandler(BaseTracer): def send(self, *ops: Dict[str, Any]) -> bool: """Send a patch to the stream, return False if the stream is closed.""" - try: - self.send_stream.send_nowait(RunLogPatch(*ops)) - return True - except (ClosedResourceError, BrokenResourceError): - return False + # We will likely want to wrap this in try / except at some point + # to handle exceptions that might arise at run time. + # For now we'll let the exception bubble up, and always return + # True on the happy path. + self.send_stream.send_nowait(RunLogPatch(*ops)) + return True async def tap_output_aiter( self, run_id: UUID, output: AsyncIterator[T] diff --git a/libs/core/langchain_core/tracers/memory_stream.py b/libs/core/langchain_core/tracers/memory_stream.py new file mode 100644 index 0000000000..871b5ca932 --- /dev/null +++ b/libs/core/langchain_core/tracers/memory_stream.py @@ -0,0 +1,104 @@ +"""Module implements a memory stream for communication between two co-routines. + +This module provides a way to communicate between two co-routines using a memory +channel. The writer and reader can be in the same event loop or in different event +loops. When they're in different event loops, they will also be in different +threads. + +This is useful in situations when there's a mix of synchronous and asynchronous +used in the code. +""" +import asyncio +from asyncio import AbstractEventLoop, Queue +from typing import AsyncIterator, Generic, TypeVar + +T = TypeVar("T") + + +class _SendStream(Generic[T]): + def __init__( + self, reader_loop: AbstractEventLoop, queue: Queue, done: object + ) -> None: + """Create a writer for the queue and done object. + + Args: + reader_loop: The event loop to use for the writer. This loop will be used + to schedule the writes to the queue. + queue: The queue to write to. This is an asyncio queue. + done: Special sentinel object to indicate that the writer is done. + """ + self._reader_loop = reader_loop + self._queue = queue + self._done = done + + async def send(self, item: T) -> None: + """Schedule the item to be written to the queue using the original loop.""" + return self.send_nowait(item) + + def send_nowait(self, item: T) -> None: + """Schedule the item to be written to the queue using the original loop.""" + self._reader_loop.call_soon_threadsafe(self._queue.put_nowait, item) + + async def aclose(self) -> None: + """Schedule the done object write the queue using the original loop.""" + return self.close() + + def close(self) -> None: + """Schedule the done object write the queue using the original loop.""" + self._reader_loop.call_soon_threadsafe(self._queue.put_nowait, self._done) + + +class _ReceiveStream(Generic[T]): + def __init__(self, queue: Queue, done: object) -> None: + """Create a reader for the queue and done object. + + This reader should be used in the same loop as the loop that was passed + to the channel. + """ + self._queue = queue + self._done = done + self._is_closed = False + + async def __aiter__(self) -> AsyncIterator[T]: + while True: + item = await self._queue.get() + if item is self._done: + self._is_closed = True + break + yield item + + +class _MemoryStream(Generic[T]): + """Stream data from a writer to a reader even if they are in different threads. + + Uses asyncio queues to communicate between two co-routines. This implementation + should work even if the writer and reader co-routines belong to two different + event loops (e.g. one running from an event loop in the main thread + and the other running in an event loop in a background thread). + + This implementation is meant to be used with a single writer and a single reader. + + This is an internal implementation to LangChain please do not use it directly. + """ + + def __init__(self, loop: AbstractEventLoop) -> None: + """Create a channel for the given loop. + + Args: + loop: The event loop to use for the channel. The reader is assumed + to be running in the same loop as the one passed to this constructor. + This will NOT be validated at run time. + """ + self._loop = loop + self._queue: asyncio.Queue = asyncio.Queue(maxsize=0) + self._done = object() + + def get_send_stream(self) -> _SendStream[T]: + """Get a writer for the channel.""" + return _SendStream[T]( + reader_loop=self._loop, queue=self._queue, done=self._done + ) + + def get_receive_stream(self) -> _ReceiveStream[T]: + """Get a reader for the channel.""" + return _ReceiveStream[T](queue=self._queue, done=self._done) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index d89f8b2657..001ebaa756 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -738,7 +738,7 @@ def test_validation_error_handling_callable() -> None: ], ) def test_validation_error_handling_non_validation_error( - handler: Union[bool, str, Callable[[ValidationError], str]] + handler: Union[bool, str, Callable[[ValidationError], str]], ) -> None: """Test that validation errors are handled correctly.""" @@ -800,7 +800,7 @@ async def test_async_validation_error_handling_callable() -> None: ], ) async def test_async_validation_error_handling_non_validation_error( - handler: Union[bool, str, Callable[[ValidationError], str]] + handler: Union[bool, str, Callable[[ValidationError], str]], ) -> None: """Test that validation errors are handled correctly.""" diff --git a/libs/core/tests/unit_tests/tracers/test_memory_stream.py b/libs/core/tests/unit_tests/tracers/test_memory_stream.py new file mode 100644 index 0000000000..85d1698682 --- /dev/null +++ b/libs/core/tests/unit_tests/tracers/test_memory_stream.py @@ -0,0 +1,122 @@ +import asyncio +import math +import time +from concurrent.futures import ThreadPoolExecutor +from typing import AsyncIterator + +from langchain_core.tracers.memory_stream import _MemoryStream + + +async def test_same_event_loop() -> None: + """Test that the memory stream works when the same event loop is used. + + This is the easy case. + """ + reader_loop = asyncio.get_event_loop() + channel = _MemoryStream[dict](reader_loop) + writer = channel.get_send_stream() + reader = channel.get_receive_stream() + + async def producer() -> None: + """Produce items with slight delay.""" + tic = time.time() + for i in range(3): + await asyncio.sleep(0.10) + toc = time.time() + await writer.send( + { + "item": i, + "produce_time": toc - tic, + } + ) + await writer.aclose() + + async def consumer() -> AsyncIterator[dict]: + tic = time.time() + async for item in reader: + toc = time.time() + yield { + "receive_time": toc - tic, + **item, + } + + asyncio.create_task(producer()) + + items = [item async for item in consumer()] + + for item in items: + delta_time = item["receive_time"] - item["produce_time"] + # Allow a generous 10ms of delay + # The test is meant to verify that the producer and consumer are running in + # parallel despite the fact that the producer is running from another thread. + # abs_tol is used to allow for some delay in the producer and consumer + # due to overhead. + # To verify that the producer and consumer are running in parallel, we + # expect the delta_time to be smaller than the sleep delay in the producer + # * # of items = 30 ms + assert ( + math.isclose(delta_time, 0, abs_tol=0.010) is True + ), f"delta_time: {delta_time}" + + +async def test_queue_for_streaming_via_sync_call() -> None: + """Test via async -> sync -> async path.""" + reader_loop = asyncio.get_event_loop() + channel = _MemoryStream[dict](reader_loop) + writer = channel.get_send_stream() + reader = channel.get_receive_stream() + + async def producer() -> None: + """Produce items with slight delay.""" + tic = time.time() + for i in range(3): + await asyncio.sleep(0.10) + toc = time.time() + await writer.send( + { + "item": i, + "produce_time": toc - tic, + } + ) + await writer.aclose() + + def sync_call() -> None: + """Blocking sync call.""" + asyncio.run(producer()) + + async def consumer() -> AsyncIterator[dict]: + tic = time.time() + async for item in reader: + toc = time.time() + yield { + "receive_time": toc - tic, + **item, + } + + with ThreadPoolExecutor() as executor: + executor.submit(sync_call) + items = [item async for item in consumer()] + + for item in items: + delta_time = item["receive_time"] - item["produce_time"] + # Allow a generous 10ms of delay + # The test is meant to verify that the producer and consumer are running in + # parallel despite the fact that the producer is running from another thread. + # abs_tol is used to allow for some delay in the producer and consumer + # due to overhead. + # To verify that the producer and consumer are running in parallel, we + # expect the delta_time to be smaller than the sleep delay in the producer + # * # of items = 30 ms + assert ( + math.isclose(delta_time, 0, abs_tol=0.010) is True + ), f"delta_time: {delta_time}" + + +async def test_closed_stream() -> None: + reader_loop = asyncio.get_event_loop() + channel = _MemoryStream[str](reader_loop) + writer = channel.get_send_stream() + reader = channel.get_receive_stream() + await writer.aclose() + + assert [chunk async for chunk in reader] == []