@ -4,13 +4,15 @@ import asyncio
import functools
import logging
import uuid
from abc import ABC , abstractmethod
from concurrent . futures import ThreadPoolExecutor
from contextlib import asynccontextmanager , contextmanager
from contextvars import Context, copy_context
from contextvars import copy_context
from typing import (
TYPE_CHECKING ,
Any ,
AsyncGenerator ,
Callable ,
Coroutine ,
Dict ,
Generator ,
@ -272,25 +274,14 @@ def handle_event(
# we end up in a deadlock, as we'd have gotten here from a
# running coroutine, which we cannot interrupt to run this one.
# The solution is to create a new loop in a new thread.
with _executor_w_context ( 1 ) as executor :
executor . submit ( _run_coros , coros ) . result ( )
with ThreadPoolExecutor ( 1 ) as executor :
executor . submit (
cast ( Callable , copy_context ( ) . run ) , _run_coros , coros
) . result ( )
else :
_run_coros ( coros )
def _set_context ( context : Context ) - > None :
for var , value in context . items ( ) :
var . set ( value )
def _executor_w_context ( max_workers : Optional [ int ] = None ) - > ThreadPoolExecutor :
return ThreadPoolExecutor (
max_workers = max_workers ,
initializer = _set_context ,
initargs = ( copy_context ( ) , ) ,
)
def _run_coros ( coros : List [ Coroutine [ Any , Any , Any ] ] ) - > None :
if hasattr ( asyncio , " Runner " ) :
# Python 3.11+
@ -315,7 +306,6 @@ def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
async def _ahandle_event_for_handler (
executor : ThreadPoolExecutor ,
handler : BaseCallbackHandler ,
event_name : str ,
ignore_condition_name : Optional [ str ] ,
@ -332,13 +322,18 @@ async def _ahandle_event_for_handler(
event ( * args , * * kwargs )
else :
await asyncio . get_event_loop ( ) . run_in_executor (
executor , functools . partial ( event , * args , * * kwargs )
None ,
cast (
Callable ,
functools . partial (
copy_context ( ) . run , event , * args , * * kwargs
) ,
) ,
)
except NotImplementedError as e :
if event_name == " on_chat_model_start " :
message_strings = [ get_buffer_string ( m ) for m in args [ 1 ] ]
await _ahandle_event_for_handler (
executor ,
handler ,
" on_llm_start " ,
" ignore_llm " ,
@ -380,25 +375,23 @@ async def ahandle_event(
* args : The arguments to pass to the event handler
* * kwargs : The keyword arguments to pass to the event handler
"""
with _executor_w_context ( ) as executor :
for handler in [ h for h in handlers if h . run_inline ] :
await _ahandle_event_for_handler (
executor , handler , event_name , ignore_condition_name , * args , * * kwargs
)
await asyncio . gather (
* (
_ahandle_event_for_handler (
executor ,
handler ,
event_name ,
ignore_condition_name ,
* args ,
* * kwargs ,
)
for handler in handlers
if not handler . run_inline
for handler in [ h for h in handlers if h . run_inline ] :
await _ahandle_event_for_handler (
handler , event_name , ignore_condition_name , * args , * * kwargs
)
await asyncio . gather (
* (
_ahandle_event_for_handler (
handler ,
event_name ,
ignore_condition_name ,
* args ,
* * kwargs ,
)
for handler in handlers
if not handler . run_inline
)
)
BRM = TypeVar ( " BRM " , bound = " BaseRunManager " )
@ -526,9 +519,17 @@ class ParentRunManager(RunManager):
return manager
class AsyncRunManager ( BaseRunManager ):
class AsyncRunManager ( BaseRunManager , ABC ):
""" Async Run Manager. """
@abstractmethod
def get_sync ( self ) - > RunManager :
""" Get the equivalent sync RunManager.
Returns :
RunManager : The sync RunManager .
"""
async def on_text (
self ,
text : str ,
@ -664,6 +665,23 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
class AsyncCallbackManagerForLLMRun ( AsyncRunManager , LLMManagerMixin ) :
""" Async callback manager for LLM run. """
def get_sync ( self ) - > CallbackManagerForLLMRun :
""" Get the equivalent sync RunManager.
Returns :
CallbackManagerForLLMRun : The sync RunManager .
"""
return CallbackManagerForLLMRun (
run_id = self . run_id ,
handlers = self . handlers ,
inheritable_handlers = self . inheritable_handlers ,
parent_run_id = self . parent_run_id ,
tags = self . tags ,
inheritable_tags = self . inheritable_tags ,
metadata = self . metadata ,
inheritable_metadata = self . inheritable_metadata ,
)
async def on_llm_new_token (
self ,
token : str ,
@ -818,6 +836,23 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
class AsyncCallbackManagerForChainRun ( AsyncParentRunManager , ChainManagerMixin ) :
""" Async callback manager for chain run. """
def get_sync ( self ) - > CallbackManagerForChainRun :
""" Get the equivalent sync RunManager.
Returns :
CallbackManagerForChainRun : The sync RunManager .
"""
return CallbackManagerForChainRun (
run_id = self . run_id ,
handlers = self . handlers ,
inheritable_handlers = self . inheritable_handlers ,
parent_run_id = self . parent_run_id ,
tags = self . tags ,
inheritable_tags = self . inheritable_tags ,
metadata = self . metadata ,
inheritable_metadata = self . inheritable_metadata ,
)
async def on_chain_end (
self , outputs : Union [ Dict [ str , Any ] , Any ] , * * kwargs : Any
) - > None :
@ -948,6 +983,23 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
class AsyncCallbackManagerForToolRun ( AsyncParentRunManager , ToolManagerMixin ) :
""" Async callback manager for tool run. """
def get_sync ( self ) - > CallbackManagerForToolRun :
""" Get the equivalent sync RunManager.
Returns :
CallbackManagerForToolRun : The sync RunManager .
"""
return CallbackManagerForToolRun (
run_id = self . run_id ,
handlers = self . handlers ,
inheritable_handlers = self . inheritable_handlers ,
parent_run_id = self . parent_run_id ,
tags = self . tags ,
inheritable_tags = self . inheritable_tags ,
metadata = self . metadata ,
inheritable_metadata = self . inheritable_metadata ,
)
async def on_tool_end ( self , output : str , * * kwargs : Any ) - > None :
""" Run when tool ends running.
@ -1031,6 +1083,23 @@ class AsyncCallbackManagerForRetrieverRun(
) :
""" Async callback manager for retriever run. """
def get_sync ( self ) - > CallbackManagerForRetrieverRun :
""" Get the equivalent sync RunManager.
Returns :
CallbackManagerForRetrieverRun : The sync RunManager .
"""
return CallbackManagerForRetrieverRun (
run_id = self . run_id ,
handlers = self . handlers ,
inheritable_handlers = self . inheritable_handlers ,
parent_run_id = self . parent_run_id ,
tags = self . tags ,
inheritable_tags = self . inheritable_tags ,
metadata = self . metadata ,
inheritable_metadata = self . inheritable_metadata ,
)
async def on_retriever_end (
self , documents : Sequence [ Document ] , * * kwargs : Any
) - > None :