Catch Base Exception (#10607)

Currently the on_*_error isn't called for CancellationError's. This is
because in python 3.8, the inheritance changed from Exception to
BaseException


https://docs.python.org/3/library/asyncio-exceptions.html#asyncio.CancelledError
pull/10740/head
William FH 11 months ago committed by GitHub
parent 39c1c94272
commit 287c81db89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -282,7 +282,7 @@ class AgentExecutorIterator(BaseAgentExecutorIterator):
return self._call_next()
except StopIteration:
raise
except (KeyboardInterrupt, Exception) as e:
except BaseException as e:
if self.run_manager:
self.run_manager.on_chain_error(e)
raise
@ -304,7 +304,7 @@ class AgentExecutorIterator(BaseAgentExecutorIterator):
await self.timeout_manager.__aexit__(None, None, None)
self.timeout_manager = None
return await self._astop()
except (KeyboardInterrupt, Exception) as e:
except BaseException as e:
if self.run_manager:
assert isinstance(self.run_manager, AsyncCallbackManagerForChainRun)
await self.run_manager.on_chain_error(e)

@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
@ -255,9 +255,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.step += 1
self.llm_streams += 1
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.step += 1
self.errors += 1
@ -296,9 +294,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
aim.Text(outputs_res["output"]), name="on_chain_end", context=resp
)
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.step += 1
self.errors += 1
@ -329,9 +325,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self._run.track(aim.Text(output), name="on_tool_end", context=resp)
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
self.step += 1
self.errors += 1

@ -1,6 +1,6 @@
import os
import warnings
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional
from packaging.version import parse
@ -236,9 +236,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
# Push the records to Argilla
self.dataset.push_to_argilla()
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM outputs an error."""
pass
@ -313,9 +311,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
# Push the records to Argilla
self.dataset.push_to_argilla()
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM chain outputs an error."""
pass
@ -342,9 +338,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
"""Do nothing when tool ends."""
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when tool outputs an error."""
pass

@ -1,5 +1,5 @@
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import import_pandas
@ -163,9 +163,7 @@ class ArizeCallbackHandler(BaseCallbackHandler):
else:
print(f'❌ Logging failed "{response_from_arize.text}"')
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing."""
pass
@ -178,9 +176,7 @@ class ArizeCallbackHandler(BaseCallbackHandler):
"""Do nothing."""
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing."""
pass
@ -205,9 +201,7 @@ class ArizeCallbackHandler(BaseCallbackHandler):
) -> None:
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
pass
def on_text(self, text: str, **kwargs: Any) -> None:

@ -6,7 +6,7 @@ import uuid
from collections import defaultdict
from datetime import datetime
from time import time
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Optional
import numpy as np
@ -257,17 +257,13 @@ class ArthurCallbackHandler(BaseCallbackHandler):
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""On chain end, do nothing."""
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM outputs an error."""
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""On new token, pass."""
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM chain outputs an error."""
def on_tool_start(
@ -290,9 +286,7 @@ class ArthurCallbackHandler(BaseCallbackHandler):
) -> None:
"""Do nothing when tool ends."""
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when tool outputs an error."""
def on_text(self, text: str, **kwargs: Any) -> None:

@ -18,7 +18,7 @@ class RetrieverManagerMixin:
def on_retriever_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@ -69,7 +69,7 @@ class LLMManagerMixin:
def on_llm_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@ -93,7 +93,7 @@ class ChainManagerMixin:
def on_chain_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@ -137,7 +137,7 @@ class ToolManagerMixin:
def on_tool_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@ -344,7 +344,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_llm_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@ -379,7 +379,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_chain_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@ -414,7 +414,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_tool_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
@ -492,7 +492,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
async def on_retriever_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,

@ -1,7 +1,7 @@
import tempfile
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
@ -155,9 +155,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
if self.stream_logs:
self.logger.report_text(generation_resp)
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.step += 1
self.errors += 1
@ -210,9 +208,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
if self.stream_logs:
self.logger.report_text(resp)
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.step += 1
self.errors += 1
@ -250,9 +246,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
if self.stream_logs:
self.logger.report_text(resp)
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
self.step += 1
self.errors += 1

@ -1,7 +1,7 @@
import tempfile
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from typing import Any, Callable, Dict, List, Optional, Sequence
import langchain
from langchain.callbacks.base import BaseCallbackHandler
@ -223,9 +223,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self._log_text_metrics(output_complexity_metrics, step=self.step)
self._log_text_metrics(output_custom_metrics, step=self.step)
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.step += 1
self.errors += 1
@ -280,9 +278,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
f"Output Value for {chain_output_key} will not be logged"
)
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.step += 1
self.errors += 1
@ -320,9 +316,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
resp.update({"output": output})
self.action_records.append(resp)
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
self.step += 1
self.errors += 1

@ -128,9 +128,7 @@ class DeepEvalCallbackHandler(BaseCallbackHandler):
callbacks."""
)
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM outputs an error."""
pass
@ -144,9 +142,7 @@ class DeepEvalCallbackHandler(BaseCallbackHandler):
"""Do nothing when chain ends."""
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM chain outputs an error."""
pass
@ -173,9 +169,7 @@ class DeepEvalCallbackHandler(BaseCallbackHandler):
"""Do nothing when tool ends."""
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when tool outputs an error."""
pass

@ -3,7 +3,7 @@ from __future__ import annotations
import logging
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
@ -221,9 +221,7 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
)
self.deck.append(self.markdown_renderer().to_html(generation.text))
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.step += 1
self.errors += 1
@ -266,9 +264,7 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
)
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.step += 1
self.errors += 1
@ -306,9 +302,7 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
)
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
self.step += 1
self.errors += 1

@ -1,5 +1,5 @@
import time
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
@ -113,9 +113,7 @@ class InfinoCallbackHandler(BaseCallbackHandler):
for generation in generations:
self._send_to_infino("prompt_response", generation.text, is_ts=False)
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Set the error flag."""
self.error = 1
@ -129,9 +127,7 @@ class InfinoCallbackHandler(BaseCallbackHandler):
"""Do nothing when LLM chain ends."""
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Need to log the error."""
pass
@ -158,9 +154,7 @@ class InfinoCallbackHandler(BaseCallbackHandler):
"""Do nothing when tool ends."""
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when tool outputs an error."""
pass

@ -334,9 +334,7 @@ class LabelStudioCallbackHandler(BaseCallbackHandler):
# Pop current run from `self.runs`
self.payload.pop(run_id)
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM outputs an error."""
pass
@ -348,9 +346,7 @@ class LabelStudioCallbackHandler(BaseCallbackHandler):
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when LLM chain outputs an error."""
pass
@ -377,9 +373,7 @@ class LabelStudioCallbackHandler(BaseCallbackHandler):
"""Do nothing when tool ends."""
pass
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing when tool outputs an error."""
pass

@ -406,7 +406,7 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
def on_chain_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
@ -423,7 +423,7 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
def on_tool_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,
@ -440,7 +440,7 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
def on_llm_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
*,
run_id: UUID,
parent_run_id: Union[UUID, None] = None,

@ -707,7 +707,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
def on_llm_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
**kwargs: Any,
) -> None:
"""Run when LLM errors.
@ -773,7 +773,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
async def on_llm_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
**kwargs: Any,
) -> None:
"""Run when LLM errors.
@ -985,7 +985,7 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
def on_tool_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
**kwargs: Any,
) -> None:
"""Run when tool errors.
@ -1027,7 +1027,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
async def on_tool_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
**kwargs: Any,
) -> None:
"""Run when tool errors.
@ -1069,7 +1069,7 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin):
def on_retriever_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
**kwargs: Any,
) -> None:
"""Run when retriever errors."""
@ -1108,7 +1108,7 @@ class AsyncCallbackManagerForRetrieverRun(
async def on_retriever_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
**kwargs: Any,
) -> None:
"""Run when retriever errors."""

@ -384,9 +384,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.mlflg.html(dependency_tree, "dep-" + hash_string(generation.text))
self.mlflg.html(entities, "ent-" + hash_string(generation.text))
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
@ -434,9 +432,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"chain_end_{chain_ends}")
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
@ -480,9 +476,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
self.records["action_records"].append(resp)
self.mlflg.jsonf(resp, f"tool_end_{tool_ends}")
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1

@ -3,7 +3,7 @@ import os
import shutil
import tempfile
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.utils import (
@ -121,9 +121,7 @@ class SageMakerCallbackHandler(BaseCallbackHandler):
f"llm_end_{llm_ends}_generation_{idx}",
)
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
@ -164,9 +162,7 @@ class SageMakerCallbackHandler(BaseCallbackHandler):
self.jsonf(resp, self.temp_dir, f"chain_end_{chain_ends}")
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1
@ -202,9 +198,7 @@ class SageMakerCallbackHandler(BaseCallbackHandler):
self.jsonf(resp, self.temp_dir, f"tool_end_{tool_ends}")
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
self.metrics["step"] += 1
self.metrics["errors"] += 1

@ -1,5 +1,5 @@
"""Callback Handler that prints to std out."""
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
@ -27,9 +27,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
"""Do nothing."""
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing."""
pass
@ -44,9 +42,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
"""Print out that we finished a chain."""
print("\n\033[1m> Finished chain.\033[0m")
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing."""
pass
@ -80,9 +76,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
if llm_prefix is not None:
print_text(f"\n{llm_prefix}")
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Do nothing."""
pass

@ -37,9 +37,7 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self.done.set()
async def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
async def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self.done.set()
# TODO implement the other methods

@ -1,6 +1,6 @@
"""Callback Handler streams to stdout on new llm token."""
import sys
from typing import Any, Dict, List, Union
from typing import Any, Dict, List
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
@ -31,9 +31,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Run when LLM ends running."""
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
def on_chain_start(
@ -44,9 +42,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running."""
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
def on_tool_start(
@ -61,9 +57,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running."""
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
def on_text(self, text: str, **kwargs: Any) -> None:

@ -3,7 +3,7 @@
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.streamlit.mutable_expander import MutableExpander
@ -163,9 +163,7 @@ class LLMThought:
# data is redundant
self._reset_llm_token_stream()
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self._container.markdown("**LLM encountered an error...**")
self._container.exception(error)
@ -191,9 +189,7 @@ class LLMThought:
) -> None:
self._container.markdown(f"**{output}**")
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self._container.markdown("**Tool encountered an error...**")
self._container.exception(error)
@ -353,9 +349,7 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
self._require_current_thought().on_llm_end(response, **kwargs)
self._prune_old_thought_containers()
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self._require_current_thought().on_llm_error(error, **kwargs)
self._prune_old_thought_containers()
@ -378,9 +372,7 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
)
self._complete_current_thought()
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self._require_current_thought().on_tool_error(error, **kwargs)
self._prune_old_thought_containers()
@ -401,9 +393,7 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
pass
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
pass
def on_agent_action(

@ -211,7 +211,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
def on_llm_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
*,
run_id: UUID,
**kwargs: Any,
@ -294,7 +294,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
def on_chain_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
*,
inputs: Optional[Dict[str, Any]] = None,
run_id: UUID,
@ -365,7 +365,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
def on_tool_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
*,
run_id: UUID,
**kwargs: Any,
@ -420,7 +420,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
def on_retriever_error(
self,
error: Union[Exception, KeyboardInterrupt],
error: BaseException,
*,
run_id: UUID,
**kwargs: Any,

@ -282,9 +282,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
if self.stream_logs:
self.run.log(generation_resp)
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when LLM errors."""
self.step += 1
self.errors += 1
@ -337,9 +335,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
if self.stream_logs:
self.run.log(resp)
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when chain errors."""
self.step += 1
self.errors += 1
@ -377,9 +373,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
if self.stream_logs:
self.run.log(resp)
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
"""Run when tool errors."""
self.step += 1
self.errors += 1

@ -287,7 +287,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
if new_arg_supported
else self._call(inputs)
)
except (KeyboardInterrupt, Exception) as e:
except BaseException as e:
run_manager.on_chain_error(e)
raise e
run_manager.on_chain_end(outputs)
@ -356,7 +356,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
if new_arg_supported
else await self._acall(inputs)
)
except (KeyboardInterrupt, Exception) as e:
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
await run_manager.on_chain_end(outputs)

@ -186,7 +186,7 @@ class LLMChain(Chain):
)
try:
response = self.generate(input_list, run_manager=run_manager)
except (KeyboardInterrupt, Exception) as e:
except BaseException as e:
run_manager.on_chain_error(e)
raise e
outputs = self.create_outputs(response)
@ -206,7 +206,7 @@ class LLMChain(Chain):
)
try:
response = await self.agenerate(input_list, run_manager=run_manager)
except (KeyboardInterrupt, Exception) as e:
except BaseException as e:
await run_manager.on_chain_error(e)
raise e
outputs = self.create_outputs(response)

@ -186,7 +186,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
else:
generation += chunk
assert generation is not None
except (KeyboardInterrupt, Exception) as e:
except BaseException as e:
run_manager.on_llm_error(e)
raise e
else:
@ -233,7 +233,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
else:
generation += chunk
assert generation is not None
except (KeyboardInterrupt, Exception) as e:
except BaseException as e:
await run_manager.on_llm_error(e)
raise e
else:
@ -303,7 +303,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
**kwargs,
)
)
except (KeyboardInterrupt, Exception) as e:
except BaseException as e:
if run_managers:
run_managers[i].on_llm_error(e)
raise e
@ -364,7 +364,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
)
exceptions = []
for i, res in enumerate(results):
if isinstance(res, Exception):
if isinstance(res, BaseException):
if run_managers:
await run_managers[i].on_llm_error(res)
exceptions.append(res)

@ -388,7 +388,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
else:
generation += chunk
assert generation is not None
except (KeyboardInterrupt, Exception) as e:
except BaseException as e:
run_manager.on_llm_error(e)
raise e
else:
@ -435,7 +435,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
else:
generation += chunk
assert generation is not None
except (KeyboardInterrupt, Exception) as e:
except BaseException as e:
await run_manager.on_llm_error(e)
raise e
else:
@ -523,7 +523,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
if new_arg_supported
else self._generate(prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
except BaseException as e:
for run_manager in run_managers:
run_manager.on_llm_error(e)
raise e
@ -674,7 +674,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
if new_arg_supported
else await self._agenerate(prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
except BaseException as e:
await asyncio.gather(
*[run_manager.on_llm_error(e) for run_manager in run_managers]
)

@ -319,7 +319,7 @@ class Runnable(Generic[Input, Output], ABC):
)
try:
output = call_func_with_variable_args(func, input, run_manager, config)
except Exception as e:
except BaseException as e:
run_manager.on_chain_error(e)
raise
else:
@ -354,7 +354,7 @@ class Runnable(Generic[Input, Output], ABC):
output = await acall_func_with_variable_args(
func, input, run_manager, config
)
except Exception as e:
except BaseException as e:
await run_manager.on_chain_error(e)
raise
else:
@ -408,7 +408,7 @@ class Runnable(Generic[Input, Output], ABC):
if accepts_run_manager(func):
kwargs["run_manager"] = run_managers
output = func(input, **kwargs) # type: ignore[call-arg]
except Exception as e:
except BaseException as e:
for run_manager in run_managers:
run_manager.on_chain_error(e)
if return_exceptions:
@ -481,7 +481,7 @@ class Runnable(Generic[Input, Output], ABC):
if accepts_run_manager(func):
kwargs["run_manager"] = run_managers
output = await func(input, **kwargs) # type: ignore[call-arg]
except Exception as e:
except BaseException as e:
await asyncio.gather(
*(run_manager.on_chain_error(e) for run_manager in run_managers)
)
@ -573,7 +573,7 @@ class Runnable(Generic[Input, Output], ABC):
except TypeError:
final_input = None
final_input_supported = False
except Exception as e:
except BaseException as e:
run_manager.on_chain_error(e, inputs=final_input)
raise
else:
@ -651,7 +651,7 @@ class Runnable(Generic[Input, Output], ABC):
except TypeError:
final_input = None
final_input_supported = False
except Exception as e:
except BaseException as e:
await run_manager.on_chain_error(e, inputs=final_input)
raise
else:
@ -981,7 +981,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
),
)
# finish the root run
except (KeyboardInterrupt, Exception) as e:
except BaseException as e:
run_manager.on_chain_error(e)
raise
else:
@ -1013,7 +1013,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
),
)
# finish the root run
except (KeyboardInterrupt, Exception) as e:
except BaseException as e:
await run_manager.on_chain_error(e)
raise
else:
@ -1119,7 +1119,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
)
# finish the root runs
except (KeyboardInterrupt, Exception) as e:
except BaseException as e:
for rm in run_managers:
rm.on_chain_error(e)
if return_exceptions:
@ -1242,7 +1242,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
],
)
# finish the root runs
except (KeyboardInterrupt, Exception) as e:
except BaseException as e:
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
if return_exceptions:
return cast(List[Output], [e for _ in inputs])
@ -1450,7 +1450,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
]
output = {key: future.result() for key, future in zip(steps, futures)}
# finish the root run
except (KeyboardInterrupt, Exception) as e:
except BaseException as e:
run_manager.on_chain_error(e)
raise
else:
@ -1489,7 +1489,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
)
output = {key: value for key, value in zip(steps, results)}
# finish the root run
except (KeyboardInterrupt, Exception) as e:
except BaseException as e:
await run_manager.on_chain_error(e)
raise
else:

Loading…
Cancel
Save