diff --git a/libs/core/langchain_core/callbacks/manager.py b/libs/core/langchain_core/callbacks/manager.py index 56a8e04db0..e216f47983 100644 --- a/libs/core/langchain_core/callbacks/manager.py +++ b/libs/core/langchain_core/callbacks/manager.py @@ -203,6 +203,21 @@ async def atrace_as_chain_group( await run_manager.on_chain_end({}) +Func = TypeVar("Func", bound=Callable) + + +def shielded(func: Func) -> Func: + """ + Makes so an awaitable method is always shielded from cancellation + """ + + @functools.wraps(func) + async def wrapped(*args: Any, **kwargs: Any) -> Any: + return await asyncio.shield(func(*args, **kwargs)) + + return cast(Func, wrapped) + + def handle_event( handlers: List[BaseCallbackHandler], event_name: str, @@ -293,7 +308,10 @@ def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None: with asyncio.Runner() as runner: # Run the coroutine, get the result for coro in coros: - runner.run(coro) + try: + runner.run(coro) + except Exception as e: + logger.warning(f"Error in callback coroutine: {repr(e)}") # Run pending tasks scheduled by coros until they are all done while pending := asyncio.all_tasks(runner.get_loop()): @@ -302,7 +320,10 @@ def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None: # Before Python 3.11 we need to run each coroutine in a new event loop # as the Runner api is not available. for coro in coros: - asyncio.run(coro) + try: + asyncio.run(coro) + except Exception as e: + logger.warning(f"Error in callback coroutine: {repr(e)}") async def _ahandle_event_for_handler( @@ -682,6 +703,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): inheritable_metadata=self.inheritable_metadata, ) + @shielded async def on_llm_new_token( self, token: str, @@ -706,6 +728,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): **kwargs, ) + @shielded async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Run when LLM ends running. @@ -723,6 +746,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): **kwargs, ) + @shielded async def on_llm_error( self, error: BaseException, @@ -853,6 +877,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): inheritable_metadata=self.inheritable_metadata, ) + @shielded async def on_chain_end( self, outputs: Union[Dict[str, Any], Any], **kwargs: Any ) -> None: @@ -872,6 +897,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): **kwargs, ) + @shielded async def on_chain_error( self, error: BaseException, @@ -893,6 +919,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): **kwargs, ) + @shielded async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any: """Run when agent action is received. @@ -913,6 +940,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): **kwargs, ) + @shielded async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any: """Run when agent finish is received. @@ -1000,6 +1028,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin): inheritable_metadata=self.inheritable_metadata, ) + @shielded async def on_tool_end(self, output: str, **kwargs: Any) -> None: """Run when tool ends running. @@ -1017,6 +1046,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin): **kwargs, ) + @shielded async def on_tool_error( self, error: BaseException, @@ -1100,6 +1130,7 @@ class AsyncCallbackManagerForRetrieverRun( inheritable_metadata=self.inheritable_metadata, ) + @shielded async def on_retriever_end( self, documents: Sequence[Document], **kwargs: Any ) -> None: @@ -1115,6 +1146,7 @@ class AsyncCallbackManagerForRetrieverRun( **kwargs, ) + @shielded async def on_retriever_error( self, error: BaseException,