mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Support using async callback handlers with sync callback manager (#10945)
The current behaviour just calls the handler without awaiting the coroutine, which results in exceptions/warnings, and obviously doesn't actually execute whatever the callback handler does <!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
48a04aed75
commit
77ce9ed6f1
@ -5,12 +5,14 @@ import functools
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from contextlib import asynccontextmanager, contextmanager
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
|
Coroutine,
|
||||||
Dict,
|
Dict,
|
||||||
Generator,
|
Generator,
|
||||||
List,
|
List,
|
||||||
@ -370,37 +372,84 @@ def _handle_event(
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Generic event handler for CallbackManager."""
|
"""Generic event handler for CallbackManager."""
|
||||||
message_strings: Optional[List[str]] = None
|
coros: List[Coroutine[Any, Any, Any]] = []
|
||||||
for handler in handlers:
|
|
||||||
try:
|
try:
|
||||||
if ignore_condition_name is None or not getattr(
|
message_strings: Optional[List[str]] = None
|
||||||
handler, ignore_condition_name
|
for handler in handlers:
|
||||||
):
|
try:
|
||||||
getattr(handler, event_name)(*args, **kwargs)
|
if ignore_condition_name is None or not getattr(
|
||||||
except NotImplementedError as e:
|
handler, ignore_condition_name
|
||||||
if event_name == "on_chat_model_start":
|
):
|
||||||
if message_strings is None:
|
event = getattr(handler, event_name)(*args, **kwargs)
|
||||||
message_strings = [get_buffer_string(m) for m in args[1]]
|
if asyncio.iscoroutine(event):
|
||||||
_handle_event(
|
coros.append(event)
|
||||||
[handler],
|
except NotImplementedError as e:
|
||||||
"on_llm_start",
|
if event_name == "on_chat_model_start":
|
||||||
"ignore_llm",
|
if message_strings is None:
|
||||||
args[0],
|
message_strings = [get_buffer_string(m) for m in args[1]]
|
||||||
message_strings,
|
_handle_event(
|
||||||
*args[2:],
|
[handler],
|
||||||
**kwargs,
|
"on_llm_start",
|
||||||
)
|
"ignore_llm",
|
||||||
else:
|
args[0],
|
||||||
|
message_strings,
|
||||||
|
*args[2:],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
handler_name = handler.__class__.__name__
|
||||||
|
logger.warning(
|
||||||
|
f"NotImplementedError in {handler_name}.{event_name}"
|
||||||
|
f" callback: {e}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"NotImplementedError in {handler.__class__.__name__}.{event_name}"
|
f"Error in {handler.__class__.__name__}.{event_name} callback: {e}"
|
||||||
f" callback: {e}"
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
if handler.raise_error:
|
||||||
logger.warning(
|
raise e
|
||||||
f"Error in {handler.__class__.__name__}.{event_name} callback: {e}"
|
finally:
|
||||||
)
|
if coros:
|
||||||
if handler.raise_error:
|
try:
|
||||||
raise e
|
# Raises RuntimeError if there is no current event loop.
|
||||||
|
asyncio.get_running_loop()
|
||||||
|
loop_running = True
|
||||||
|
except RuntimeError:
|
||||||
|
loop_running = False
|
||||||
|
|
||||||
|
if loop_running:
|
||||||
|
# If we try to submit this coroutine to the running loop
|
||||||
|
# we end up in a deadlock, as we'd have gotten here from a
|
||||||
|
# running coroutine, which we cannot interrupt to run this one.
|
||||||
|
# The solution is to create a new loop in a new thread.
|
||||||
|
with ThreadPoolExecutor(1) as executor:
|
||||||
|
executor.submit(_run_coros, coros).result()
|
||||||
|
else:
|
||||||
|
_run_coros(coros)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
|
||||||
|
if hasattr(asyncio, "Runner"):
|
||||||
|
# Python 3.11+
|
||||||
|
# Run the coroutines in a new event loop, taking care to
|
||||||
|
# - install signal handlers
|
||||||
|
# - run pending tasks scheduled by `coros`
|
||||||
|
# - close asyncgens and executors
|
||||||
|
# - close the loop
|
||||||
|
with asyncio.Runner() as runner:
|
||||||
|
# Run the coroutine, get the result
|
||||||
|
for coro in coros:
|
||||||
|
runner.run(coro)
|
||||||
|
|
||||||
|
# Run pending tasks scheduled by coros until they are all done
|
||||||
|
while pending := asyncio.all_tasks(runner.get_loop()):
|
||||||
|
runner.run(asyncio.wait(pending))
|
||||||
|
else:
|
||||||
|
# Before Python 3.11 we need to run each coroutine in a new event loop
|
||||||
|
# as the Runner api is not available.
|
||||||
|
for coro in coros:
|
||||||
|
asyncio.run(coro)
|
||||||
|
|
||||||
|
|
||||||
async def _ahandle_event_for_handler(
|
async def _ahandle_event_for_handler(
|
||||||
|
@ -92,6 +92,27 @@ def test_callback_manager() -> None:
|
|||||||
_test_callback_manager(manager, handler1, handler2)
|
_test_callback_manager(manager, handler1, handler2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_callback_manager_with_async() -> None:
|
||||||
|
"""Test the CallbackManager."""
|
||||||
|
handler1 = FakeCallbackHandler()
|
||||||
|
handler2 = FakeCallbackHandler()
|
||||||
|
handler3 = FakeAsyncCallbackHandler()
|
||||||
|
handler4 = FakeAsyncCallbackHandler()
|
||||||
|
manager = CallbackManager(handlers=[handler1, handler2, handler3, handler4])
|
||||||
|
_test_callback_manager(manager, handler1, handler2, handler3, handler4)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_manager_with_async_with_running_loop() -> None:
|
||||||
|
"""Test the CallbackManager."""
|
||||||
|
handler1 = FakeCallbackHandler()
|
||||||
|
handler2 = FakeCallbackHandler()
|
||||||
|
handler3 = FakeAsyncCallbackHandler()
|
||||||
|
handler4 = FakeAsyncCallbackHandler()
|
||||||
|
manager = CallbackManager(handlers=[handler1, handler2, handler3, handler4])
|
||||||
|
_test_callback_manager(manager, handler1, handler2, handler3, handler4)
|
||||||
|
|
||||||
|
|
||||||
def test_ignore_llm() -> None:
|
def test_ignore_llm() -> None:
|
||||||
"""Test ignore llm param for callback handlers."""
|
"""Test ignore llm param for callback handlers."""
|
||||||
handler1 = FakeCallbackHandler(ignore_llm_=True)
|
handler1 = FakeCallbackHandler(ignore_llm_=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user