@ -203,6 +203,21 @@ async def atrace_as_chain_group(
await run_manager . on_chain_end ( { } )
await run_manager . on_chain_end ( { } )
Func = TypeVar ( " Func " , bound = Callable )
def shielded ( func : Func ) - > Func :
"""
Makes so an awaitable method is always shielded from cancellation
"""
@functools.wraps ( func )
async def wrapped ( * args : Any , * * kwargs : Any ) - > Any :
return await asyncio . shield ( func ( * args , * * kwargs ) )
return cast ( Func , wrapped )
def handle_event (
def handle_event (
handlers : List [ BaseCallbackHandler ] ,
handlers : List [ BaseCallbackHandler ] ,
event_name : str ,
event_name : str ,
@ -293,7 +308,10 @@ def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
with asyncio . Runner ( ) as runner :
with asyncio . Runner ( ) as runner :
# Run the coroutine, get the result
# Run the coroutine, get the result
for coro in coros :
for coro in coros :
runner . run ( coro )
try :
runner . run ( coro )
except Exception as e :
logger . warning ( f " Error in callback coroutine: { repr ( e ) } " )
# Run pending tasks scheduled by coros until they are all done
# Run pending tasks scheduled by coros until they are all done
while pending := asyncio . all_tasks ( runner . get_loop ( ) ) :
while pending := asyncio . all_tasks ( runner . get_loop ( ) ) :
@ -302,7 +320,10 @@ def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
# Before Python 3.11 we need to run each coroutine in a new event loop
# Before Python 3.11 we need to run each coroutine in a new event loop
# as the Runner api is not available.
# as the Runner api is not available.
for coro in coros :
for coro in coros :
asyncio . run ( coro )
try :
asyncio . run ( coro )
except Exception as e :
logger . warning ( f " Error in callback coroutine: { repr ( e ) } " )
async def _ahandle_event_for_handler (
async def _ahandle_event_for_handler (
@ -682,6 +703,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
inheritable_metadata = self . inheritable_metadata ,
inheritable_metadata = self . inheritable_metadata ,
)
)
@shielded
async def on_llm_new_token (
async def on_llm_new_token (
self ,
self ,
token : str ,
token : str ,
@ -706,6 +728,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
* * kwargs ,
* * kwargs ,
)
)
@shielded
async def on_llm_end ( self , response : LLMResult , * * kwargs : Any ) - > None :
async def on_llm_end ( self , response : LLMResult , * * kwargs : Any ) - > None :
""" Run when LLM ends running.
""" Run when LLM ends running.
@ -723,6 +746,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
* * kwargs ,
* * kwargs ,
)
)
@shielded
async def on_llm_error (
async def on_llm_error (
self ,
self ,
error : BaseException ,
error : BaseException ,
@ -853,6 +877,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
inheritable_metadata = self . inheritable_metadata ,
inheritable_metadata = self . inheritable_metadata ,
)
)
@shielded
async def on_chain_end (
async def on_chain_end (
self , outputs : Union [ Dict [ str , Any ] , Any ] , * * kwargs : Any
self , outputs : Union [ Dict [ str , Any ] , Any ] , * * kwargs : Any
) - > None :
) - > None :
@ -872,6 +897,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
* * kwargs ,
* * kwargs ,
)
)
@shielded
async def on_chain_error (
async def on_chain_error (
self ,
self ,
error : BaseException ,
error : BaseException ,
@ -893,6 +919,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
* * kwargs ,
* * kwargs ,
)
)
@shielded
async def on_agent_action ( self , action : AgentAction , * * kwargs : Any ) - > Any :
async def on_agent_action ( self , action : AgentAction , * * kwargs : Any ) - > Any :
""" Run when agent action is received.
""" Run when agent action is received.
@ -913,6 +940,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
* * kwargs ,
* * kwargs ,
)
)
@shielded
async def on_agent_finish ( self , finish : AgentFinish , * * kwargs : Any ) - > Any :
async def on_agent_finish ( self , finish : AgentFinish , * * kwargs : Any ) - > Any :
""" Run when agent finish is received.
""" Run when agent finish is received.
@ -1000,6 +1028,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
inheritable_metadata = self . inheritable_metadata ,
inheritable_metadata = self . inheritable_metadata ,
)
)
@shielded
async def on_tool_end ( self , output : str , * * kwargs : Any ) - > None :
async def on_tool_end ( self , output : str , * * kwargs : Any ) - > None :
""" Run when tool ends running.
""" Run when tool ends running.
@ -1017,6 +1046,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
* * kwargs ,
* * kwargs ,
)
)
@shielded
async def on_tool_error (
async def on_tool_error (
self ,
self ,
error : BaseException ,
error : BaseException ,
@ -1100,6 +1130,7 @@ class AsyncCallbackManagerForRetrieverRun(
inheritable_metadata = self . inheritable_metadata ,
inheritable_metadata = self . inheritable_metadata ,
)
)
@shielded
async def on_retriever_end (
async def on_retriever_end (
self , documents : Sequence [ Document ] , * * kwargs : Any
self , documents : Sequence [ Document ] , * * kwargs : Any
) - > None :
) - > None :
@ -1115,6 +1146,7 @@ class AsyncCallbackManagerForRetrieverRun(
* * kwargs ,
* * kwargs ,
)
)
@shielded
async def on_retriever_error (
async def on_retriever_error (
self ,
self ,
error : BaseException ,
error : BaseException ,