Support using async callback handlers with sync callback manager (#10945)

The current behaviour just calls the handler without awaiting the
coroutine, which results in exceptions/warnings, and obviously doesn't
actually execute whatever the callback handler does

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos 2023-09-28 10:39:01 +01:00 committed by GitHub
parent 48a04aed75
commit 77ce9ed6f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 99 additions and 29 deletions

View File

@ -5,12 +5,14 @@ import functools
import logging import logging
import os import os
import uuid import uuid
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager, contextmanager from contextlib import asynccontextmanager, contextmanager
from contextvars import ContextVar from contextvars import ContextVar
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
AsyncGenerator, AsyncGenerator,
Coroutine,
Dict, Dict,
Generator, Generator,
List, List,
@ -370,37 +372,84 @@ def _handle_event(
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Generic event handler for CallbackManager.""" """Generic event handler for CallbackManager."""
message_strings: Optional[List[str]] = None coros: List[Coroutine[Any, Any, Any]] = []
for handler in handlers:
try: try:
if ignore_condition_name is None or not getattr( message_strings: Optional[List[str]] = None
handler, ignore_condition_name for handler in handlers:
): try:
getattr(handler, event_name)(*args, **kwargs) if ignore_condition_name is None or not getattr(
except NotImplementedError as e: handler, ignore_condition_name
if event_name == "on_chat_model_start": ):
if message_strings is None: event = getattr(handler, event_name)(*args, **kwargs)
message_strings = [get_buffer_string(m) for m in args[1]] if asyncio.iscoroutine(event):
_handle_event( coros.append(event)
[handler], except NotImplementedError as e:
"on_llm_start", if event_name == "on_chat_model_start":
"ignore_llm", if message_strings is None:
args[0], message_strings = [get_buffer_string(m) for m in args[1]]
message_strings, _handle_event(
*args[2:], [handler],
**kwargs, "on_llm_start",
) "ignore_llm",
else: args[0],
message_strings,
*args[2:],
**kwargs,
)
else:
handler_name = handler.__class__.__name__
logger.warning(
f"NotImplementedError in {handler_name}.{event_name}"
f" callback: {e}"
)
except Exception as e:
logger.warning( logger.warning(
f"NotImplementedError in {handler.__class__.__name__}.{event_name}" f"Error in {handler.__class__.__name__}.{event_name} callback: {e}"
f" callback: {e}"
) )
except Exception as e: if handler.raise_error:
logger.warning( raise e
f"Error in {handler.__class__.__name__}.{event_name} callback: {e}" finally:
) if coros:
if handler.raise_error: try:
raise e # Raises RuntimeError if there is no current event loop.
asyncio.get_running_loop()
loop_running = True
except RuntimeError:
loop_running = False
if loop_running:
# If we try to submit this coroutine to the running loop
# we end up in a deadlock, as we'd have gotten here from a
# running coroutine, which we cannot interrupt to run this one.
# The solution is to create a new loop in a new thread.
with ThreadPoolExecutor(1) as executor:
executor.submit(_run_coros, coros).result()
else:
_run_coros(coros)
def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
if hasattr(asyncio, "Runner"):
# Python 3.11+
# Run the coroutines in a new event loop, taking care to
# - install signal handlers
# - run pending tasks scheduled by `coros`
# - close asyncgens and executors
# - close the loop
with asyncio.Runner() as runner:
# Run the coroutine, get the result
for coro in coros:
runner.run(coro)
# Run pending tasks scheduled by coros until they are all done
while pending := asyncio.all_tasks(runner.get_loop()):
runner.run(asyncio.wait(pending))
else:
# Before Python 3.11 we need to run each coroutine in a new event loop
# as the Runner api is not available.
for coro in coros:
asyncio.run(coro)
async def _ahandle_event_for_handler( async def _ahandle_event_for_handler(

View File

@ -92,6 +92,27 @@ def test_callback_manager() -> None:
_test_callback_manager(manager, handler1, handler2) _test_callback_manager(manager, handler1, handler2)
def test_callback_manager_with_async() -> None:
"""Test the CallbackManager."""
handler1 = FakeCallbackHandler()
handler2 = FakeCallbackHandler()
handler3 = FakeAsyncCallbackHandler()
handler4 = FakeAsyncCallbackHandler()
manager = CallbackManager(handlers=[handler1, handler2, handler3, handler4])
_test_callback_manager(manager, handler1, handler2, handler3, handler4)
@pytest.mark.asyncio
async def test_callback_manager_with_async_with_running_loop() -> None:
"""Test the CallbackManager."""
handler1 = FakeCallbackHandler()
handler2 = FakeCallbackHandler()
handler3 = FakeAsyncCallbackHandler()
handler4 = FakeAsyncCallbackHandler()
manager = CallbackManager(handlers=[handler1, handler2, handler3, handler4])
_test_callback_manager(manager, handler1, handler2, handler3, handler4)
def test_ignore_llm() -> None: def test_ignore_llm() -> None:
"""Test ignore llm param for callback handlers.""" """Test ignore llm param for callback handlers."""
handler1 = FakeCallbackHandler(ignore_llm_=True) handler1 = FakeCallbackHandler(ignore_llm_=True)