mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Shield callback methods from cancellation: Fix interrupted runs marked as pending forever (#17010)
<!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes if applicable, - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
e7b3290d30
commit
f0ffebb944
@ -203,6 +203,21 @@ async def atrace_as_chain_group(
|
||||
await run_manager.on_chain_end({})
|
||||
|
||||
|
||||
Func = TypeVar("Func", bound=Callable)
|
||||
|
||||
|
||||
def shielded(func: Func) -> Func:
|
||||
"""
|
||||
Makes so an awaitable method is always shielded from cancellation
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||
return await asyncio.shield(func(*args, **kwargs))
|
||||
|
||||
return cast(Func, wrapped)
|
||||
|
||||
|
||||
def handle_event(
|
||||
handlers: List[BaseCallbackHandler],
|
||||
event_name: str,
|
||||
@ -293,7 +308,10 @@ def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
|
||||
with asyncio.Runner() as runner:
|
||||
# Run the coroutine, get the result
|
||||
for coro in coros:
|
||||
runner.run(coro)
|
||||
try:
|
||||
runner.run(coro)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in callback coroutine: {repr(e)}")
|
||||
|
||||
# Run pending tasks scheduled by coros until they are all done
|
||||
while pending := asyncio.all_tasks(runner.get_loop()):
|
||||
@ -302,7 +320,10 @@ def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
|
||||
# Before Python 3.11 we need to run each coroutine in a new event loop
|
||||
# as the Runner api is not available.
|
||||
for coro in coros:
|
||||
asyncio.run(coro)
|
||||
try:
|
||||
asyncio.run(coro)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in callback coroutine: {repr(e)}")
|
||||
|
||||
|
||||
async def _ahandle_event_for_handler(
|
||||
@ -682,6 +703,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
|
||||
@shielded
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
@ -706,6 +728,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@shielded
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Run when LLM ends running.
|
||||
|
||||
@ -723,6 +746,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@shielded
|
||||
async def on_llm_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
@ -853,6 +877,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
|
||||
@shielded
|
||||
async def on_chain_end(
|
||||
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
|
||||
) -> None:
|
||||
@ -872,6 +897,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@shielded
|
||||
async def on_chain_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
@ -893,6 +919,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@shielded
|
||||
async def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
|
||||
"""Run when agent action is received.
|
||||
|
||||
@ -913,6 +940,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@shielded
|
||||
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run when agent finish is received.
|
||||
|
||||
@ -1000,6 +1028,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
|
||||
@shielded
|
||||
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||
"""Run when tool ends running.
|
||||
|
||||
@ -1017,6 +1046,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@shielded
|
||||
async def on_tool_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
@ -1100,6 +1130,7 @@ class AsyncCallbackManagerForRetrieverRun(
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
)
|
||||
|
||||
@shielded
|
||||
async def on_retriever_end(
|
||||
self, documents: Sequence[Document], **kwargs: Any
|
||||
) -> None:
|
||||
@ -1115,6 +1146,7 @@ class AsyncCallbackManagerForRetrieverRun(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@shielded
|
||||
async def on_retriever_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
|
Loading…
Reference in New Issue
Block a user