From 77ce9ed6f1711005aa07317fced56ec1b614852c Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Thu, 28 Sep 2023 10:39:01 +0100 Subject: [PATCH] Support using async callback handlers with sync callback manager (#10945) The current behaviour just calls the handler without awaiting the coroutine, which results in exceptions/warnings, and obviously doesn't actually execute whatever the callback handler does --- libs/langchain/langchain/callbacks/manager.py | 107 +++++++++++++----- .../callbacks/test_callback_manager.py | 21 ++++ 2 files changed, 99 insertions(+), 29 deletions(-) diff --git a/libs/langchain/langchain/callbacks/manager.py b/libs/langchain/langchain/callbacks/manager.py index d0e984dee3..19f041005a 100644 --- a/libs/langchain/langchain/callbacks/manager.py +++ b/libs/langchain/langchain/callbacks/manager.py @@ -5,12 +5,14 @@ import functools import logging import os import uuid +from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar from typing import ( TYPE_CHECKING, Any, AsyncGenerator, + Coroutine, Dict, Generator, List, @@ -370,37 +372,84 @@ def _handle_event( **kwargs: Any, ) -> None: """Generic event handler for CallbackManager.""" - message_strings: Optional[List[str]] = None - for handler in handlers: - try: - if ignore_condition_name is None or not getattr( - handler, ignore_condition_name - ): - getattr(handler, event_name)(*args, **kwargs) - except NotImplementedError as e: - if event_name == "on_chat_model_start": - if message_strings is None: - message_strings = [get_buffer_string(m) for m in args[1]] - _handle_event( - [handler], - "on_llm_start", - "ignore_llm", - args[0], - message_strings, - *args[2:], - **kwargs, - ) - else: + coros: List[Coroutine[Any, Any, Any]] = [] + + try: + message_strings: Optional[List[str]] = None + for handler in handlers: + try: + if ignore_condition_name is None or not getattr( + handler, ignore_condition_name + ): + event = getattr(handler, event_name)(*args, **kwargs) + if asyncio.iscoroutine(event): + coros.append(event) + except NotImplementedError as e: + if event_name == "on_chat_model_start": + if message_strings is None: + message_strings = [get_buffer_string(m) for m in args[1]] + _handle_event( + [handler], + "on_llm_start", + "ignore_llm", + args[0], + message_strings, + *args[2:], + **kwargs, + ) + else: + handler_name = handler.__class__.__name__ + logger.warning( + f"NotImplementedError in {handler_name}.{event_name}" + f" callback: {e}" + ) + except Exception as e: logger.warning( - f"NotImplementedError in {handler.__class__.__name__}.{event_name}" - f" callback: {e}" + f"Error in {handler.__class__.__name__}.{event_name} callback: {e}" ) - except Exception as e: - logger.warning( - f"Error in {handler.__class__.__name__}.{event_name} callback: {e}" - ) - if handler.raise_error: - raise e + if handler.raise_error: + raise e + finally: + if coros: + try: + # Raises RuntimeError if there is no current event loop. + asyncio.get_running_loop() + loop_running = True + except RuntimeError: + loop_running = False + + if loop_running: + # If we try to submit this coroutine to the running loop + # we end up in a deadlock, as we'd have gotten here from a + # running coroutine, which we cannot interrupt to run this one. + # The solution is to create a new loop in a new thread. + with ThreadPoolExecutor(1) as executor: + executor.submit(_run_coros, coros).result() + else: + _run_coros(coros) + + +def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None: + if hasattr(asyncio, "Runner"): + # Python 3.11+ + # Run the coroutines in a new event loop, taking care to + # - install signal handlers + # - run pending tasks scheduled by `coros` + # - close asyncgens and executors + # - close the loop + with asyncio.Runner() as runner: + # Run the coroutine, get the result + for coro in coros: + runner.run(coro) + + # Run pending tasks scheduled by coros until they are all done + while pending := asyncio.all_tasks(runner.get_loop()): + runner.run(asyncio.wait(pending)) + else: + # Before Python 3.11 we need to run each coroutine in a new event loop + # as the Runner api is not available. + for coro in coros: + asyncio.run(coro) async def _ahandle_event_for_handler( diff --git a/libs/langchain/tests/unit_tests/callbacks/test_callback_manager.py b/libs/langchain/tests/unit_tests/callbacks/test_callback_manager.py index 428d41f8c5..c25b9b3c77 100644 --- a/libs/langchain/tests/unit_tests/callbacks/test_callback_manager.py +++ b/libs/langchain/tests/unit_tests/callbacks/test_callback_manager.py @@ -92,6 +92,27 @@ def test_callback_manager() -> None: _test_callback_manager(manager, handler1, handler2) +def test_callback_manager_with_async() -> None: + """Test the CallbackManager.""" + handler1 = FakeCallbackHandler() + handler2 = FakeCallbackHandler() + handler3 = FakeAsyncCallbackHandler() + handler4 = FakeAsyncCallbackHandler() + manager = CallbackManager(handlers=[handler1, handler2, handler3, handler4]) + _test_callback_manager(manager, handler1, handler2, handler3, handler4) + + +@pytest.mark.asyncio +async def test_callback_manager_with_async_with_running_loop() -> None: + """Test the CallbackManager.""" + handler1 = FakeCallbackHandler() + handler2 = FakeCallbackHandler() + handler3 = FakeAsyncCallbackHandler() + handler4 = FakeAsyncCallbackHandler() + manager = CallbackManager(handlers=[handler1, handler2, handler3, handler4]) + _test_callback_manager(manager, handler1, handler2, handler3, handler4) + + def test_ignore_llm() -> None: """Test ignore llm param for callback handlers.""" handler1 = FakeCallbackHandler(ignore_llm_=True)