From 287c81db89c1adcce66cc01c5bc711e6eaf11cad Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Mon, 18 Sep 2023 08:19:35 -0700 Subject: [PATCH] Catch Base Exception (#10607) Currently the on_*_error isn't called for CancellationError's. This is because in python 3.8, the inheritance changed from Exception to BaseException https://docs.python.org/3/library/asyncio-exceptions.html#asyncio.CancelledError --- .../langchain/agents/agent_iterator.py | 4 ++-- .../langchain/callbacks/aim_callback.py | 14 ++++------- .../langchain/callbacks/argilla_callback.py | 14 ++++------- .../langchain/callbacks/arize_callback.py | 14 ++++------- .../langchain/callbacks/arthur_callback.py | 14 ++++------- libs/langchain/langchain/callbacks/base.py | 16 ++++++------- .../langchain/callbacks/clearml_callback.py | 14 ++++------- .../langchain/callbacks/comet_ml_callback.py | 14 ++++------- .../langchain/callbacks/confident_callback.py | 12 +++------- .../langchain/callbacks/flyte_callback.py | 14 ++++------- .../langchain/callbacks/infino_callback.py | 14 ++++------- .../callbacks/labelstudio_callback.py | 12 +++------- .../langchain/callbacks/llmonitor_callback.py | 6 ++--- libs/langchain/langchain/callbacks/manager.py | 12 +++++----- .../langchain/callbacks/mlflow_callback.py | 12 +++------- .../langchain/callbacks/sagemaker_callback.py | 14 ++++------- libs/langchain/langchain/callbacks/stdout.py | 14 ++++------- .../langchain/callbacks/streaming_aiter.py | 4 +--- .../langchain/callbacks/streaming_stdout.py | 14 ++++------- .../streamlit/streamlit_callback_handler.py | 22 +++++------------ .../langchain/callbacks/tracers/base.py | 8 +++---- .../langchain/callbacks/wandb_callback.py | 12 +++------- libs/langchain/langchain/chains/base.py | 4 ++-- libs/langchain/langchain/chains/llm.py | 4 ++-- libs/langchain/langchain/chat_models/base.py | 8 +++---- libs/langchain/langchain/llms/base.py | 8 +++---- .../langchain/schema/runnable/base.py | 24 +++++++++---------- 27 files changed, 110 insertions(+), 212 deletions(-) diff --git a/libs/langchain/langchain/agents/agent_iterator.py b/libs/langchain/langchain/agents/agent_iterator.py index 36b11f7b05..b7b706fea3 100644 --- a/libs/langchain/langchain/agents/agent_iterator.py +++ b/libs/langchain/langchain/agents/agent_iterator.py @@ -282,7 +282,7 @@ class AgentExecutorIterator(BaseAgentExecutorIterator): return self._call_next() except StopIteration: raise - except (KeyboardInterrupt, Exception) as e: + except BaseException as e: if self.run_manager: self.run_manager.on_chain_error(e) raise @@ -304,7 +304,7 @@ class AgentExecutorIterator(BaseAgentExecutorIterator): await self.timeout_manager.__aexit__(None, None, None) self.timeout_manager = None return await self._astop() - except (KeyboardInterrupt, Exception) as e: + except BaseException as e: if self.run_manager: assert isinstance(self.run_manager, AsyncCallbackManagerForChainRun) await self.run_manager.on_chain_error(e) diff --git a/libs/langchain/langchain/callbacks/aim_callback.py b/libs/langchain/langchain/callbacks/aim_callback.py index 9941f92989..9526f34f14 100644 --- a/libs/langchain/langchain/callbacks/aim_callback.py +++ b/libs/langchain/langchain/callbacks/aim_callback.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import AgentAction, AgentFinish, LLMResult @@ -255,9 +255,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self.step += 1 self.llm_streams += 1 - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """Run when LLM errors.""" self.step += 1 self.errors += 1 @@ -296,9 +294,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): aim.Text(outputs_res["output"]), name="on_chain_end", context=resp ) - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """Run when chain errors.""" self.step += 1 self.errors += 1 @@ -329,9 +325,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self._run.track(aim.Text(output), name="on_tool_end", context=resp) - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """Run when tool errors.""" self.step += 1 self.errors += 1 diff --git a/libs/langchain/langchain/callbacks/argilla_callback.py b/libs/langchain/langchain/callbacks/argilla_callback.py index 84a1596386..4763a7f11e 100644 --- a/libs/langchain/langchain/callbacks/argilla_callback.py +++ b/libs/langchain/langchain/callbacks/argilla_callback.py @@ -1,6 +1,6 @@ import os import warnings -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional from packaging.version import parse @@ -236,9 +236,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler): # Push the records to Argilla self.dataset.push_to_argilla() - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """Do nothing when LLM outputs an error.""" pass @@ -313,9 +311,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler): # Push the records to Argilla self.dataset.push_to_argilla() - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """Do nothing when LLM chain outputs an error.""" pass @@ -342,9 +338,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler): """Do nothing when tool ends.""" pass - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """Do nothing when tool outputs an error.""" pass diff --git a/libs/langchain/langchain/callbacks/arize_callback.py b/libs/langchain/langchain/callbacks/arize_callback.py index 62f952588a..a57de3a905 100644 --- a/libs/langchain/langchain/callbacks/arize_callback.py +++ b/libs/langchain/langchain/callbacks/arize_callback.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.utils import import_pandas @@ -163,9 +163,7 @@ class ArizeCallbackHandler(BaseCallbackHandler): else: print(f'❌ Logging failed "{response_from_arize.text}"') - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """Do nothing.""" pass @@ -178,9 +176,7 @@ class ArizeCallbackHandler(BaseCallbackHandler): """Do nothing.""" pass - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """Do nothing.""" pass @@ -205,9 +201,7 @@ class ArizeCallbackHandler(BaseCallbackHandler): ) -> None: pass - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: pass def on_text(self, text: str, **kwargs: Any) -> None: diff --git a/libs/langchain/langchain/callbacks/arthur_callback.py b/libs/langchain/langchain/callbacks/arthur_callback.py index c7c8c2317a..5584175b7b 100644 --- a/libs/langchain/langchain/callbacks/arthur_callback.py +++ b/libs/langchain/langchain/callbacks/arthur_callback.py @@ -6,7 +6,7 @@ import uuid from collections import defaultdict from datetime import datetime from time import time -from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Optional import numpy as np @@ -257,17 +257,13 @@ class ArthurCallbackHandler(BaseCallbackHandler): def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """On chain end, do nothing.""" - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """Do nothing when LLM outputs an error.""" def on_llm_new_token(self, token: str, **kwargs: Any) -> None: """On new token, pass.""" - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """Do nothing when LLM chain outputs an error.""" def on_tool_start( @@ -290,9 +286,7 @@ class ArthurCallbackHandler(BaseCallbackHandler): ) -> None: """Do nothing when tool ends.""" - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """Do nothing when tool outputs an error.""" def on_text(self, text: str, **kwargs: Any) -> None: diff --git a/libs/langchain/langchain/callbacks/base.py b/libs/langchain/langchain/callbacks/base.py index d6155536b0..519379fc50 100644 --- a/libs/langchain/langchain/callbacks/base.py +++ b/libs/langchain/langchain/callbacks/base.py @@ -18,7 +18,7 @@ class RetrieverManagerMixin: def on_retriever_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -69,7 +69,7 @@ class LLMManagerMixin: def on_llm_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -93,7 +93,7 @@ class ChainManagerMixin: def on_chain_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -137,7 +137,7 @@ class ToolManagerMixin: def on_tool_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -344,7 +344,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): async def on_llm_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -379,7 +379,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): async def on_chain_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -414,7 +414,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): async def on_tool_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, @@ -492,7 +492,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): async def on_retriever_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, diff --git a/libs/langchain/langchain/callbacks/clearml_callback.py b/libs/langchain/langchain/callbacks/clearml_callback.py index 2b32428eee..4d6610bf81 100644 --- a/libs/langchain/langchain/callbacks/clearml_callback.py +++ b/libs/langchain/langchain/callbacks/clearml_callback.py @@ -1,7 +1,7 @@ import tempfile from copy import deepcopy from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.utils import ( @@ -155,9 +155,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): if self.stream_logs: self.logger.report_text(generation_resp) - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """Run when LLM errors.""" self.step += 1 self.errors += 1 @@ -210,9 +208,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): if self.stream_logs: self.logger.report_text(resp) - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """Run when chain errors.""" self.step += 1 self.errors += 1 @@ -250,9 +246,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): if self.stream_logs: self.logger.report_text(resp) - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """Run when tool errors.""" self.step += 1 self.errors += 1 diff --git a/libs/langchain/langchain/callbacks/comet_ml_callback.py b/libs/langchain/langchain/callbacks/comet_ml_callback.py index cb593d00b3..1e8aabb1c1 100644 --- a/libs/langchain/langchain/callbacks/comet_ml_callback.py +++ b/libs/langchain/langchain/callbacks/comet_ml_callback.py @@ -1,7 +1,7 @@ import tempfile from copy import deepcopy from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Optional, Sequence import langchain from langchain.callbacks.base import BaseCallbackHandler @@ -223,9 +223,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self._log_text_metrics(output_complexity_metrics, step=self.step) self._log_text_metrics(output_custom_metrics, step=self.step) - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """Run when LLM errors.""" self.step += 1 self.errors += 1 @@ -280,9 +278,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): f"Output Value for {chain_output_key} will not be logged" ) - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """Run when chain errors.""" self.step += 1 self.errors += 1 @@ -320,9 +316,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): resp.update({"output": output}) self.action_records.append(resp) - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """Run when tool errors.""" self.step += 1 self.errors += 1 diff --git a/libs/langchain/langchain/callbacks/confident_callback.py b/libs/langchain/langchain/callbacks/confident_callback.py index d65ad8a0a2..9d8f494c93 100644 --- a/libs/langchain/langchain/callbacks/confident_callback.py +++ b/libs/langchain/langchain/callbacks/confident_callback.py @@ -128,9 +128,7 @@ class DeepEvalCallbackHandler(BaseCallbackHandler): callbacks.""" ) - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """Do nothing when LLM outputs an error.""" pass @@ -144,9 +142,7 @@ class DeepEvalCallbackHandler(BaseCallbackHandler): """Do nothing when chain ends.""" pass - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """Do nothing when LLM chain outputs an error.""" pass @@ -173,9 +169,7 @@ class DeepEvalCallbackHandler(BaseCallbackHandler): """Do nothing when tool ends.""" pass - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """Do nothing when tool outputs an error.""" pass diff --git a/libs/langchain/langchain/callbacks/flyte_callback.py b/libs/langchain/langchain/callbacks/flyte_callback.py index 7e22d58e98..ad696b8139 100644 --- a/libs/langchain/langchain/callbacks/flyte_callback.py +++ b/libs/langchain/langchain/callbacks/flyte_callback.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Tuple from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.utils import ( @@ -221,9 +221,7 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): ) self.deck.append(self.markdown_renderer().to_html(generation.text)) - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """Run when LLM errors.""" self.step += 1 self.errors += 1 @@ -266,9 +264,7 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" ) - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """Run when chain errors.""" self.step += 1 self.errors += 1 @@ -306,9 +302,7 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n" ) - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """Run when tool errors.""" self.step += 1 self.errors += 1 diff --git a/libs/langchain/langchain/callbacks/infino_callback.py b/libs/langchain/langchain/callbacks/infino_callback.py index 1eb75d087d..fa1cfdc7a9 100644 --- a/libs/langchain/langchain/callbacks/infino_callback.py +++ b/libs/langchain/langchain/callbacks/infino_callback.py @@ -1,5 +1,5 @@ import time -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import AgentAction, AgentFinish, LLMResult @@ -113,9 +113,7 @@ class InfinoCallbackHandler(BaseCallbackHandler): for generation in generations: self._send_to_infino("prompt_response", generation.text, is_ts=False) - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """Set the error flag.""" self.error = 1 @@ -129,9 +127,7 @@ class InfinoCallbackHandler(BaseCallbackHandler): """Do nothing when LLM chain ends.""" pass - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """Need to log the error.""" pass @@ -158,9 +154,7 @@ class InfinoCallbackHandler(BaseCallbackHandler): """Do nothing when tool ends.""" pass - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """Do nothing when tool outputs an error.""" pass diff --git a/libs/langchain/langchain/callbacks/labelstudio_callback.py b/libs/langchain/langchain/callbacks/labelstudio_callback.py index 50e468aa57..d74dcd52bd 100644 --- a/libs/langchain/langchain/callbacks/labelstudio_callback.py +++ b/libs/langchain/langchain/callbacks/labelstudio_callback.py @@ -334,9 +334,7 @@ class LabelStudioCallbackHandler(BaseCallbackHandler): # Pop current run from `self.runs` self.payload.pop(run_id) - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """Do nothing when LLM outputs an error.""" pass @@ -348,9 +346,7 @@ class LabelStudioCallbackHandler(BaseCallbackHandler): def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: pass - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """Do nothing when LLM chain outputs an error.""" pass @@ -377,9 +373,7 @@ class LabelStudioCallbackHandler(BaseCallbackHandler): """Do nothing when tool ends.""" pass - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """Do nothing when tool outputs an error.""" pass diff --git a/libs/langchain/langchain/callbacks/llmonitor_callback.py b/libs/langchain/langchain/callbacks/llmonitor_callback.py index 9aadc9df71..6e9e4e5328 100644 --- a/libs/langchain/langchain/callbacks/llmonitor_callback.py +++ b/libs/langchain/langchain/callbacks/llmonitor_callback.py @@ -406,7 +406,7 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): def on_chain_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, *, run_id: UUID, parent_run_id: Union[UUID, None] = None, @@ -423,7 +423,7 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): def on_tool_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, *, run_id: UUID, parent_run_id: Union[UUID, None] = None, @@ -440,7 +440,7 @@ class LLMonitorCallbackHandler(BaseCallbackHandler): def on_llm_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, *, run_id: UUID, parent_run_id: Union[UUID, None] = None, diff --git a/libs/langchain/langchain/callbacks/manager.py b/libs/langchain/langchain/callbacks/manager.py index eb4e0ebdec..d0e984dee3 100644 --- a/libs/langchain/langchain/callbacks/manager.py +++ b/libs/langchain/langchain/callbacks/manager.py @@ -707,7 +707,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): def on_llm_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, **kwargs: Any, ) -> None: """Run when LLM errors. @@ -773,7 +773,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): async def on_llm_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, **kwargs: Any, ) -> None: """Run when LLM errors. @@ -985,7 +985,7 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin): def on_tool_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, **kwargs: Any, ) -> None: """Run when tool errors. @@ -1027,7 +1027,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin): async def on_tool_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, **kwargs: Any, ) -> None: """Run when tool errors. @@ -1069,7 +1069,7 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin): def on_retriever_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, **kwargs: Any, ) -> None: """Run when retriever errors.""" @@ -1108,7 +1108,7 @@ class AsyncCallbackManagerForRetrieverRun( async def on_retriever_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, **kwargs: Any, ) -> None: """Run when retriever errors.""" diff --git a/libs/langchain/langchain/callbacks/mlflow_callback.py b/libs/langchain/langchain/callbacks/mlflow_callback.py index c51db69bf0..85ffcfdb06 100644 --- a/libs/langchain/langchain/callbacks/mlflow_callback.py +++ b/libs/langchain/langchain/callbacks/mlflow_callback.py @@ -384,9 +384,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self.mlflg.html(dependency_tree, "dep-" + hash_string(generation.text)) self.mlflg.html(entities, "ent-" + hash_string(generation.text)) - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """Run when LLM errors.""" self.metrics["step"] += 1 self.metrics["errors"] += 1 @@ -434,9 +432,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self.records["action_records"].append(resp) self.mlflg.jsonf(resp, f"chain_end_{chain_ends}") - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """Run when chain errors.""" self.metrics["step"] += 1 self.metrics["errors"] += 1 @@ -480,9 +476,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): self.records["action_records"].append(resp) self.mlflg.jsonf(resp, f"tool_end_{tool_ends}") - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """Run when tool errors.""" self.metrics["step"] += 1 self.metrics["errors"] += 1 diff --git a/libs/langchain/langchain/callbacks/sagemaker_callback.py b/libs/langchain/langchain/callbacks/sagemaker_callback.py index c97461c330..9b532a04c1 100644 --- a/libs/langchain/langchain/callbacks/sagemaker_callback.py +++ b/libs/langchain/langchain/callbacks/sagemaker_callback.py @@ -3,7 +3,7 @@ import os import shutil import tempfile from copy import deepcopy -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.utils import ( @@ -121,9 +121,7 @@ class SageMakerCallbackHandler(BaseCallbackHandler): f"llm_end_{llm_ends}_generation_{idx}", ) - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """Run when LLM errors.""" self.metrics["step"] += 1 self.metrics["errors"] += 1 @@ -164,9 +162,7 @@ class SageMakerCallbackHandler(BaseCallbackHandler): self.jsonf(resp, self.temp_dir, f"chain_end_{chain_ends}") - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """Run when chain errors.""" self.metrics["step"] += 1 self.metrics["errors"] += 1 @@ -202,9 +198,7 @@ class SageMakerCallbackHandler(BaseCallbackHandler): self.jsonf(resp, self.temp_dir, f"tool_end_{tool_ends}") - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """Run when tool errors.""" self.metrics["step"] += 1 self.metrics["errors"] += 1 diff --git a/libs/langchain/langchain/callbacks/stdout.py b/libs/langchain/langchain/callbacks/stdout.py index 56e0e7d904..a9738c9bf9 100644 --- a/libs/langchain/langchain/callbacks/stdout.py +++ b/libs/langchain/langchain/callbacks/stdout.py @@ -1,5 +1,5 @@ """Callback Handler that prints to std out.""" -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import AgentAction, AgentFinish, LLMResult @@ -27,9 +27,7 @@ class StdOutCallbackHandler(BaseCallbackHandler): """Do nothing.""" pass - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """Do nothing.""" pass @@ -44,9 +42,7 @@ class StdOutCallbackHandler(BaseCallbackHandler): """Print out that we finished a chain.""" print("\n\033[1m> Finished chain.\033[0m") - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """Do nothing.""" pass @@ -80,9 +76,7 @@ class StdOutCallbackHandler(BaseCallbackHandler): if llm_prefix is not None: print_text(f"\n{llm_prefix}") - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """Do nothing.""" pass diff --git a/libs/langchain/langchain/callbacks/streaming_aiter.py b/libs/langchain/langchain/callbacks/streaming_aiter.py index 17e962ac87..e49bea629d 100644 --- a/libs/langchain/langchain/callbacks/streaming_aiter.py +++ b/libs/langchain/langchain/callbacks/streaming_aiter.py @@ -37,9 +37,7 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler): async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: self.done.set() - async def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + async def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: self.done.set() # TODO implement the other methods diff --git a/libs/langchain/langchain/callbacks/streaming_stdout.py b/libs/langchain/langchain/callbacks/streaming_stdout.py index 2c71bc769c..fefba70683 100644 --- a/libs/langchain/langchain/callbacks/streaming_stdout.py +++ b/libs/langchain/langchain/callbacks/streaming_stdout.py @@ -1,6 +1,6 @@ """Callback Handler streams to stdout on new llm token.""" import sys -from typing import Any, Dict, List, Union +from typing import Any, Dict, List from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import AgentAction, AgentFinish, LLMResult @@ -31,9 +31,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler): def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: """Run when LLM ends running.""" - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """Run when LLM errors.""" def on_chain_start( @@ -44,9 +42,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler): def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """Run when chain ends running.""" - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """Run when chain errors.""" def on_tool_start( @@ -61,9 +57,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler): def on_tool_end(self, output: str, **kwargs: Any) -> None: """Run when tool ends running.""" - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """Run when tool errors.""" def on_text(self, text: str, **kwargs: Any) -> None: diff --git a/libs/langchain/langchain/callbacks/streamlit/streamlit_callback_handler.py b/libs/langchain/langchain/callbacks/streamlit/streamlit_callback_handler.py index 87cc4e058f..d43af17a35 100644 --- a/libs/langchain/langchain/callbacks/streamlit/streamlit_callback_handler.py +++ b/libs/langchain/langchain/callbacks/streamlit/streamlit_callback_handler.py @@ -3,7 +3,7 @@ from __future__ import annotations from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.streamlit.mutable_expander import MutableExpander @@ -163,9 +163,7 @@ class LLMThought: # data is redundant self._reset_llm_token_stream() - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: self._container.markdown("**LLM encountered an error...**") self._container.exception(error) @@ -191,9 +189,7 @@ class LLMThought: ) -> None: self._container.markdown(f"**{output}**") - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: self._container.markdown("**Tool encountered an error...**") self._container.exception(error) @@ -353,9 +349,7 @@ class StreamlitCallbackHandler(BaseCallbackHandler): self._require_current_thought().on_llm_end(response, **kwargs) self._prune_old_thought_containers() - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: self._require_current_thought().on_llm_error(error, **kwargs) self._prune_old_thought_containers() @@ -378,9 +372,7 @@ class StreamlitCallbackHandler(BaseCallbackHandler): ) self._complete_current_thought() - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: self._require_current_thought().on_tool_error(error, **kwargs) self._prune_old_thought_containers() @@ -401,9 +393,7 @@ class StreamlitCallbackHandler(BaseCallbackHandler): def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: pass - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: pass def on_agent_action( diff --git a/libs/langchain/langchain/callbacks/tracers/base.py b/libs/langchain/langchain/callbacks/tracers/base.py index bee30a515f..a527798a91 100644 --- a/libs/langchain/langchain/callbacks/tracers/base.py +++ b/libs/langchain/langchain/callbacks/tracers/base.py @@ -211,7 +211,7 @@ class BaseTracer(BaseCallbackHandler, ABC): def on_llm_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, *, run_id: UUID, **kwargs: Any, @@ -294,7 +294,7 @@ class BaseTracer(BaseCallbackHandler, ABC): def on_chain_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, *, inputs: Optional[Dict[str, Any]] = None, run_id: UUID, @@ -365,7 +365,7 @@ class BaseTracer(BaseCallbackHandler, ABC): def on_tool_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, *, run_id: UUID, **kwargs: Any, @@ -420,7 +420,7 @@ class BaseTracer(BaseCallbackHandler, ABC): def on_retriever_error( self, - error: Union[Exception, KeyboardInterrupt], + error: BaseException, *, run_id: UUID, **kwargs: Any, diff --git a/libs/langchain/langchain/callbacks/wandb_callback.py b/libs/langchain/langchain/callbacks/wandb_callback.py index d49805e861..18102b34bd 100644 --- a/libs/langchain/langchain/callbacks/wandb_callback.py +++ b/libs/langchain/langchain/callbacks/wandb_callback.py @@ -282,9 +282,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): if self.stream_logs: self.run.log(generation_resp) - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_llm_error(self, error: BaseException, **kwargs: Any) -> None: """Run when LLM errors.""" self.step += 1 self.errors += 1 @@ -337,9 +335,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): if self.stream_logs: self.run.log(resp) - def on_chain_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_chain_error(self, error: BaseException, **kwargs: Any) -> None: """Run when chain errors.""" self.step += 1 self.errors += 1 @@ -377,9 +373,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler): if self.stream_logs: self.run.log(resp) - def on_tool_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: + def on_tool_error(self, error: BaseException, **kwargs: Any) -> None: """Run when tool errors.""" self.step += 1 self.errors += 1 diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index 848d0940a6..52ff280755 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -287,7 +287,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): if new_arg_supported else self._call(inputs) ) - except (KeyboardInterrupt, Exception) as e: + except BaseException as e: run_manager.on_chain_error(e) raise e run_manager.on_chain_end(outputs) @@ -356,7 +356,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC): if new_arg_supported else await self._acall(inputs) ) - except (KeyboardInterrupt, Exception) as e: + except BaseException as e: await run_manager.on_chain_error(e) raise e await run_manager.on_chain_end(outputs) diff --git a/libs/langchain/langchain/chains/llm.py b/libs/langchain/langchain/chains/llm.py index efa16ff4fc..c32858d97e 100644 --- a/libs/langchain/langchain/chains/llm.py +++ b/libs/langchain/langchain/chains/llm.py @@ -186,7 +186,7 @@ class LLMChain(Chain): ) try: response = self.generate(input_list, run_manager=run_manager) - except (KeyboardInterrupt, Exception) as e: + except BaseException as e: run_manager.on_chain_error(e) raise e outputs = self.create_outputs(response) @@ -206,7 +206,7 @@ class LLMChain(Chain): ) try: response = await self.agenerate(input_list, run_manager=run_manager) - except (KeyboardInterrupt, Exception) as e: + except BaseException as e: await run_manager.on_chain_error(e) raise e outputs = self.create_outputs(response) diff --git a/libs/langchain/langchain/chat_models/base.py b/libs/langchain/langchain/chat_models/base.py index 2d0db37c0a..61bf020823 100644 --- a/libs/langchain/langchain/chat_models/base.py +++ b/libs/langchain/langchain/chat_models/base.py @@ -186,7 +186,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): else: generation += chunk assert generation is not None - except (KeyboardInterrupt, Exception) as e: + except BaseException as e: run_manager.on_llm_error(e) raise e else: @@ -233,7 +233,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): else: generation += chunk assert generation is not None - except (KeyboardInterrupt, Exception) as e: + except BaseException as e: await run_manager.on_llm_error(e) raise e else: @@ -303,7 +303,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): **kwargs, ) ) - except (KeyboardInterrupt, Exception) as e: + except BaseException as e: if run_managers: run_managers[i].on_llm_error(e) raise e @@ -364,7 +364,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC): ) exceptions = [] for i, res in enumerate(results): - if isinstance(res, Exception): + if isinstance(res, BaseException): if run_managers: await run_managers[i].on_llm_error(res) exceptions.append(res) diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index 5d6e074b8d..0c9c3158ec 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -388,7 +388,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): else: generation += chunk assert generation is not None - except (KeyboardInterrupt, Exception) as e: + except BaseException as e: run_manager.on_llm_error(e) raise e else: @@ -435,7 +435,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): else: generation += chunk assert generation is not None - except (KeyboardInterrupt, Exception) as e: + except BaseException as e: await run_manager.on_llm_error(e) raise e else: @@ -523,7 +523,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): if new_arg_supported else self._generate(prompts, stop=stop) ) - except (KeyboardInterrupt, Exception) as e: + except BaseException as e: for run_manager in run_managers: run_manager.on_llm_error(e) raise e @@ -674,7 +674,7 @@ class BaseLLM(BaseLanguageModel[str], ABC): if new_arg_supported else await self._agenerate(prompts, stop=stop) ) - except (KeyboardInterrupt, Exception) as e: + except BaseException as e: await asyncio.gather( *[run_manager.on_llm_error(e) for run_manager in run_managers] ) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 49c3abf462..3c879ef8c5 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -319,7 +319,7 @@ class Runnable(Generic[Input, Output], ABC): ) try: output = call_func_with_variable_args(func, input, run_manager, config) - except Exception as e: + except BaseException as e: run_manager.on_chain_error(e) raise else: @@ -354,7 +354,7 @@ class Runnable(Generic[Input, Output], ABC): output = await acall_func_with_variable_args( func, input, run_manager, config ) - except Exception as e: + except BaseException as e: await run_manager.on_chain_error(e) raise else: @@ -408,7 +408,7 @@ class Runnable(Generic[Input, Output], ABC): if accepts_run_manager(func): kwargs["run_manager"] = run_managers output = func(input, **kwargs) # type: ignore[call-arg] - except Exception as e: + except BaseException as e: for run_manager in run_managers: run_manager.on_chain_error(e) if return_exceptions: @@ -481,7 +481,7 @@ class Runnable(Generic[Input, Output], ABC): if accepts_run_manager(func): kwargs["run_manager"] = run_managers output = await func(input, **kwargs) # type: ignore[call-arg] - except Exception as e: + except BaseException as e: await asyncio.gather( *(run_manager.on_chain_error(e) for run_manager in run_managers) ) @@ -573,7 +573,7 @@ class Runnable(Generic[Input, Output], ABC): except TypeError: final_input = None final_input_supported = False - except Exception as e: + except BaseException as e: run_manager.on_chain_error(e, inputs=final_input) raise else: @@ -651,7 +651,7 @@ class Runnable(Generic[Input, Output], ABC): except TypeError: final_input = None final_input_supported = False - except Exception as e: + except BaseException as e: await run_manager.on_chain_error(e, inputs=final_input) raise else: @@ -981,7 +981,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): ), ) # finish the root run - except (KeyboardInterrupt, Exception) as e: + except BaseException as e: run_manager.on_chain_error(e) raise else: @@ -1013,7 +1013,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): ), ) # finish the root run - except (KeyboardInterrupt, Exception) as e: + except BaseException as e: await run_manager.on_chain_error(e) raise else: @@ -1119,7 +1119,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): ) # finish the root runs - except (KeyboardInterrupt, Exception) as e: + except BaseException as e: for rm in run_managers: rm.on_chain_error(e) if return_exceptions: @@ -1242,7 +1242,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]): ], ) # finish the root runs - except (KeyboardInterrupt, Exception) as e: + except BaseException as e: await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers)) if return_exceptions: return cast(List[Output], [e for _ in inputs]) @@ -1450,7 +1450,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): ] output = {key: future.result() for key, future in zip(steps, futures)} # finish the root run - except (KeyboardInterrupt, Exception) as e: + except BaseException as e: run_manager.on_chain_error(e) raise else: @@ -1489,7 +1489,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]): ) output = {key: value for key, value in zip(steps, results)} # finish the root run - except (KeyboardInterrupt, Exception) as e: + except BaseException as e: await run_manager.on_chain_error(e) raise else: