Shield callback methods from cancellation: Fix interrupted runs marked as pending forever (#17010)

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos 2024-02-05 12:09:47 -08:00 committed by GitHub
parent e7b3290d30
commit f0ffebb944
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -203,6 +203,21 @@ async def atrace_as_chain_group(
await run_manager.on_chain_end({})
Func = TypeVar("Func", bound=Callable)
def shielded(func: Func) -> Func:
"""
Makes so an awaitable method is always shielded from cancellation
"""
@functools.wraps(func)
async def wrapped(*args: Any, **kwargs: Any) -> Any:
return await asyncio.shield(func(*args, **kwargs))
return cast(Func, wrapped)
def handle_event(
handlers: List[BaseCallbackHandler],
event_name: str,
@ -293,7 +308,10 @@ def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
with asyncio.Runner() as runner:
# Run the coroutine, get the result
for coro in coros:
runner.run(coro)
try:
runner.run(coro)
except Exception as e:
logger.warning(f"Error in callback coroutine: {repr(e)}")
# Run pending tasks scheduled by coros until they are all done
while pending := asyncio.all_tasks(runner.get_loop()):
@ -302,7 +320,10 @@ def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
# Before Python 3.11 we need to run each coroutine in a new event loop
# as the Runner api is not available.
for coro in coros:
asyncio.run(coro)
try:
asyncio.run(coro)
except Exception as e:
logger.warning(f"Error in callback coroutine: {repr(e)}")
async def _ahandle_event_for_handler(
@ -682,6 +703,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
inheritable_metadata=self.inheritable_metadata,
)
@shielded
async def on_llm_new_token(
self,
token: str,
@ -706,6 +728,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
**kwargs,
)
@shielded
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running.
@ -723,6 +746,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
**kwargs,
)
@shielded
async def on_llm_error(
self,
error: BaseException,
@ -853,6 +877,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
inheritable_metadata=self.inheritable_metadata,
)
@shielded
async def on_chain_end(
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
) -> None:
@ -872,6 +897,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
**kwargs,
)
@shielded
async def on_chain_error(
self,
error: BaseException,
@ -893,6 +919,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
**kwargs,
)
@shielded
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
"""Run when agent action is received.
@ -913,6 +940,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
**kwargs,
)
@shielded
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run when agent finish is received.
@ -1000,6 +1028,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
inheritable_metadata=self.inheritable_metadata,
)
@shielded
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running.
@ -1017,6 +1046,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
**kwargs,
)
@shielded
async def on_tool_error(
self,
error: BaseException,
@ -1100,6 +1130,7 @@ class AsyncCallbackManagerForRetrieverRun(
inheritable_metadata=self.inheritable_metadata,
)
@shielded
async def on_retriever_end(
self, documents: Sequence[Document], **kwargs: Any
) -> None:
@ -1115,6 +1146,7 @@ class AsyncCallbackManagerForRetrieverRun(
**kwargs,
)
@shielded
async def on_retriever_error(
self,
error: BaseException,