Catch Base Exception (#10607)

Currently the on_*_error isn't called for CancellationError's. This is
because in python 3.8, the inheritance changed from Exception to
BaseException


https://docs.python.org/3/library/asyncio-exceptions.html#asyncio.CancelledError
pull/10740/head
William FH 1 year ago committed by GitHub
parent 39c1c94272
commit 287c81db89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -282,7 +282,7 @@ class AgentExecutorIterator(BaseAgentExecutorIterator):
return self._call_next() return self._call_next()
except StopIteration: except StopIteration:
raise raise
except (KeyboardInterrupt, Exception) as e: except BaseException as e:
if self.run_manager: if self.run_manager:
self.run_manager.on_chain_error(e) self.run_manager.on_chain_error(e)
raise raise
@ -304,7 +304,7 @@ class AgentExecutorIterator(BaseAgentExecutorIterator):
await self.timeout_manager.__aexit__(None, None, None) await self.timeout_manager.__aexit__(None, None, None)
self.timeout_manager = None self.timeout_manager = None
return await self._astop() return await self._astop()
except (KeyboardInterrupt, Exception) as e: except BaseException as e:
if self.run_manager: if self.run_manager:
assert isinstance(self.run_manager, AsyncCallbackManagerForChainRun) assert isinstance(self.run_manager, AsyncCallbackManagerForChainRun)
await self.run_manager.on_chain_error(e) await self.run_manager.on_chain_error(e)

@ -1,5 +1,5 @@
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult
@ -255,9 +255,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.step += 1 self.step += 1
self.llm_streams += 1 self.llm_streams += 1
def on_llm_error( def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors.""" """Run when LLM errors."""
self.step += 1 self.step += 1
self.errors += 1 self.errors += 1
@ -296,9 +294,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
aim.Text(outputs_res["output"]), name="on_chain_end", context=resp aim.Text(outputs_res["output"]), name="on_chain_end", context=resp
) )
def on_chain_error( def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when chain errors.""" """Run when chain errors."""
self.step += 1 self.step += 1
self.errors += 1 self.errors += 1
@ -329,9 +325,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self._run.track(aim.Text(output), name="on_tool_end", context=resp) self._run.track(aim.Text(output), name="on_tool_end", context=resp)
def on_tool_error( def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when tool errors.""" """Run when tool errors."""
self.step += 1 self.step += 1
self.errors += 1 self.errors += 1

@ -1,6 +1,6 @@
import os import os
import warnings import warnings
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional
from packaging.version import parse from packaging.version import parse
@ -236,9 +236,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
# Push the records to Argilla # Push the records to Argilla
self.dataset.push_to_argilla() self.dataset.push_to_argilla()
def on_llm_error( def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing when LLM outputs an error.""" """Do nothing when LLM outputs an error."""
pass pass
@ -313,9 +311,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
# Push the records to Argilla # Push the records to Argilla
self.dataset.push_to_argilla() self.dataset.push_to_argilla()
def on_chain_error( def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing when LLM chain outputs an error.""" """Do nothing when LLM chain outputs an error."""
pass pass
@ -342,9 +338,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
"""Do nothing when tool ends.""" """Do nothing when tool ends."""
pass pass
def on_tool_error( def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing when tool outputs an error.""" """Do nothing when tool outputs an error."""
pass pass

@ -1,5 +1,5 @@
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import import_pandas from langchain.callbacks.utils import import_pandas
@ -163,9 +163,7 @@ class ArizeCallbackHandler(BaseCallbackHandler):
else: else:
print(f'❌ Logging failed "{response_from_arize.text}"') print(f'❌ Logging failed "{response_from_arize.text}"')
def on_llm_error( def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing.""" """Do nothing."""
pass pass
@ -178,9 +176,7 @@ class ArizeCallbackHandler(BaseCallbackHandler):
"""Do nothing.""" """Do nothing."""
pass pass
def on_chain_error( def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing.""" """Do nothing."""
pass pass
@ -205,9 +201,7 @@ class ArizeCallbackHandler(BaseCallbackHandler):
) -> None: ) -> None:
pass pass
def on_tool_error( def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
pass pass
def on_text(self, text: str, **kwargs: Any) -> None: def on_text(self, text: str, **kwargs: Any) -> None:

@ -6,7 +6,7 @@ import uuid
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from time import time from time import time
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Optional
import numpy as np import numpy as np
@ -257,17 +257,13 @@ class ArthurCallbackHandler(BaseCallbackHandler):
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""On chain end, do nothing.""" """On chain end, do nothing."""
def on_llm_error( def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing when LLM outputs an error.""" """Do nothing when LLM outputs an error."""
def on_llm_new_token(self, token: str, **kwargs: Any) -> None: def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""On new token, pass.""" """On new token, pass."""
def on_chain_error( def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing when LLM chain outputs an error.""" """Do nothing when LLM chain outputs an error."""
def on_tool_start( def on_tool_start(
@ -290,9 +286,7 @@ class ArthurCallbackHandler(BaseCallbackHandler):
) -> None: ) -> None:
"""Do nothing when tool ends.""" """Do nothing when tool ends."""
def on_tool_error( def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing when tool outputs an error.""" """Do nothing when tool outputs an error."""
def on_text(self, text: str, **kwargs: Any) -> None: def on_text(self, text: str, **kwargs: Any) -> None:

@ -18,7 +18,7 @@ class RetrieverManagerMixin:
def on_retriever_error( def on_retriever_error(
self, self,
error: Union[Exception, KeyboardInterrupt], error: BaseException,
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
@ -69,7 +69,7 @@ class LLMManagerMixin:
def on_llm_error( def on_llm_error(
self, self,
error: Union[Exception, KeyboardInterrupt], error: BaseException,
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
@ -93,7 +93,7 @@ class ChainManagerMixin:
def on_chain_error( def on_chain_error(
self, self,
error: Union[Exception, KeyboardInterrupt], error: BaseException,
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
@ -137,7 +137,7 @@ class ToolManagerMixin:
def on_tool_error( def on_tool_error(
self, self,
error: Union[Exception, KeyboardInterrupt], error: BaseException,
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
@ -344,7 +344,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_llm_error( async def on_llm_error(
self, self,
error: Union[Exception, KeyboardInterrupt], error: BaseException,
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
@ -379,7 +379,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_chain_error( async def on_chain_error(
self, self,
error: Union[Exception, KeyboardInterrupt], error: BaseException,
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
@ -414,7 +414,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_tool_error( async def on_tool_error(
self, self,
error: Union[Exception, KeyboardInterrupt], error: BaseException,
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
@ -492,7 +492,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_retriever_error( async def on_retriever_error(
self, self,
error: Union[Exception, KeyboardInterrupt], error: BaseException,
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,

@ -1,7 +1,7 @@
import tempfile import tempfile
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Union from typing import Any, Dict, List, Optional, Sequence
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import ( from langchain.callbacks.utils import (
@ -155,9 +155,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
if self.stream_logs: if self.stream_logs:
self.logger.report_text(generation_resp) self.logger.report_text(generation_resp)
def on_llm_error( def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors.""" """Run when LLM errors."""
self.step += 1 self.step += 1
self.errors += 1 self.errors += 1
@ -210,9 +208,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
if self.stream_logs: if self.stream_logs:
self.logger.report_text(resp) self.logger.report_text(resp)
def on_chain_error( def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when chain errors.""" """Run when chain errors."""
self.step += 1 self.step += 1
self.errors += 1 self.errors += 1
@ -250,9 +246,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
if self.stream_logs: if self.stream_logs:
self.logger.report_text(resp) self.logger.report_text(resp)
def on_tool_error( def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when tool errors.""" """Run when tool errors."""
self.step += 1 self.step += 1
self.errors += 1 self.errors += 1

@ -1,7 +1,7 @@
import tempfile import tempfile
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Union from typing import Any, Callable, Dict, List, Optional, Sequence
import langchain import langchain
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
@ -223,9 +223,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self._log_text_metrics(output_complexity_metrics, step=self.step) self._log_text_metrics(output_complexity_metrics, step=self.step)
self._log_text_metrics(output_custom_metrics, step=self.step) self._log_text_metrics(output_custom_metrics, step=self.step)
def on_llm_error( def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors.""" """Run when LLM errors."""
self.step += 1 self.step += 1
self.errors += 1 self.errors += 1
@ -280,9 +278,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
f"Output Value for {chain_output_key} will not be logged" f"Output Value for {chain_output_key} will not be logged"
) )
def on_chain_error( def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when chain errors.""" """Run when chain errors."""
self.step += 1 self.step += 1
self.errors += 1 self.errors += 1
@ -320,9 +316,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
resp.update({"output": output}) resp.update({"output": output})
self.action_records.append(resp) self.action_records.append(resp)
def on_tool_error( def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when tool errors.""" """Run when tool errors."""
self.step += 1 self.step += 1
self.errors += 1 self.errors += 1

@ -128,9 +128,7 @@ class DeepEvalCallbackHandler(BaseCallbackHandler):
callbacks.""" callbacks."""
) )
def on_llm_error( def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing when LLM outputs an error.""" """Do nothing when LLM outputs an error."""
pass pass
@ -144,9 +142,7 @@ class DeepEvalCallbackHandler(BaseCallbackHandler):
"""Do nothing when chain ends.""" """Do nothing when chain ends."""
pass pass
def on_chain_error( def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing when LLM chain outputs an error.""" """Do nothing when LLM chain outputs an error."""
pass pass
@ -173,9 +169,7 @@ class DeepEvalCallbackHandler(BaseCallbackHandler):
"""Do nothing when tool ends.""" """Do nothing when tool ends."""
pass pass
def on_tool_error( def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing when tool outputs an error.""" """Do nothing when tool outputs an error."""
pass pass

@ -3,7 +3,7 @@ from __future__ import annotations
import logging import logging
from copy import deepcopy from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import ( from langchain.callbacks.utils import (
@ -221,9 +221,7 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
) )
self.deck.append(self.markdown_renderer().to_html(generation.text)) self.deck.append(self.markdown_renderer().to_html(generation.text))
def on_llm_error( def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors.""" """Run when LLM errors."""
self.step += 1 self.step += 1
self.errors += 1 self.errors += 1
@ -266,9 +264,7 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
) )
def on_chain_error( def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when chain errors.""" """Run when chain errors."""
self.step += 1 self.step += 1
self.errors += 1 self.errors += 1
@ -306,9 +302,7 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
) )
def on_tool_error( def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when tool errors.""" """Run when tool errors."""
self.step += 1 self.step += 1
self.errors += 1 self.errors += 1

@ -1,5 +1,5 @@
import time import time
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult
@ -113,9 +113,7 @@ class InfinoCallbackHandler(BaseCallbackHandler):
for generation in generations: for generation in generations:
self._send_to_infino("prompt_response", generation.text, is_ts=False) self._send_to_infino("prompt_response", generation.text, is_ts=False)
def on_llm_error( def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Set the error flag.""" """Set the error flag."""
self.error = 1 self.error = 1
@ -129,9 +127,7 @@ class InfinoCallbackHandler(BaseCallbackHandler):
"""Do nothing when LLM chain ends.""" """Do nothing when LLM chain ends."""
pass pass
def on_chain_error( def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Need to log the error.""" """Need to log the error."""
pass pass
@ -158,9 +154,7 @@ class InfinoCallbackHandler(BaseCallbackHandler):
"""Do nothing when tool ends.""" """Do nothing when tool ends."""
pass pass
def on_tool_error( def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing when tool outputs an error.""" """Do nothing when tool outputs an error."""
pass pass

@ -334,9 +334,7 @@ class LabelStudioCallbackHandler(BaseCallbackHandler):
# Pop current run from `self.runs` # Pop current run from `self.runs`
self.payload.pop(run_id) self.payload.pop(run_id)
def on_llm_error( def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing when LLM outputs an error.""" """Do nothing when LLM outputs an error."""
pass pass
@ -348,9 +346,7 @@ class LabelStudioCallbackHandler(BaseCallbackHandler):
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
pass pass
def on_chain_error( def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing when LLM chain outputs an error.""" """Do nothing when LLM chain outputs an error."""
pass pass
@ -377,9 +373,7 @@ class LabelStudioCallbackHandler(BaseCallbackHandler):
"""Do nothing when tool ends.""" """Do nothing when tool ends."""
pass pass
def on_tool_error( def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing when tool outputs an error.""" """Do nothing when tool outputs an error."""
pass pass

@ -406,7 +406,7 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
def on_chain_error( def on_chain_error(
self, self,
error: Union[Exception, KeyboardInterrupt], error: BaseException,
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Union[UUID, None] = None, parent_run_id: Union[UUID, None] = None,
@ -423,7 +423,7 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
def on_tool_error( def on_tool_error(
self, self,
error: Union[Exception, KeyboardInterrupt], error: BaseException,
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Union[UUID, None] = None, parent_run_id: Union[UUID, None] = None,
@ -440,7 +440,7 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
def on_llm_error( def on_llm_error(
self, self,
error: Union[Exception, KeyboardInterrupt], error: BaseException,
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Union[UUID, None] = None, parent_run_id: Union[UUID, None] = None,

@ -707,7 +707,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
def on_llm_error( def on_llm_error(
self, self,
error: Union[Exception, KeyboardInterrupt], error: BaseException,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when LLM errors. """Run when LLM errors.
@ -773,7 +773,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
async def on_llm_error( async def on_llm_error(
self, self,
error: Union[Exception, KeyboardInterrupt], error: BaseException,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when LLM errors. """Run when LLM errors.
@ -985,7 +985,7 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
def on_tool_error( def on_tool_error(
self, self,
error: Union[Exception, KeyboardInterrupt], error: BaseException,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when tool errors. """Run when tool errors.
@ -1027,7 +1027,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
async def on_tool_error( async def on_tool_error(
self, self,
error: Union[Exception, KeyboardInterrupt], error: BaseException,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when tool errors. """Run when tool errors.
@ -1069,7 +1069,7 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin):
def on_retriever_error( def on_retriever_error(
self, self,
error: Union[Exception, KeyboardInterrupt], error: BaseException,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when retriever errors.""" """Run when retriever errors."""
@ -1108,7 +1108,7 @@ class AsyncCallbackManagerForRetrieverRun(
async def on_retriever_error( async def on_retriever_error(
self, self,
error: Union[Exception, KeyboardInterrupt], error: BaseException,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when retriever errors.""" """Run when retriever errors."""

@ -384,9 +384,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.mlflg.html(dependency_tree, "dep-" + hash_string(generation.text)) self.mlflg.html(dependency_tree, "dep-" + hash_string(generation.text))
self.mlflg.html(entities, "ent-" + hash_string(generation.text)) self.mlflg.html(entities, "ent-" + hash_string(generation.text))
def on_llm_error( def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors.""" """Run when LLM errors."""
self.metrics["step"] += 1 self.metrics["step"] += 1
self.metrics["errors"] += 1 self.metrics["errors"] += 1
@ -434,9 +432,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.records["action_records"].append(resp) self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"chain_end_{chain_ends}") self.mlflg.jsonf(resp, f"chain_end_{chain_ends}")
def on_chain_error( def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when chain errors.""" """Run when chain errors."""
self.metrics["step"] += 1 self.metrics["step"] += 1
self.metrics["errors"] += 1 self.metrics["errors"] += 1
@ -480,9 +476,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.records["action_records"].append(resp) self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"tool_end_{tool_ends}") self.mlflg.jsonf(resp, f"tool_end_{tool_ends}")
def on_tool_error( def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when tool errors.""" """Run when tool errors."""
self.metrics["step"] += 1 self.metrics["step"] += 1
self.metrics["errors"] += 1 self.metrics["errors"] += 1

@ -3,7 +3,7 @@ import os
import shutil import shutil
import tempfile import tempfile
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import ( from langchain.callbacks.utils import (
@ -121,9 +121,7 @@ class SageMakerCallbackHandler(BaseCallbackHandler):
f"llm_end_{llm_ends}_generation_{idx}", f"llm_end_{llm_ends}_generation_{idx}",
) )
def on_llm_error( def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors.""" """Run when LLM errors."""
self.metrics["step"] += 1 self.metrics["step"] += 1
self.metrics["errors"] += 1 self.metrics["errors"] += 1
@ -164,9 +162,7 @@ class SageMakerCallbackHandler(BaseCallbackHandler):
self.jsonf(resp, self.temp_dir, f"chain_end_{chain_ends}") self.jsonf(resp, self.temp_dir, f"chain_end_{chain_ends}")
def on_chain_error( def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when chain errors.""" """Run when chain errors."""
self.metrics["step"] += 1 self.metrics["step"] += 1
self.metrics["errors"] += 1 self.metrics["errors"] += 1
@ -202,9 +198,7 @@ class SageMakerCallbackHandler(BaseCallbackHandler):
self.jsonf(resp, self.temp_dir, f"tool_end_{tool_ends}") self.jsonf(resp, self.temp_dir, f"tool_end_{tool_ends}")
def on_tool_error( def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when tool errors.""" """Run when tool errors."""
self.metrics["step"] += 1 self.metrics["step"] += 1
self.metrics["errors"] += 1 self.metrics["errors"] += 1

@ -1,5 +1,5 @@
"""Callback Handler that prints to std out.""" """Callback Handler that prints to std out."""
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult
@ -27,9 +27,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
"""Do nothing.""" """Do nothing."""
pass pass
def on_llm_error( def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing.""" """Do nothing."""
pass pass
@ -44,9 +42,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
"""Print out that we finished a chain.""" """Print out that we finished a chain."""
print("\n\033[1m> Finished chain.\033[0m") print("\n\033[1m> Finished chain.\033[0m")
def on_chain_error( def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing.""" """Do nothing."""
pass pass
@ -80,9 +76,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
if llm_prefix is not None: if llm_prefix is not None:
print_text(f"\n{llm_prefix}") print_text(f"\n{llm_prefix}")
def on_tool_error( def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing.""" """Do nothing."""
pass pass

@ -37,9 +37,7 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self.done.set() self.done.set()
async def on_llm_error( async def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
self.done.set() self.done.set()
# TODO implement the other methods # TODO implement the other methods

@ -1,6 +1,6 @@
"""Callback Handler streams to stdout on new llm token.""" """Callback Handler streams to stdout on new llm token."""
import sys import sys
from typing import Any, Dict, List, Union from typing import Any, Dict, List
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult
@ -31,9 +31,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running.""" """Run when LLM ends running."""
def on_llm_error( def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors.""" """Run when LLM errors."""
def on_chain_start( def on_chain_start(
@ -44,9 +42,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running.""" """Run when chain ends running."""
def on_chain_error( def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when chain errors.""" """Run when chain errors."""
def on_tool_start( def on_tool_start(
@ -61,9 +57,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
def on_tool_end(self, output: str, **kwargs: Any) -> None: def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running.""" """Run when tool ends running."""
def on_tool_error( def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when tool errors.""" """Run when tool errors."""
def on_text(self, text: str, **kwargs: Any) -> None: def on_text(self, text: str, **kwargs: Any) -> None:

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.streamlit.mutable_expander import MutableExpander from langchain.callbacks.streamlit.mutable_expander import MutableExpander
@ -163,9 +163,7 @@ class LLMThought:
# data is redundant # data is redundant
self._reset_llm_token_stream() self._reset_llm_token_stream()
def on_llm_error( def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
self._container.markdown("**LLM encountered an error...**") self._container.markdown("**LLM encountered an error...**")
self._container.exception(error) self._container.exception(error)
@ -191,9 +189,7 @@ class LLMThought:
) -> None: ) -> None:
self._container.markdown(f"**{output}**") self._container.markdown(f"**{output}**")
def on_tool_error( def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
self._container.markdown("**Tool encountered an error...**") self._container.markdown("**Tool encountered an error...**")
self._container.exception(error) self._container.exception(error)
@ -353,9 +349,7 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
self._require_current_thought().on_llm_end(response, **kwargs) self._require_current_thought().on_llm_end(response, **kwargs)
self._prune_old_thought_containers() self._prune_old_thought_containers()
def on_llm_error( def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
self._require_current_thought().on_llm_error(error, **kwargs) self._require_current_thought().on_llm_error(error, **kwargs)
self._prune_old_thought_containers() self._prune_old_thought_containers()
@ -378,9 +372,7 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
) )
self._complete_current_thought() self._complete_current_thought()
def on_tool_error( def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
self._require_current_thought().on_tool_error(error, **kwargs) self._require_current_thought().on_tool_error(error, **kwargs)
self._prune_old_thought_containers() self._prune_old_thought_containers()
@ -401,9 +393,7 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
pass pass
def on_chain_error( def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
pass pass
def on_agent_action( def on_agent_action(

@ -211,7 +211,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
def on_llm_error( def on_llm_error(
self, self,
error: Union[Exception, KeyboardInterrupt], error: BaseException,
*, *,
run_id: UUID, run_id: UUID,
**kwargs: Any, **kwargs: Any,
@ -294,7 +294,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
def on_chain_error( def on_chain_error(
self, self,
error: Union[Exception, KeyboardInterrupt], error: BaseException,
*, *,
inputs: Optional[Dict[str, Any]] = None, inputs: Optional[Dict[str, Any]] = None,
run_id: UUID, run_id: UUID,
@ -365,7 +365,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
def on_tool_error( def on_tool_error(
self, self,
error: Union[Exception, KeyboardInterrupt], error: BaseException,
*, *,
run_id: UUID, run_id: UUID,
**kwargs: Any, **kwargs: Any,
@ -420,7 +420,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
def on_retriever_error( def on_retriever_error(
self, self,
error: Union[Exception, KeyboardInterrupt], error: BaseException,
*, *,
run_id: UUID, run_id: UUID,
**kwargs: Any, **kwargs: Any,

@ -282,9 +282,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
if self.stream_logs: if self.stream_logs:
self.run.log(generation_resp) self.run.log(generation_resp)
def on_llm_error( def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when LLM errors.""" """Run when LLM errors."""
self.step += 1 self.step += 1
self.errors += 1 self.errors += 1
@ -337,9 +335,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
if self.stream_logs: if self.stream_logs:
self.run.log(resp) self.run.log(resp)
def on_chain_error( def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when chain errors.""" """Run when chain errors."""
self.step += 1 self.step += 1
self.errors += 1 self.errors += 1
@ -377,9 +373,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
if self.stream_logs: if self.stream_logs:
self.run.log(resp) self.run.log(resp)
def on_tool_error( def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Run when tool errors.""" """Run when tool errors."""
self.step += 1 self.step += 1
self.errors += 1 self.errors += 1

@ -287,7 +287,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
if new_arg_supported if new_arg_supported
else self._call(inputs) else self._call(inputs)
) )
except (KeyboardInterrupt, Exception) as e: except BaseException as e:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
raise e raise e
run_manager.on_chain_end(outputs) run_manager.on_chain_end(outputs)
@ -356,7 +356,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
if new_arg_supported if new_arg_supported
else await self._acall(inputs) else await self._acall(inputs)
) )
except (KeyboardInterrupt, Exception) as e: except BaseException as e:
await run_manager.on_chain_error(e) await run_manager.on_chain_error(e)
raise e raise e
await run_manager.on_chain_end(outputs) await run_manager.on_chain_end(outputs)

@ -186,7 +186,7 @@ class LLMChain(Chain):
) )
try: try:
response = self.generate(input_list, run_manager=run_manager) response = self.generate(input_list, run_manager=run_manager)
except (KeyboardInterrupt, Exception) as e: except BaseException as e:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
raise e raise e
outputs = self.create_outputs(response) outputs = self.create_outputs(response)
@ -206,7 +206,7 @@ class LLMChain(Chain):
) )
try: try:
response = await self.agenerate(input_list, run_manager=run_manager) response = await self.agenerate(input_list, run_manager=run_manager)
except (KeyboardInterrupt, Exception) as e: except BaseException as e:
await run_manager.on_chain_error(e) await run_manager.on_chain_error(e)
raise e raise e
outputs = self.create_outputs(response) outputs = self.create_outputs(response)

@ -186,7 +186,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
else: else:
generation += chunk generation += chunk
assert generation is not None assert generation is not None
except (KeyboardInterrupt, Exception) as e: except BaseException as e:
run_manager.on_llm_error(e) run_manager.on_llm_error(e)
raise e raise e
else: else:
@ -233,7 +233,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
else: else:
generation += chunk generation += chunk
assert generation is not None assert generation is not None
except (KeyboardInterrupt, Exception) as e: except BaseException as e:
await run_manager.on_llm_error(e) await run_manager.on_llm_error(e)
raise e raise e
else: else:
@ -303,7 +303,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
**kwargs, **kwargs,
) )
) )
except (KeyboardInterrupt, Exception) as e: except BaseException as e:
if run_managers: if run_managers:
run_managers[i].on_llm_error(e) run_managers[i].on_llm_error(e)
raise e raise e
@ -364,7 +364,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
) )
exceptions = [] exceptions = []
for i, res in enumerate(results): for i, res in enumerate(results):
if isinstance(res, Exception): if isinstance(res, BaseException):
if run_managers: if run_managers:
await run_managers[i].on_llm_error(res) await run_managers[i].on_llm_error(res)
exceptions.append(res) exceptions.append(res)

@ -388,7 +388,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
else: else:
generation += chunk generation += chunk
assert generation is not None assert generation is not None
except (KeyboardInterrupt, Exception) as e: except BaseException as e:
run_manager.on_llm_error(e) run_manager.on_llm_error(e)
raise e raise e
else: else:
@ -435,7 +435,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
else: else:
generation += chunk generation += chunk
assert generation is not None assert generation is not None
except (KeyboardInterrupt, Exception) as e: except BaseException as e:
await run_manager.on_llm_error(e) await run_manager.on_llm_error(e)
raise e raise e
else: else:
@ -523,7 +523,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
if new_arg_supported if new_arg_supported
else self._generate(prompts, stop=stop) else self._generate(prompts, stop=stop)
) )
except (KeyboardInterrupt, Exception) as e: except BaseException as e:
for run_manager in run_managers: for run_manager in run_managers:
run_manager.on_llm_error(e) run_manager.on_llm_error(e)
raise e raise e
@ -674,7 +674,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
if new_arg_supported if new_arg_supported
else await self._agenerate(prompts, stop=stop) else await self._agenerate(prompts, stop=stop)
) )
except (KeyboardInterrupt, Exception) as e: except BaseException as e:
await asyncio.gather( await asyncio.gather(
*[run_manager.on_llm_error(e) for run_manager in run_managers] *[run_manager.on_llm_error(e) for run_manager in run_managers]
) )

@ -319,7 +319,7 @@ class Runnable(Generic[Input, Output], ABC):
) )
try: try:
output = call_func_with_variable_args(func, input, run_manager, config) output = call_func_with_variable_args(func, input, run_manager, config)
except Exception as e: except BaseException as e:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
raise raise
else: else:
@ -354,7 +354,7 @@ class Runnable(Generic[Input, Output], ABC):
output = await acall_func_with_variable_args( output = await acall_func_with_variable_args(
func, input, run_manager, config func, input, run_manager, config
) )
except Exception as e: except BaseException as e:
await run_manager.on_chain_error(e) await run_manager.on_chain_error(e)
raise raise
else: else:
@ -408,7 +408,7 @@ class Runnable(Generic[Input, Output], ABC):
if accepts_run_manager(func): if accepts_run_manager(func):
kwargs["run_manager"] = run_managers kwargs["run_manager"] = run_managers
output = func(input, **kwargs) # type: ignore[call-arg] output = func(input, **kwargs) # type: ignore[call-arg]
except Exception as e: except BaseException as e:
for run_manager in run_managers: for run_manager in run_managers:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
if return_exceptions: if return_exceptions:
@ -481,7 +481,7 @@ class Runnable(Generic[Input, Output], ABC):
if accepts_run_manager(func): if accepts_run_manager(func):
kwargs["run_manager"] = run_managers kwargs["run_manager"] = run_managers
output = await func(input, **kwargs) # type: ignore[call-arg] output = await func(input, **kwargs) # type: ignore[call-arg]
except Exception as e: except BaseException as e:
await asyncio.gather( await asyncio.gather(
*(run_manager.on_chain_error(e) for run_manager in run_managers) *(run_manager.on_chain_error(e) for run_manager in run_managers)
) )
@ -573,7 +573,7 @@ class Runnable(Generic[Input, Output], ABC):
except TypeError: except TypeError:
final_input = None final_input = None
final_input_supported = False final_input_supported = False
except Exception as e: except BaseException as e:
run_manager.on_chain_error(e, inputs=final_input) run_manager.on_chain_error(e, inputs=final_input)
raise raise
else: else:
@ -651,7 +651,7 @@ class Runnable(Generic[Input, Output], ABC):
except TypeError: except TypeError:
final_input = None final_input = None
final_input_supported = False final_input_supported = False
except Exception as e: except BaseException as e:
await run_manager.on_chain_error(e, inputs=final_input) await run_manager.on_chain_error(e, inputs=final_input)
raise raise
else: else:
@ -981,7 +981,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
), ),
) )
# finish the root run # finish the root run
except (KeyboardInterrupt, Exception) as e: except BaseException as e:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
raise raise
else: else:
@ -1013,7 +1013,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
), ),
) )
# finish the root run # finish the root run
except (KeyboardInterrupt, Exception) as e: except BaseException as e:
await run_manager.on_chain_error(e) await run_manager.on_chain_error(e)
raise raise
else: else:
@ -1119,7 +1119,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
) )
# finish the root runs # finish the root runs
except (KeyboardInterrupt, Exception) as e: except BaseException as e:
for rm in run_managers: for rm in run_managers:
rm.on_chain_error(e) rm.on_chain_error(e)
if return_exceptions: if return_exceptions:
@ -1242,7 +1242,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
], ],
) )
# finish the root runs # finish the root runs
except (KeyboardInterrupt, Exception) as e: except BaseException as e:
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers)) await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
if return_exceptions: if return_exceptions:
return cast(List[Output], [e for _ in inputs]) return cast(List[Output], [e for _ in inputs])
@ -1450,7 +1450,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
] ]
output = {key: future.result() for key, future in zip(steps, futures)} output = {key: future.result() for key, future in zip(steps, futures)}
# finish the root run # finish the root run
except (KeyboardInterrupt, Exception) as e: except BaseException as e:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
raise raise
else: else:
@ -1489,7 +1489,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
) )
output = {key: value for key, value in zip(steps, results)} output = {key: value for key, value in zip(steps, results)}
# finish the root run # finish the root run
except (KeyboardInterrupt, Exception) as e: except BaseException as e:
await run_manager.on_chain_error(e) await run_manager.on_chain_error(e)
raise raise
else: else:

Loading…
Cancel
Save