mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Catch Base Exception (#10607)
Currently the on_*_error isn't called for CancellationError's. This is because in python 3.8, the inheritance changed from Exception to BaseException https://docs.python.org/3/library/asyncio-exceptions.html#asyncio.CancelledError
This commit is contained in:
parent
39c1c94272
commit
287c81db89
@ -282,7 +282,7 @@ class AgentExecutorIterator(BaseAgentExecutorIterator):
|
|||||||
return self._call_next()
|
return self._call_next()
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise
|
raise
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except BaseException as e:
|
||||||
if self.run_manager:
|
if self.run_manager:
|
||||||
self.run_manager.on_chain_error(e)
|
self.run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
@ -304,7 +304,7 @@ class AgentExecutorIterator(BaseAgentExecutorIterator):
|
|||||||
await self.timeout_manager.__aexit__(None, None, None)
|
await self.timeout_manager.__aexit__(None, None, None)
|
||||||
self.timeout_manager = None
|
self.timeout_manager = None
|
||||||
return await self._astop()
|
return await self._astop()
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except BaseException as e:
|
||||||
if self.run_manager:
|
if self.run_manager:
|
||||||
assert isinstance(self.run_manager, AsyncCallbackManagerForChainRun)
|
assert isinstance(self.run_manager, AsyncCallbackManagerForChainRun)
|
||||||
await self.run_manager.on_chain_error(e)
|
await self.run_manager.on_chain_error(e)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
@ -255,9 +255,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
self.step += 1
|
self.step += 1
|
||||||
self.llm_streams += 1
|
self.llm_streams += 1
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when LLM errors."""
|
"""Run when LLM errors."""
|
||||||
self.step += 1
|
self.step += 1
|
||||||
self.errors += 1
|
self.errors += 1
|
||||||
@ -296,9 +294,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
aim.Text(outputs_res["output"]), name="on_chain_end", context=resp
|
aim.Text(outputs_res["output"]), name="on_chain_end", context=resp
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when chain errors."""
|
"""Run when chain errors."""
|
||||||
self.step += 1
|
self.step += 1
|
||||||
self.errors += 1
|
self.errors += 1
|
||||||
@ -329,9 +325,7 @@ class AimCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
|
|
||||||
self._run.track(aim.Text(output), name="on_tool_end", context=resp)
|
self._run.track(aim.Text(output), name="on_tool_end", context=resp)
|
||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when tool errors."""
|
"""Run when tool errors."""
|
||||||
self.step += 1
|
self.step += 1
|
||||||
self.errors += 1
|
self.errors += 1
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from packaging.version import parse
|
from packaging.version import parse
|
||||||
|
|
||||||
@ -236,9 +236,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
|||||||
# Push the records to Argilla
|
# Push the records to Argilla
|
||||||
self.dataset.push_to_argilla()
|
self.dataset.push_to_argilla()
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing when LLM outputs an error."""
|
"""Do nothing when LLM outputs an error."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -313,9 +311,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
|||||||
# Push the records to Argilla
|
# Push the records to Argilla
|
||||||
self.dataset.push_to_argilla()
|
self.dataset.push_to_argilla()
|
||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing when LLM chain outputs an error."""
|
"""Do nothing when LLM chain outputs an error."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -342,9 +338,7 @@ class ArgillaCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Do nothing when tool ends."""
|
"""Do nothing when tool ends."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing when tool outputs an error."""
|
"""Do nothing when tool outputs an error."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.callbacks.utils import import_pandas
|
from langchain.callbacks.utils import import_pandas
|
||||||
@ -163,9 +163,7 @@ class ArizeCallbackHandler(BaseCallbackHandler):
|
|||||||
else:
|
else:
|
||||||
print(f'❌ Logging failed "{response_from_arize.text}"')
|
print(f'❌ Logging failed "{response_from_arize.text}"')
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing."""
|
"""Do nothing."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -178,9 +176,7 @@ class ArizeCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Do nothing."""
|
"""Do nothing."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing."""
|
"""Do nothing."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -205,9 +201,7 @@ class ArizeCallbackHandler(BaseCallbackHandler):
|
|||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
|
@ -6,7 +6,7 @@ import uuid
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from time import time
|
from time import time
|
||||||
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -257,17 +257,13 @@ class ArthurCallbackHandler(BaseCallbackHandler):
|
|||||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||||
"""On chain end, do nothing."""
|
"""On chain end, do nothing."""
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing when LLM outputs an error."""
|
"""Do nothing when LLM outputs an error."""
|
||||||
|
|
||||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
"""On new token, pass."""
|
"""On new token, pass."""
|
||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing when LLM chain outputs an error."""
|
"""Do nothing when LLM chain outputs an error."""
|
||||||
|
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
@ -290,9 +286,7 @@ class ArthurCallbackHandler(BaseCallbackHandler):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Do nothing when tool ends."""
|
"""Do nothing when tool ends."""
|
||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing when tool outputs an error."""
|
"""Do nothing when tool outputs an error."""
|
||||||
|
|
||||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
|
@ -18,7 +18,7 @@ class RetrieverManagerMixin:
|
|||||||
|
|
||||||
def on_retriever_error(
|
def on_retriever_error(
|
||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: BaseException,
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
@ -69,7 +69,7 @@ class LLMManagerMixin:
|
|||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(
|
||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: BaseException,
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
@ -93,7 +93,7 @@ class ChainManagerMixin:
|
|||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(
|
||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: BaseException,
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
@ -137,7 +137,7 @@ class ToolManagerMixin:
|
|||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(
|
||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: BaseException,
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
@ -344,7 +344,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
async def on_llm_error(
|
async def on_llm_error(
|
||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: BaseException,
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
@ -379,7 +379,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
async def on_chain_error(
|
async def on_chain_error(
|
||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: BaseException,
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
@ -414,7 +414,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
async def on_tool_error(
|
async def on_tool_error(
|
||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: BaseException,
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
@ -492,7 +492,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
async def on_retriever_error(
|
async def on_retriever_error(
|
||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: BaseException,
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import tempfile
|
import tempfile
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
from typing import Any, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.callbacks.utils import (
|
from langchain.callbacks.utils import (
|
||||||
@ -155,9 +155,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
if self.stream_logs:
|
if self.stream_logs:
|
||||||
self.logger.report_text(generation_resp)
|
self.logger.report_text(generation_resp)
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when LLM errors."""
|
"""Run when LLM errors."""
|
||||||
self.step += 1
|
self.step += 1
|
||||||
self.errors += 1
|
self.errors += 1
|
||||||
@ -210,9 +208,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
if self.stream_logs:
|
if self.stream_logs:
|
||||||
self.logger.report_text(resp)
|
self.logger.report_text(resp)
|
||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when chain errors."""
|
"""Run when chain errors."""
|
||||||
self.step += 1
|
self.step += 1
|
||||||
self.errors += 1
|
self.errors += 1
|
||||||
@ -250,9 +246,7 @@ class ClearMLCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
if self.stream_logs:
|
if self.stream_logs:
|
||||||
self.logger.report_text(resp)
|
self.logger.report_text(resp)
|
||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when tool errors."""
|
"""Run when tool errors."""
|
||||||
self.step += 1
|
self.step += 1
|
||||||
self.errors += 1
|
self.errors += 1
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import tempfile
|
import tempfile
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
from typing import Any, Callable, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
import langchain
|
import langchain
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
@ -223,9 +223,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
self._log_text_metrics(output_complexity_metrics, step=self.step)
|
self._log_text_metrics(output_complexity_metrics, step=self.step)
|
||||||
self._log_text_metrics(output_custom_metrics, step=self.step)
|
self._log_text_metrics(output_custom_metrics, step=self.step)
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when LLM errors."""
|
"""Run when LLM errors."""
|
||||||
self.step += 1
|
self.step += 1
|
||||||
self.errors += 1
|
self.errors += 1
|
||||||
@ -280,9 +278,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
f"Output Value for {chain_output_key} will not be logged"
|
f"Output Value for {chain_output_key} will not be logged"
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when chain errors."""
|
"""Run when chain errors."""
|
||||||
self.step += 1
|
self.step += 1
|
||||||
self.errors += 1
|
self.errors += 1
|
||||||
@ -320,9 +316,7 @@ class CometCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
resp.update({"output": output})
|
resp.update({"output": output})
|
||||||
self.action_records.append(resp)
|
self.action_records.append(resp)
|
||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when tool errors."""
|
"""Run when tool errors."""
|
||||||
self.step += 1
|
self.step += 1
|
||||||
self.errors += 1
|
self.errors += 1
|
||||||
|
@ -128,9 +128,7 @@ class DeepEvalCallbackHandler(BaseCallbackHandler):
|
|||||||
callbacks."""
|
callbacks."""
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing when LLM outputs an error."""
|
"""Do nothing when LLM outputs an error."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -144,9 +142,7 @@ class DeepEvalCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Do nothing when chain ends."""
|
"""Do nothing when chain ends."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing when LLM chain outputs an error."""
|
"""Do nothing when LLM chain outputs an error."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -173,9 +169,7 @@ class DeepEvalCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Do nothing when tool ends."""
|
"""Do nothing when tool ends."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing when tool outputs an error."""
|
"""Do nothing when tool outputs an error."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.callbacks.utils import (
|
from langchain.callbacks.utils import (
|
||||||
@ -221,9 +221,7 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
)
|
)
|
||||||
self.deck.append(self.markdown_renderer().to_html(generation.text))
|
self.deck.append(self.markdown_renderer().to_html(generation.text))
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when LLM errors."""
|
"""Run when LLM errors."""
|
||||||
self.step += 1
|
self.step += 1
|
||||||
self.errors += 1
|
self.errors += 1
|
||||||
@ -266,9 +264,7 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
|
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when chain errors."""
|
"""Run when chain errors."""
|
||||||
self.step += 1
|
self.step += 1
|
||||||
self.errors += 1
|
self.errors += 1
|
||||||
@ -306,9 +302,7 @@ class FlyteCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
|
self.table_renderer().to_html(self.pandas.DataFrame([resp])) + "\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when tool errors."""
|
"""Run when tool errors."""
|
||||||
self.step += 1
|
self.step += 1
|
||||||
self.errors += 1
|
self.errors += 1
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
@ -113,9 +113,7 @@ class InfinoCallbackHandler(BaseCallbackHandler):
|
|||||||
for generation in generations:
|
for generation in generations:
|
||||||
self._send_to_infino("prompt_response", generation.text, is_ts=False)
|
self._send_to_infino("prompt_response", generation.text, is_ts=False)
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Set the error flag."""
|
"""Set the error flag."""
|
||||||
self.error = 1
|
self.error = 1
|
||||||
|
|
||||||
@ -129,9 +127,7 @@ class InfinoCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Do nothing when LLM chain ends."""
|
"""Do nothing when LLM chain ends."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Need to log the error."""
|
"""Need to log the error."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -158,9 +154,7 @@ class InfinoCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Do nothing when tool ends."""
|
"""Do nothing when tool ends."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing when tool outputs an error."""
|
"""Do nothing when tool outputs an error."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -334,9 +334,7 @@ class LabelStudioCallbackHandler(BaseCallbackHandler):
|
|||||||
# Pop current run from `self.runs`
|
# Pop current run from `self.runs`
|
||||||
self.payload.pop(run_id)
|
self.payload.pop(run_id)
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing when LLM outputs an error."""
|
"""Do nothing when LLM outputs an error."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -348,9 +346,7 @@ class LabelStudioCallbackHandler(BaseCallbackHandler):
|
|||||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing when LLM chain outputs an error."""
|
"""Do nothing when LLM chain outputs an error."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -377,9 +373,7 @@ class LabelStudioCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Do nothing when tool ends."""
|
"""Do nothing when tool ends."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing when tool outputs an error."""
|
"""Do nothing when tool outputs an error."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -406,7 +406,7 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(
|
||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: BaseException,
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Union[UUID, None] = None,
|
parent_run_id: Union[UUID, None] = None,
|
||||||
@ -423,7 +423,7 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(
|
||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: BaseException,
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Union[UUID, None] = None,
|
parent_run_id: Union[UUID, None] = None,
|
||||||
@ -440,7 +440,7 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(
|
||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: BaseException,
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Union[UUID, None] = None,
|
parent_run_id: Union[UUID, None] = None,
|
||||||
|
@ -707,7 +707,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
|||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(
|
||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: BaseException,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM errors.
|
"""Run when LLM errors.
|
||||||
@ -773,7 +773,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
|||||||
|
|
||||||
async def on_llm_error(
|
async def on_llm_error(
|
||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: BaseException,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM errors.
|
"""Run when LLM errors.
|
||||||
@ -985,7 +985,7 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
|
|||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(
|
||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: BaseException,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool errors.
|
"""Run when tool errors.
|
||||||
@ -1027,7 +1027,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
|
|||||||
|
|
||||||
async def on_tool_error(
|
async def on_tool_error(
|
||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: BaseException,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool errors.
|
"""Run when tool errors.
|
||||||
@ -1069,7 +1069,7 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin):
|
|||||||
|
|
||||||
def on_retriever_error(
|
def on_retriever_error(
|
||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: BaseException,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when retriever errors."""
|
"""Run when retriever errors."""
|
||||||
@ -1108,7 +1108,7 @@ class AsyncCallbackManagerForRetrieverRun(
|
|||||||
|
|
||||||
async def on_retriever_error(
|
async def on_retriever_error(
|
||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: BaseException,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when retriever errors."""
|
"""Run when retriever errors."""
|
||||||
|
@ -384,9 +384,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
self.mlflg.html(dependency_tree, "dep-" + hash_string(generation.text))
|
self.mlflg.html(dependency_tree, "dep-" + hash_string(generation.text))
|
||||||
self.mlflg.html(entities, "ent-" + hash_string(generation.text))
|
self.mlflg.html(entities, "ent-" + hash_string(generation.text))
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when LLM errors."""
|
"""Run when LLM errors."""
|
||||||
self.metrics["step"] += 1
|
self.metrics["step"] += 1
|
||||||
self.metrics["errors"] += 1
|
self.metrics["errors"] += 1
|
||||||
@ -434,9 +432,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
self.records["action_records"].append(resp)
|
self.records["action_records"].append(resp)
|
||||||
self.mlflg.jsonf(resp, f"chain_end_{chain_ends}")
|
self.mlflg.jsonf(resp, f"chain_end_{chain_ends}")
|
||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when chain errors."""
|
"""Run when chain errors."""
|
||||||
self.metrics["step"] += 1
|
self.metrics["step"] += 1
|
||||||
self.metrics["errors"] += 1
|
self.metrics["errors"] += 1
|
||||||
@ -480,9 +476,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
self.records["action_records"].append(resp)
|
self.records["action_records"].append(resp)
|
||||||
self.mlflg.jsonf(resp, f"tool_end_{tool_ends}")
|
self.mlflg.jsonf(resp, f"tool_end_{tool_ends}")
|
||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when tool errors."""
|
"""Run when tool errors."""
|
||||||
self.metrics["step"] += 1
|
self.metrics["step"] += 1
|
||||||
self.metrics["errors"] += 1
|
self.metrics["errors"] += 1
|
||||||
|
@ -3,7 +3,7 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.callbacks.utils import (
|
from langchain.callbacks.utils import (
|
||||||
@ -121,9 +121,7 @@ class SageMakerCallbackHandler(BaseCallbackHandler):
|
|||||||
f"llm_end_{llm_ends}_generation_{idx}",
|
f"llm_end_{llm_ends}_generation_{idx}",
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when LLM errors."""
|
"""Run when LLM errors."""
|
||||||
self.metrics["step"] += 1
|
self.metrics["step"] += 1
|
||||||
self.metrics["errors"] += 1
|
self.metrics["errors"] += 1
|
||||||
@ -164,9 +162,7 @@ class SageMakerCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
self.jsonf(resp, self.temp_dir, f"chain_end_{chain_ends}")
|
self.jsonf(resp, self.temp_dir, f"chain_end_{chain_ends}")
|
||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when chain errors."""
|
"""Run when chain errors."""
|
||||||
self.metrics["step"] += 1
|
self.metrics["step"] += 1
|
||||||
self.metrics["errors"] += 1
|
self.metrics["errors"] += 1
|
||||||
@ -202,9 +198,7 @@ class SageMakerCallbackHandler(BaseCallbackHandler):
|
|||||||
|
|
||||||
self.jsonf(resp, self.temp_dir, f"tool_end_{tool_ends}")
|
self.jsonf(resp, self.temp_dir, f"tool_end_{tool_ends}")
|
||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when tool errors."""
|
"""Run when tool errors."""
|
||||||
self.metrics["step"] += 1
|
self.metrics["step"] += 1
|
||||||
self.metrics["errors"] += 1
|
self.metrics["errors"] += 1
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""Callback Handler that prints to std out."""
|
"""Callback Handler that prints to std out."""
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
@ -27,9 +27,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Do nothing."""
|
"""Do nothing."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing."""
|
"""Do nothing."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -44,9 +42,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
"""Print out that we finished a chain."""
|
"""Print out that we finished a chain."""
|
||||||
print("\n\033[1m> Finished chain.\033[0m")
|
print("\n\033[1m> Finished chain.\033[0m")
|
||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing."""
|
"""Do nothing."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -80,9 +76,7 @@ class StdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
if llm_prefix is not None:
|
if llm_prefix is not None:
|
||||||
print_text(f"\n{llm_prefix}")
|
print_text(f"\n{llm_prefix}")
|
||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Do nothing."""
|
"""Do nothing."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -37,9 +37,7 @@ class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
|||||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
self.done.set()
|
self.done.set()
|
||||||
|
|
||||||
async def on_llm_error(
|
async def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
self.done.set()
|
self.done.set()
|
||||||
|
|
||||||
# TODO implement the other methods
|
# TODO implement the other methods
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Callback Handler streams to stdout on new llm token."""
|
"""Callback Handler streams to stdout on new llm token."""
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
@ -31,9 +31,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
"""Run when LLM ends running."""
|
"""Run when LLM ends running."""
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when LLM errors."""
|
"""Run when LLM errors."""
|
||||||
|
|
||||||
def on_chain_start(
|
def on_chain_start(
|
||||||
@ -44,9 +42,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||||
"""Run when chain ends running."""
|
"""Run when chain ends running."""
|
||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when chain errors."""
|
"""Run when chain errors."""
|
||||||
|
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
@ -61,9 +57,7 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler):
|
|||||||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
"""Run when tool ends running."""
|
"""Run when tool ends running."""
|
||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when tool errors."""
|
"""Run when tool errors."""
|
||||||
|
|
||||||
def on_text(self, text: str, **kwargs: Any) -> None:
|
def on_text(self, text: str, **kwargs: Any) -> None:
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.callbacks.streamlit.mutable_expander import MutableExpander
|
from langchain.callbacks.streamlit.mutable_expander import MutableExpander
|
||||||
@ -163,9 +163,7 @@ class LLMThought:
|
|||||||
# data is redundant
|
# data is redundant
|
||||||
self._reset_llm_token_stream()
|
self._reset_llm_token_stream()
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
self._container.markdown("**LLM encountered an error...**")
|
self._container.markdown("**LLM encountered an error...**")
|
||||||
self._container.exception(error)
|
self._container.exception(error)
|
||||||
|
|
||||||
@ -191,9 +189,7 @@ class LLMThought:
|
|||||||
) -> None:
|
) -> None:
|
||||||
self._container.markdown(f"**{output}**")
|
self._container.markdown(f"**{output}**")
|
||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
self._container.markdown("**Tool encountered an error...**")
|
self._container.markdown("**Tool encountered an error...**")
|
||||||
self._container.exception(error)
|
self._container.exception(error)
|
||||||
|
|
||||||
@ -353,9 +349,7 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
|
|||||||
self._require_current_thought().on_llm_end(response, **kwargs)
|
self._require_current_thought().on_llm_end(response, **kwargs)
|
||||||
self._prune_old_thought_containers()
|
self._prune_old_thought_containers()
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
self._require_current_thought().on_llm_error(error, **kwargs)
|
self._require_current_thought().on_llm_error(error, **kwargs)
|
||||||
self._prune_old_thought_containers()
|
self._prune_old_thought_containers()
|
||||||
|
|
||||||
@ -378,9 +372,7 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
|
|||||||
)
|
)
|
||||||
self._complete_current_thought()
|
self._complete_current_thought()
|
||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
self._require_current_thought().on_tool_error(error, **kwargs)
|
self._require_current_thought().on_tool_error(error, **kwargs)
|
||||||
self._prune_old_thought_containers()
|
self._prune_old_thought_containers()
|
||||||
|
|
||||||
@ -401,9 +393,7 @@ class StreamlitCallbackHandler(BaseCallbackHandler):
|
|||||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_agent_action(
|
def on_agent_action(
|
||||||
|
@ -211,7 +211,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(
|
||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: BaseException,
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -294,7 +294,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(
|
||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: BaseException,
|
||||||
*,
|
*,
|
||||||
inputs: Optional[Dict[str, Any]] = None,
|
inputs: Optional[Dict[str, Any]] = None,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
@ -365,7 +365,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(
|
||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: BaseException,
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -420,7 +420,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
|
|
||||||
def on_retriever_error(
|
def on_retriever_error(
|
||||||
self,
|
self,
|
||||||
error: Union[Exception, KeyboardInterrupt],
|
error: BaseException,
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
|
@ -282,9 +282,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
if self.stream_logs:
|
if self.stream_logs:
|
||||||
self.run.log(generation_resp)
|
self.run.log(generation_resp)
|
||||||
|
|
||||||
def on_llm_error(
|
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when LLM errors."""
|
"""Run when LLM errors."""
|
||||||
self.step += 1
|
self.step += 1
|
||||||
self.errors += 1
|
self.errors += 1
|
||||||
@ -337,9 +335,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
if self.stream_logs:
|
if self.stream_logs:
|
||||||
self.run.log(resp)
|
self.run.log(resp)
|
||||||
|
|
||||||
def on_chain_error(
|
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when chain errors."""
|
"""Run when chain errors."""
|
||||||
self.step += 1
|
self.step += 1
|
||||||
self.errors += 1
|
self.errors += 1
|
||||||
@ -377,9 +373,7 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
if self.stream_logs:
|
if self.stream_logs:
|
||||||
self.run.log(resp)
|
self.run.log(resp)
|
||||||
|
|
||||||
def on_tool_error(
|
def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
|
||||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
|
||||||
) -> None:
|
|
||||||
"""Run when tool errors."""
|
"""Run when tool errors."""
|
||||||
self.step += 1
|
self.step += 1
|
||||||
self.errors += 1
|
self.errors += 1
|
||||||
|
@ -287,7 +287,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
if new_arg_supported
|
if new_arg_supported
|
||||||
else self._call(inputs)
|
else self._call(inputs)
|
||||||
)
|
)
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise e
|
raise e
|
||||||
run_manager.on_chain_end(outputs)
|
run_manager.on_chain_end(outputs)
|
||||||
@ -356,7 +356,7 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
|
|||||||
if new_arg_supported
|
if new_arg_supported
|
||||||
else await self._acall(inputs)
|
else await self._acall(inputs)
|
||||||
)
|
)
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except BaseException as e:
|
||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
raise e
|
raise e
|
||||||
await run_manager.on_chain_end(outputs)
|
await run_manager.on_chain_end(outputs)
|
||||||
|
@ -186,7 +186,7 @@ class LLMChain(Chain):
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
response = self.generate(input_list, run_manager=run_manager)
|
response = self.generate(input_list, run_manager=run_manager)
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise e
|
raise e
|
||||||
outputs = self.create_outputs(response)
|
outputs = self.create_outputs(response)
|
||||||
@ -206,7 +206,7 @@ class LLMChain(Chain):
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
response = await self.agenerate(input_list, run_manager=run_manager)
|
response = await self.agenerate(input_list, run_manager=run_manager)
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except BaseException as e:
|
||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
raise e
|
raise e
|
||||||
outputs = self.create_outputs(response)
|
outputs = self.create_outputs(response)
|
||||||
|
@ -186,7 +186,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
else:
|
else:
|
||||||
generation += chunk
|
generation += chunk
|
||||||
assert generation is not None
|
assert generation is not None
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except BaseException as e:
|
||||||
run_manager.on_llm_error(e)
|
run_manager.on_llm_error(e)
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
@ -233,7 +233,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
else:
|
else:
|
||||||
generation += chunk
|
generation += chunk
|
||||||
assert generation is not None
|
assert generation is not None
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except BaseException as e:
|
||||||
await run_manager.on_llm_error(e)
|
await run_manager.on_llm_error(e)
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
@ -303,7 +303,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except BaseException as e:
|
||||||
if run_managers:
|
if run_managers:
|
||||||
run_managers[i].on_llm_error(e)
|
run_managers[i].on_llm_error(e)
|
||||||
raise e
|
raise e
|
||||||
@ -364,7 +364,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
|
|||||||
)
|
)
|
||||||
exceptions = []
|
exceptions = []
|
||||||
for i, res in enumerate(results):
|
for i, res in enumerate(results):
|
||||||
if isinstance(res, Exception):
|
if isinstance(res, BaseException):
|
||||||
if run_managers:
|
if run_managers:
|
||||||
await run_managers[i].on_llm_error(res)
|
await run_managers[i].on_llm_error(res)
|
||||||
exceptions.append(res)
|
exceptions.append(res)
|
||||||
|
@ -388,7 +388,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
else:
|
else:
|
||||||
generation += chunk
|
generation += chunk
|
||||||
assert generation is not None
|
assert generation is not None
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except BaseException as e:
|
||||||
run_manager.on_llm_error(e)
|
run_manager.on_llm_error(e)
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
@ -435,7 +435,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
else:
|
else:
|
||||||
generation += chunk
|
generation += chunk
|
||||||
assert generation is not None
|
assert generation is not None
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except BaseException as e:
|
||||||
await run_manager.on_llm_error(e)
|
await run_manager.on_llm_error(e)
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
@ -523,7 +523,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
if new_arg_supported
|
if new_arg_supported
|
||||||
else self._generate(prompts, stop=stop)
|
else self._generate(prompts, stop=stop)
|
||||||
)
|
)
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except BaseException as e:
|
||||||
for run_manager in run_managers:
|
for run_manager in run_managers:
|
||||||
run_manager.on_llm_error(e)
|
run_manager.on_llm_error(e)
|
||||||
raise e
|
raise e
|
||||||
@ -674,7 +674,7 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
if new_arg_supported
|
if new_arg_supported
|
||||||
else await self._agenerate(prompts, stop=stop)
|
else await self._agenerate(prompts, stop=stop)
|
||||||
)
|
)
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except BaseException as e:
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[run_manager.on_llm_error(e) for run_manager in run_managers]
|
*[run_manager.on_llm_error(e) for run_manager in run_managers]
|
||||||
)
|
)
|
||||||
|
@ -319,7 +319,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
output = call_func_with_variable_args(func, input, run_manager, config)
|
output = call_func_with_variable_args(func, input, run_manager, config)
|
||||||
except Exception as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
@ -354,7 +354,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
output = await acall_func_with_variable_args(
|
output = await acall_func_with_variable_args(
|
||||||
func, input, run_manager, config
|
func, input, run_manager, config
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except BaseException as e:
|
||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
@ -408,7 +408,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
if accepts_run_manager(func):
|
if accepts_run_manager(func):
|
||||||
kwargs["run_manager"] = run_managers
|
kwargs["run_manager"] = run_managers
|
||||||
output = func(input, **kwargs) # type: ignore[call-arg]
|
output = func(input, **kwargs) # type: ignore[call-arg]
|
||||||
except Exception as e:
|
except BaseException as e:
|
||||||
for run_manager in run_managers:
|
for run_manager in run_managers:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
@ -481,7 +481,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
if accepts_run_manager(func):
|
if accepts_run_manager(func):
|
||||||
kwargs["run_manager"] = run_managers
|
kwargs["run_manager"] = run_managers
|
||||||
output = await func(input, **kwargs) # type: ignore[call-arg]
|
output = await func(input, **kwargs) # type: ignore[call-arg]
|
||||||
except Exception as e:
|
except BaseException as e:
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*(run_manager.on_chain_error(e) for run_manager in run_managers)
|
*(run_manager.on_chain_error(e) for run_manager in run_managers)
|
||||||
)
|
)
|
||||||
@ -573,7 +573,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
final_input = None
|
final_input = None
|
||||||
final_input_supported = False
|
final_input_supported = False
|
||||||
except Exception as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e, inputs=final_input)
|
run_manager.on_chain_error(e, inputs=final_input)
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
@ -651,7 +651,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
final_input = None
|
final_input = None
|
||||||
final_input_supported = False
|
final_input_supported = False
|
||||||
except Exception as e:
|
except BaseException as e:
|
||||||
await run_manager.on_chain_error(e, inputs=final_input)
|
await run_manager.on_chain_error(e, inputs=final_input)
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
@ -981,7 +981,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
# finish the root run
|
# finish the root run
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
@ -1013,7 +1013,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
# finish the root run
|
# finish the root run
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except BaseException as e:
|
||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
@ -1119,7 +1119,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# finish the root runs
|
# finish the root runs
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except BaseException as e:
|
||||||
for rm in run_managers:
|
for rm in run_managers:
|
||||||
rm.on_chain_error(e)
|
rm.on_chain_error(e)
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
@ -1242,7 +1242,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
# finish the root runs
|
# finish the root runs
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except BaseException as e:
|
||||||
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
|
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
|
||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
return cast(List[Output], [e for _ in inputs])
|
return cast(List[Output], [e for _ in inputs])
|
||||||
@ -1450,7 +1450,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
]
|
]
|
||||||
output = {key: future.result() for key, future in zip(steps, futures)}
|
output = {key: future.result() for key, future in zip(steps, futures)}
|
||||||
# finish the root run
|
# finish the root run
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
@ -1489,7 +1489,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|||||||
)
|
)
|
||||||
output = {key: value for key, value in zip(steps, results)}
|
output = {key: value for key, value in zip(steps, results)}
|
||||||
# finish the root run
|
# finish the root run
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except BaseException as e:
|
||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user