From 98dbea6310828ec817dd1f89dfa26a7c1c2de579 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 3 Jul 2023 18:39:46 +0100 Subject: [PATCH] Add tags to all callback handler methods (#7073) --- langchain/callbacks/base.py | 13 +++++++++++++ langchain/callbacks/manager.py | 26 ++++++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index 4d9d4c75c6..a3a9bbe0f7 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -291,6 +291,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run on new LLM token. Only available when streaming is enabled.""" @@ -301,6 +302,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run when LLM ends running.""" @@ -311,6 +313,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run when LLM errors.""" @@ -333,6 +336,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run when chain ends running.""" @@ -343,6 +347,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run when chain errors.""" @@ -365,6 +370,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run when tool ends running.""" @@ -375,6 +381,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run when tool errors.""" @@ -385,6 +392,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run on arbitrary text.""" @@ -395,6 +403,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run on agent action.""" @@ -405,6 +414,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run on agent end.""" @@ -415,6 +425,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run on retriever start.""" @@ -425,6 +436,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run on retriever end.""" @@ -435,6 +447,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): *, run_id: UUID, parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run on retriever error.""" diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index 00d87af384..42783176e2 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -442,6 +442,7 @@ class RunManager(BaseRunManager): text, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -469,6 +470,7 @@ class AsyncRunManager(BaseRunManager): text, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -493,6 +495,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): token=token, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -509,6 +512,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): response, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -529,6 +533,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): error, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -553,6 +558,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): token, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -569,6 +575,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): response, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -589,6 +596,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): error, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -626,6 +634,7 @@ class CallbackManagerForChainRun(RunManager, ChainManagerMixin): outputs, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -646,6 +655,7 @@ class CallbackManagerForChainRun(RunManager, ChainManagerMixin): error, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -665,6 +675,7 @@ class CallbackManagerForChainRun(RunManager, ChainManagerMixin): action, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -684,6 +695,7 @@ class CallbackManagerForChainRun(RunManager, ChainManagerMixin): finish, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -721,6 +733,7 @@ class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin): outputs, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -741,6 +754,7 @@ class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin): error, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -760,6 +774,7 @@ class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin): action, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -779,6 +794,7 @@ class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin): finish, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -820,6 +836,7 @@ class CallbackManagerForToolRun(RunManager, ToolManagerMixin): output, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -840,6 +857,7 @@ class CallbackManagerForToolRun(RunManager, ToolManagerMixin): error, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -877,6 +895,7 @@ class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin): output, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -897,6 +916,7 @@ class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin): error, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -926,6 +946,7 @@ class CallbackManagerForRetrieverRun(RunManager, RetrieverManagerMixin): documents, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -942,6 +963,7 @@ class CallbackManagerForRetrieverRun(RunManager, RetrieverManagerMixin): error, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -972,6 +994,7 @@ class AsyncCallbackManagerForRetrieverRun( documents, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -988,6 +1011,7 @@ class AsyncCallbackManagerForRetrieverRun( error, run_id=self.run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -1188,6 +1212,7 @@ class CallbackManager(BaseCallbackManager): query, run_id=run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, ) @@ -1454,6 +1479,7 @@ class AsyncCallbackManager(BaseCallbackManager): query, run_id=run_id, parent_run_id=self.parent_run_id, + tags=self.tags, **kwargs, )