From 4c1c05c2c7dcbd3f8a023983b4b63f04fbd282c2 Mon Sep 17 00:00:00 2001 From: Ankush Gola <9536492+agola11@users.noreply.github.com> Date: Wed, 5 Jul 2023 11:11:38 -0700 Subject: [PATCH] support adding custom metadata to runs (#7120) - [x] wire up tools - [x] wire up retrievers - [x] add integration test --- langchain/callbacks/base.py | 24 ++ langchain/callbacks/manager.py | 207 ++++++++++-------- langchain/callbacks/tracers/base.py | 12 + langchain/callbacks/tracers/langchain.py | 3 + langchain/chains/base.py | 54 ++++- langchain/chat_models/base.py | 8 + .../agents/trajectory_eval_chain.py | 3 + langchain/llms/base.py | 42 +++- langchain/schema/retriever.py | 60 ++++- langchain/tools/base.py | 36 ++- .../callbacks/test_langchain_tracer.py | 34 +++ 11 files changed, 369 insertions(+), 114 deletions(-) diff --git a/langchain/callbacks/base.py b/langchain/callbacks/base.py index 039b4fd1e2..930d03fbbe 100644 --- a/langchain/callbacks/base.py +++ b/langchain/callbacks/base.py @@ -147,6 +147,7 @@ class CallbackManagerMixin: run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: """Run when LLM starts running.""" @@ -159,6 +160,7 @@ class CallbackManagerMixin: run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: """Run when a chat model starts running.""" @@ -174,6 +176,7 @@ class CallbackManagerMixin: run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: """Run when Retriever starts running.""" @@ -186,6 +189,7 @@ class CallbackManagerMixin: run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: """Run when chain starts running.""" @@ -198,6 +202,7 @@ class CallbackManagerMixin: run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: """Run when tool starts running.""" @@ -268,6 +273,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: """Run when LLM starts running.""" @@ -280,6 +286,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: """Run when a chat model starts running.""" @@ -328,6 +335,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: """Run when chain starts running.""" @@ -362,6 +370,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: """Run when tool starts running.""" @@ -429,6 +438,7 @@ class AsyncCallbackHandler(BaseCallbackHandler): run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: """Run on retriever start.""" @@ -467,6 +477,8 @@ class BaseCallbackManager(CallbackManagerMixin): *, tags: Optional[List[str]] = None, inheritable_tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + inheritable_metadata: Optional[Dict[str, Any]] = None, ) -> None: """Initialize callback manager.""" self.handlers: List[BaseCallbackHandler] = handlers @@ -476,6 +488,8 @@ class BaseCallbackManager(CallbackManagerMixin): self.parent_run_id: Optional[UUID] = parent_run_id self.tags = tags or [] self.inheritable_tags = inheritable_tags or [] + self.metadata = metadata or {} + self.inheritable_metadata = inheritable_metadata or {} @property def is_async(self) -> bool: @@ -518,3 +532,13 @@ class BaseCallbackManager(CallbackManagerMixin): for tag in tags: self.tags.remove(tag) self.inheritable_tags.remove(tag) + + def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None: + self.metadata.update(metadata) + if inherit: + self.inheritable_metadata.update(metadata) + + def remove_metadata(self, keys: List[str]) -> None: + for key in keys: + self.metadata.pop(key) + self.inheritable_metadata.pop(key) diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index 3ae9e61240..77a83c4123 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -383,6 +383,8 @@ class BaseRunManager(RunManagerMixin): parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, inheritable_tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + inheritable_metadata: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the run manager. @@ -395,6 +397,8 @@ class BaseRunManager(RunManagerMixin): Defaults to None. tags (Optional[List[str]]): The list of tags. inheritable_tags (Optional[List[str]]): The list of inheritable tags. + metadata (Optional[Dict[str, Any]]): The metadata. + inheritable_metadata (Optional[Dict[str, Any]]): The inheritable metadata. """ self.run_id = run_id self.handlers = handlers @@ -402,6 +406,8 @@ class BaseRunManager(RunManagerMixin): self.parent_run_id = parent_run_id self.tags = tags or [] self.inheritable_tags = inheritable_tags or [] + self.metadata = metadata or {} + self.inheritable_metadata = inheritable_metadata or {} @classmethod def get_noop_manager(cls: Type[BRM]) -> BRM: @@ -416,6 +422,8 @@ class BaseRunManager(RunManagerMixin): inheritable_handlers=[], tags=[], inheritable_tags=[], + metadata={}, + inheritable_metadata={}, ) @@ -447,6 +455,28 @@ class RunManager(BaseRunManager): ) +class ParentRunManager(RunManager): + """Sync Parent Run Manager.""" + + def get_child(self, tag: Optional[str] = None) -> CallbackManager: + """Get a child callback manager. + + Args: + tag (str, optional): The tag for the child callback manager. + Defaults to None. + + Returns: + CallbackManager: The child callback manager. + """ + manager = CallbackManager(handlers=[], parent_run_id=self.run_id) + manager.set_handlers(self.inheritable_handlers) + manager.add_tags(self.inheritable_tags) + manager.add_metadata(self.inheritable_metadata) + if tag is not None: + manager.add_tags([tag], False) + return manager + + class AsyncRunManager(BaseRunManager): """Async Run Manager.""" @@ -475,6 +505,28 @@ class AsyncRunManager(BaseRunManager): ) +class AsyncParentRunManager(AsyncRunManager): + """Async Parent Run Manager.""" + + def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager: + """Get a child callback manager. + + Args: + tag (str, optional): The tag for the child callback manager. + Defaults to None. + + Returns: + AsyncCallbackManager: The child callback manager. + """ + manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id) + manager.set_handlers(self.inheritable_handlers) + manager.add_tags(self.inheritable_tags) + manager.add_metadata(self.inheritable_metadata) + if tag is not None: + manager.add_tags([tag], False) + return manager + + class CallbackManagerForLLMRun(RunManager, LLMManagerMixin): """Callback manager for LLM run.""" @@ -601,26 +653,9 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin): ) -class CallbackManagerForChainRun(RunManager, ChainManagerMixin): +class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin): """Callback manager for chain run.""" - def get_child(self, tag: Optional[str] = None) -> CallbackManager: - """Get a child callback manager. - - Args: - tag (str, optional): The tag for the child callback manager. - Defaults to None. - - Returns: - CallbackManager: The child callback manager. - """ - manager = CallbackManager(handlers=[], parent_run_id=self.run_id) - manager.set_handlers(self.inheritable_handlers) - manager.add_tags(self.inheritable_tags) - if tag is not None: - manager.add_tags([tag], False) - return manager - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """Run when chain ends running. @@ -700,26 +735,9 @@ class CallbackManagerForChainRun(RunManager, ChainManagerMixin): ) -class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin): +class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin): """Async callback manager for chain run.""" - def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager: - """Get a child callback manager. - - Args: - tag (str, optional): The tag for the child callback manager. - Defaults to None. - - Returns: - AsyncCallbackManager: The child callback manager. - """ - manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id) - manager.set_handlers(self.inheritable_handlers) - manager.add_tags(self.inheritable_tags) - if tag is not None: - manager.add_tags([tag], False) - return manager - async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: """Run when chain ends running. @@ -799,26 +817,9 @@ class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin): ) -class CallbackManagerForToolRun(RunManager, ToolManagerMixin): +class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin): """Callback manager for tool run.""" - def get_child(self, tag: Optional[str] = None) -> CallbackManager: - """Get a child callback manager. - - Args: - tag (str, optional): The tag for the child callback manager. - Defaults to None. - - Returns: - CallbackManager: The child callback manager. - """ - manager = CallbackManager(handlers=[], parent_run_id=self.run_id) - manager.set_handlers(self.inheritable_handlers) - manager.add_tags(self.inheritable_tags) - if tag is not None: - manager.add_tags([tag], False) - return manager - def on_tool_end( self, output: str, @@ -862,26 +863,9 @@ class CallbackManagerForToolRun(RunManager, ToolManagerMixin): ) -class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin): +class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin): """Async callback manager for tool run.""" - def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager: - """Get a child callback manager. - - Args: - tag (str, optional): The tag to add to the child - callback manager. Defaults to None. - - Returns: - AsyncCallbackManager: The child callback manager. - """ - manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id) - manager.set_handlers(self.inheritable_handlers) - manager.add_tags(self.inheritable_tags) - if tag is not None: - manager.add_tags([tag], False) - return manager - async def on_tool_end(self, output: str, **kwargs: Any) -> None: """Run when tool ends running. @@ -921,18 +905,9 @@ class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin): ) -class CallbackManagerForRetrieverRun(RunManager, RetrieverManagerMixin): +class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin): """Callback manager for retriever run.""" - def get_child(self, tag: Optional[str] = None) -> CallbackManager: - """Get a child callback manager.""" - manager = CallbackManager([], parent_run_id=self.run_id) - manager.set_handlers(self.inheritable_handlers) - manager.add_tags(self.inheritable_tags) - if tag is not None: - manager.add_tags([tag], False) - return manager - def on_retriever_end( self, documents: Sequence[Document], @@ -969,20 +944,11 @@ class CallbackManagerForRetrieverRun(RunManager, RetrieverManagerMixin): class AsyncCallbackManagerForRetrieverRun( - AsyncRunManager, + AsyncParentRunManager, RetrieverManagerMixin, ): """Async callback manager for retriever run.""" - def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager: - """Get a child callback manager.""" - manager = AsyncCallbackManager([], parent_run_id=self.run_id) - manager.set_handlers(self.inheritable_handlers) - manager.add_tags(self.inheritable_tags) - if tag is not None: - manager.add_tags([tag], False) - return manager - async def on_retriever_end( self, documents: Sequence[Document], **kwargs: Any ) -> None: @@ -1048,6 +1014,7 @@ class CallbackManager(BaseCallbackManager): run_id=run_id_, parent_run_id=self.parent_run_id, tags=self.tags, + metadata=self.metadata, **kwargs, ) @@ -1059,6 +1026,8 @@ class CallbackManager(BaseCallbackManager): parent_run_id=self.parent_run_id, tags=self.tags, inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, ) ) @@ -1094,6 +1063,7 @@ class CallbackManager(BaseCallbackManager): run_id=run_id_, parent_run_id=self.parent_run_id, tags=self.tags, + metadata=self.metadata, **kwargs, ) @@ -1105,6 +1075,8 @@ class CallbackManager(BaseCallbackManager): parent_run_id=self.parent_run_id, tags=self.tags, inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, ) ) @@ -1139,6 +1111,7 @@ class CallbackManager(BaseCallbackManager): run_id=run_id, parent_run_id=self.parent_run_id, tags=self.tags, + metadata=self.metadata, **kwargs, ) @@ -1149,6 +1122,8 @@ class CallbackManager(BaseCallbackManager): parent_run_id=self.parent_run_id, tags=self.tags, inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, ) def on_tool_start( @@ -1182,6 +1157,7 @@ class CallbackManager(BaseCallbackManager): run_id=run_id, parent_run_id=self.parent_run_id, tags=self.tags, + metadata=self.metadata, **kwargs, ) @@ -1192,6 +1168,8 @@ class CallbackManager(BaseCallbackManager): parent_run_id=self.parent_run_id, tags=self.tags, inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, ) def on_retriever_start( @@ -1215,6 +1193,7 @@ class CallbackManager(BaseCallbackManager): run_id=run_id, parent_run_id=self.parent_run_id, tags=self.tags, + metadata=self.metadata, **kwargs, ) @@ -1225,6 +1204,8 @@ class CallbackManager(BaseCallbackManager): parent_run_id=self.parent_run_id, tags=self.tags, inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, ) @classmethod @@ -1235,6 +1216,8 @@ class CallbackManager(BaseCallbackManager): verbose: bool = False, inheritable_tags: Optional[List[str]] = None, local_tags: Optional[List[str]] = None, + inheritable_metadata: Optional[Dict[str, Any]] = None, + local_metadata: Optional[Dict[str, Any]] = None, ) -> CallbackManager: """Configure the callback manager. @@ -1248,6 +1231,10 @@ class CallbackManager(BaseCallbackManager): Defaults to None. local_tags (Optional[List[str]], optional): The local tags. Defaults to None. + inheritable_metadata (Optional[Dict[str, Any]], optional): The inheritable + metadata. Defaults to None. + local_metadata (Optional[Dict[str, Any]], optional): The local metadata. + Defaults to None. Returns: CallbackManager: The configured callback manager. @@ -1259,6 +1246,8 @@ class CallbackManager(BaseCallbackManager): verbose, inheritable_tags, local_tags, + inheritable_metadata, + local_metadata, ) @@ -1305,6 +1294,7 @@ class AsyncCallbackManager(BaseCallbackManager): run_id=run_id_, parent_run_id=self.parent_run_id, tags=self.tags, + metadata=self.metadata, **kwargs, ) ) @@ -1317,6 +1307,8 @@ class AsyncCallbackManager(BaseCallbackManager): parent_run_id=self.parent_run_id, tags=self.tags, inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, ) ) @@ -1358,6 +1350,7 @@ class AsyncCallbackManager(BaseCallbackManager): run_id=run_id_, parent_run_id=self.parent_run_id, tags=self.tags, + metadata=self.metadata, **kwargs, ) ) @@ -1370,6 +1363,8 @@ class AsyncCallbackManager(BaseCallbackManager): parent_run_id=self.parent_run_id, tags=self.tags, inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, ) ) @@ -1406,6 +1401,7 @@ class AsyncCallbackManager(BaseCallbackManager): run_id=run_id, parent_run_id=self.parent_run_id, tags=self.tags, + metadata=self.metadata, **kwargs, ) @@ -1416,6 +1412,8 @@ class AsyncCallbackManager(BaseCallbackManager): parent_run_id=self.parent_run_id, tags=self.tags, inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, ) async def on_tool_start( @@ -1451,6 +1449,7 @@ class AsyncCallbackManager(BaseCallbackManager): run_id=run_id, parent_run_id=self.parent_run_id, tags=self.tags, + metadata=self.metadata, **kwargs, ) @@ -1461,6 +1460,8 @@ class AsyncCallbackManager(BaseCallbackManager): parent_run_id=self.parent_run_id, tags=self.tags, inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, ) async def on_retriever_start( @@ -1484,6 +1485,7 @@ class AsyncCallbackManager(BaseCallbackManager): run_id=run_id, parent_run_id=self.parent_run_id, tags=self.tags, + metadata=self.metadata, **kwargs, ) @@ -1494,6 +1496,8 @@ class AsyncCallbackManager(BaseCallbackManager): parent_run_id=self.parent_run_id, tags=self.tags, inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, ) @classmethod @@ -1504,6 +1508,8 @@ class AsyncCallbackManager(BaseCallbackManager): verbose: bool = False, inheritable_tags: Optional[List[str]] = None, local_tags: Optional[List[str]] = None, + inheritable_metadata: Optional[Dict[str, Any]] = None, + local_metadata: Optional[Dict[str, Any]] = None, ) -> AsyncCallbackManager: """Configure the async callback manager. @@ -1517,6 +1523,10 @@ class AsyncCallbackManager(BaseCallbackManager): Defaults to None. local_tags (Optional[List[str]], optional): The local tags. Defaults to None. + inheritable_metadata (Optional[Dict[str, Any]], optional): The inheritable + metadata. Defaults to None. + local_metadata (Optional[Dict[str, Any]], optional): The local metadata. + Defaults to None. Returns: AsyncCallbackManager: The configured async callback manager. @@ -1528,6 +1538,8 @@ class AsyncCallbackManager(BaseCallbackManager): verbose, inheritable_tags, local_tags, + inheritable_metadata, + local_metadata, ) @@ -1558,6 +1570,8 @@ def _configure( verbose: bool = False, inheritable_tags: Optional[List[str]] = None, local_tags: Optional[List[str]] = None, + inheritable_metadata: Optional[Dict[str, Any]] = None, + local_metadata: Optional[Dict[str, Any]] = None, ) -> T: """Configure the callback manager. @@ -1571,6 +1585,10 @@ def _configure( inheritable_tags (Optional[List[str]], optional): The inheritable tags. Defaults to None. local_tags (Optional[List[str]], optional): The local tags. Defaults to None. + inheritable_metadata (Optional[Dict[str, Any]], optional): The inheritable + metadata. Defaults to None. + local_metadata (Optional[Dict[str, Any]], optional): The local metadata. + Defaults to None. Returns: T: The configured callback manager. @@ -1590,6 +1608,8 @@ def _configure( parent_run_id=inheritable_callbacks.parent_run_id, tags=inheritable_callbacks.tags, inheritable_tags=inheritable_callbacks.inheritable_tags, + metadata=inheritable_callbacks.metadata, + inheritable_metadata=inheritable_callbacks.inheritable_metadata, ) local_handlers_ = ( local_callbacks @@ -1601,6 +1621,9 @@ def _configure( if inheritable_tags or local_tags: callback_manager.add_tags(inheritable_tags or []) callback_manager.add_tags(local_tags or [], False) + if inheritable_metadata or local_metadata: + callback_manager.add_metadata(inheritable_metadata or {}) + callback_manager.add_metadata(local_metadata or {}, False) tracer = tracing_callback_var.get() wandb_tracer = wandb_tracing_callback_var.get() diff --git a/langchain/callbacks/tracers/base.py b/langchain/callbacks/tracers/base.py index 0c25264ae4..4df1df07de 100644 --- a/langchain/callbacks/tracers/base.py +++ b/langchain/callbacks/tracers/base.py @@ -89,12 +89,15 @@ class BaseTracer(BaseCallbackHandler, ABC): run_id: UUID, tags: Optional[List[str]] = None, parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: """Start a trace for an LLM run.""" parent_run_id_ = str(parent_run_id) if parent_run_id else None execution_order = self._get_execution_order(parent_run_id_) start_time = datetime.utcnow() + if metadata: + kwargs.update({"metadata": metadata}) llm_run = Run( id=run_id, parent_run_id=parent_run_id, @@ -186,12 +189,15 @@ class BaseTracer(BaseCallbackHandler, ABC): run_id: UUID, tags: Optional[List[str]] = None, parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: """Start a trace for a chain run.""" parent_run_id_ = str(parent_run_id) if parent_run_id else None execution_order = self._get_execution_order(parent_run_id_) start_time = datetime.utcnow() + if metadata: + kwargs.update({"metadata": metadata}) chain_run = Run( id=run_id, parent_run_id=parent_run_id, @@ -253,12 +259,15 @@ class BaseTracer(BaseCallbackHandler, ABC): run_id: UUID, tags: Optional[List[str]] = None, parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: """Start a trace for a tool run.""" parent_run_id_ = str(parent_run_id) if parent_run_id else None execution_order = self._get_execution_order(parent_run_id_) start_time = datetime.utcnow() + if metadata: + kwargs.update({"metadata": metadata}) tool_run = Run( id=run_id, parent_run_id=parent_run_id, @@ -317,12 +326,15 @@ class BaseTracer(BaseCallbackHandler, ABC): *, run_id: UUID, parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: """Run when Retriever starts running.""" parent_run_id_ = str(parent_run_id) if parent_run_id else None execution_order = self._get_execution_order(parent_run_id_) start_time = datetime.utcnow() + if metadata: + kwargs.update({"metadata": metadata}) retrieval_run = Run( id=run_id, name="Retriever", diff --git a/langchain/callbacks/tracers/langchain.py b/langchain/callbacks/tracers/langchain.py index 378dd62d0c..5759019cc7 100644 --- a/langchain/callbacks/tracers/langchain.py +++ b/langchain/callbacks/tracers/langchain.py @@ -70,12 +70,15 @@ class LangChainTracer(BaseTracer): run_id: UUID, tags: Optional[List[str]] = None, parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: """Start a trace for an LLM run.""" parent_run_id_ = str(parent_run_id) if parent_run_id else None execution_order = self._get_execution_order(parent_run_id_) start_time = datetime.utcnow() + if metadata: + kwargs.update({"metadata": metadata}) chat_model_run = Run( id=run_id, parent_run_id=parent_run_id, diff --git a/langchain/chains/base.py b/langchain/chains/base.py index 9899aafbec..fd4a15b806 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -54,6 +54,12 @@ class Chain(Serializable, ABC): and passed as arguments to the handlers defined in `callbacks`. You can use these to eg identify a specific instance of a chain with its use case. """ + metadata: Optional[Dict[str, Any]] = None + """Optional metadata associated with the chain. Defaults to None + This metadata will be associated with each call to this chain, + and passed as arguments to the handlers defined in `callbacks`. + You can use these to eg identify a specific instance of a chain with its use case. + """ class Config: """Configuration for this pydantic object.""" @@ -130,6 +136,7 @@ class Chain(Serializable, ABC): callbacks: Callbacks = None, *, tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, include_run_info: bool = False, ) -> Dict[str, Any]: """Run the logic of this chain and add to output if desired. @@ -143,12 +150,20 @@ class Chain(Serializable, ABC): chain will be returned. Defaults to False. callbacks: Callbacks to use for this chain run. If not provided, will use the callbacks provided to the chain. + tags: Optional list of tags associated with the chain. Defaults to None + metadata: Optional metadata associated with the chain. Defaults to None include_run_info: Whether to include run info in the response. Defaults to False. """ inputs = self.prep_inputs(inputs) callback_manager = CallbackManager.configure( - callbacks, self.callbacks, self.verbose, tags, self.tags + callbacks, + self.callbacks, + self.verbose, + tags, + self.tags, + metadata, + self.metadata, ) new_arg_supported = inspect.signature(self._call).parameters.get("run_manager") run_manager = callback_manager.on_chain_start( @@ -179,6 +194,7 @@ class Chain(Serializable, ABC): callbacks: Callbacks = None, *, tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, include_run_info: bool = False, ) -> Dict[str, Any]: """Run the logic of this chain and add to output if desired. @@ -192,12 +208,20 @@ class Chain(Serializable, ABC): chain will be returned. Defaults to False. callbacks: Callbacks to use for this chain run. If not provided, will use the callbacks provided to the chain. + tags: Optional list of tags associated with the chain. Defaults to None + metadata: Optional metadata associated with the chain. Defaults to None include_run_info: Whether to include run info in the response. Defaults to False. """ inputs = self.prep_inputs(inputs) callback_manager = AsyncCallbackManager.configure( - callbacks, self.callbacks, self.verbose, tags, self.tags + callbacks, + self.callbacks, + self.verbose, + tags, + self.tags, + metadata, + self.metadata, ) new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager") run_manager = await callback_manager.on_chain_start( @@ -278,6 +302,7 @@ class Chain(Serializable, ABC): *args: Any, callbacks: Callbacks = None, tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> str: """Run the chain as text in, text out or multiple variables, text out.""" @@ -287,10 +312,14 @@ class Chain(Serializable, ABC): if args and not kwargs: if len(args) != 1: raise ValueError("`run` supports only one positional argument.") - return self(args[0], callbacks=callbacks, tags=tags)[_output_key] + return self(args[0], callbacks=callbacks, tags=tags, metadata=metadata)[ + _output_key + ] if kwargs and not args: - return self(kwargs, callbacks=callbacks, tags=tags)[_output_key] + return self(kwargs, callbacks=callbacks, tags=tags, metadata=metadata)[ + _output_key + ] if not kwargs and not args: raise ValueError( @@ -308,6 +337,7 @@ class Chain(Serializable, ABC): *args: Any, callbacks: Callbacks = None, tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> str: """Run the chain as text in, text out or multiple variables, text out.""" @@ -320,14 +350,18 @@ class Chain(Serializable, ABC): if args and not kwargs: if len(args) != 1: raise ValueError("`run` supports only one positional argument.") - return (await self.acall(args[0], callbacks=callbacks, tags=tags))[ - self.output_keys[0] - ] + return ( + await self.acall( + args[0], callbacks=callbacks, tags=tags, metadata=metadata + ) + )[self.output_keys[0]] if kwargs and not args: - return (await self.acall(kwargs, callbacks=callbacks, tags=tags))[ - self.output_keys[0] - ] + return ( + await self.acall( + kwargs, callbacks=callbacks, tags=tags, metadata=metadata + ) + )[self.output_keys[0]] raise ValueError( f"`run` supported with either positional arguments or keyword arguments" diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index 132971da42..f32ec3c77e 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -40,6 +40,8 @@ class BaseChatModel(BaseLanguageModel, ABC): callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) tags: Optional[List[str]] = Field(default=None, exclude=True) """Tags to add to the run trace.""" + metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True) + """Metadata to add to the run trace.""" @root_validator() def raise_deprecation(cls, values: Dict) -> Dict: @@ -86,6 +88,7 @@ class BaseChatModel(BaseLanguageModel, ABC): callbacks: Callbacks = None, *, tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> LLMResult: """Top Level call""" @@ -98,6 +101,8 @@ class BaseChatModel(BaseLanguageModel, ABC): self.verbose, tags, self.tags, + metadata, + self.metadata, ) run_managers = callback_manager.on_chat_model_start( dumpd(self), messages, invocation_params=params, options=options @@ -139,6 +144,7 @@ class BaseChatModel(BaseLanguageModel, ABC): callbacks: Callbacks = None, *, tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> LLMResult: """Top Level call""" @@ -151,6 +157,8 @@ class BaseChatModel(BaseLanguageModel, ABC): self.verbose, tags, self.tags, + metadata, + self.metadata, ) run_managers = await callback_manager.on_chat_model_start( diff --git a/langchain/evaluation/agents/trajectory_eval_chain.py b/langchain/evaluation/agents/trajectory_eval_chain.py index 184bcbfcee..6d0f07e40f 100644 --- a/langchain/evaluation/agents/trajectory_eval_chain.py +++ b/langchain/evaluation/agents/trajectory_eval_chain.py @@ -244,6 +244,7 @@ The following is the expected answer. Use this to measure correctness: callbacks: Callbacks = None, *, tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, include_run_info: bool = False, ) -> Dict[str, Any]: """Run the logic of this chain and add to output if desired. @@ -257,6 +258,8 @@ The following is the expected answer. Use this to measure correctness: chain will be returned. Defaults to False. callbacks: Callbacks to use for this chain run. If not provided, will use the callbacks provided to the chain. + tags: Tags to add to the chain run. + metadata: Metadata to add to the chain run. include_run_info: Whether to include run info in the response. Defaults to False. """ diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 9a15e5b173..cdb5ea5973 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -80,6 +80,8 @@ class BaseLLM(BaseLanguageModel, ABC): callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) tags: Optional[List[str]] = Field(default=None, exclude=True) """Tags to add to the run trace.""" + metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True) + """Metadata to add to the run trace.""" class Config: """Configuration for this pydantic object.""" @@ -190,6 +192,7 @@ class BaseLLM(BaseLanguageModel, ABC): callbacks: Callbacks = None, *, tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" @@ -209,7 +212,13 @@ class BaseLLM(BaseLanguageModel, ABC): ) = get_prompts(params, prompts) disregard_cache = self.cache is not None and not self.cache callback_manager = CallbackManager.configure( - callbacks, self.callbacks, self.verbose, tags, self.tags + callbacks, + self.callbacks, + self.verbose, + tags, + self.tags, + metadata, + self.metadata, ) new_arg_supported = inspect.signature(self._generate).parameters.get( "run_manager" @@ -293,6 +302,7 @@ class BaseLLM(BaseLanguageModel, ABC): callbacks: Callbacks = None, *, tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" @@ -307,7 +317,13 @@ class BaseLLM(BaseLanguageModel, ABC): ) = get_prompts(params, prompts) disregard_cache = self.cache is not None and not self.cache callback_manager = AsyncCallbackManager.configure( - callbacks, self.callbacks, self.verbose, tags, self.tags + callbacks, + self.callbacks, + self.verbose, + tags, + self.tags, + metadata, + self.metadata, ) new_arg_supported = inspect.signature(self._agenerate).parameters.get( "run_manager" @@ -350,6 +366,9 @@ class BaseLLM(BaseLanguageModel, ABC): prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None, + *, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> str: """Check Cache and run the LLM on the given prompt and input.""" @@ -360,7 +379,14 @@ class BaseLLM(BaseLanguageModel, ABC): "`generate` instead." ) return ( - self.generate([prompt], stop=stop, callbacks=callbacks, **kwargs) + self.generate( + [prompt], + stop=stop, + callbacks=callbacks, + tags=tags, + metadata=metadata, + **kwargs, + ) .generations[0][0] .text ) @@ -370,11 +396,19 @@ class BaseLLM(BaseLanguageModel, ABC): prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None, + *, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> str: """Check Cache and run the LLM on the given prompt and input.""" result = await self.agenerate( - [prompt], stop=stop, callbacks=callbacks, **kwargs + [prompt], + stop=stop, + callbacks=callbacks, + tags=tags, + metadata=metadata, + **kwargs, ) return result.generations[0][0].text diff --git a/langchain/schema/retriever.py b/langchain/schema/retriever.py index ce9e6ce664..b25ef0e692 100644 --- a/langchain/schema/retriever.py +++ b/langchain/schema/retriever.py @@ -3,7 +3,7 @@ from __future__ import annotations import warnings from abc import ABC, abstractmethod from inspect import signature -from typing import TYPE_CHECKING, Any, List +from typing import TYPE_CHECKING, Any, Dict, List, Optional from langchain.load.dump import dumpd from langchain.load.serializable import Serializable @@ -55,6 +55,20 @@ class BaseRetriever(Serializable, ABC): _new_arg_supported: bool = False _expects_other_args: bool = False + tags: Optional[List[str]] = None + """Optional list of tags associated with the retriever. Defaults to None + These tags will be associated with each call to this retriever, + and passed as arguments to the handlers defined in `callbacks`. + You can use these to eg identify a specific instance of a retriever with its + use case. + """ + metadata: Optional[Dict[str, Any]] = None + """Optional metadata associated with the retriever. Defaults to None + This metadata will be associated with each call to this retriever, + and passed as arguments to the handlers defined in `callbacks`. + You can use these to eg identify a specific instance of a retriever with its + use case. + """ def __init_subclass__(cls, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) @@ -117,19 +131,37 @@ class BaseRetriever(Serializable, ABC): """ def get_relevant_documents( - self, query: str, *, callbacks: Callbacks = None, **kwargs: Any + self, + query: str, + *, + callbacks: Callbacks = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, ) -> List[Document]: """Retrieve documents relevant to a query. Args: query: string to find relevant documents for callbacks: Callback manager or list of callbacks + tags: Optional list of tags associated with the retriever. Defaults to None + These tags will be associated with each call to this retriever, + and passed as arguments to the handlers defined in `callbacks`. + metadata: Optional metadata associated with the retriever. Defaults to None + This metadata will be associated with each call to this retriever, + and passed as arguments to the handlers defined in `callbacks`. Returns: List of relevant documents """ from langchain.callbacks.manager import CallbackManager callback_manager = CallbackManager.configure( - callbacks, None, verbose=kwargs.get("verbose", False) + callbacks, + None, + verbose=kwargs.get("verbose", False), + inheritable_tags=tags, + local_tags=self.tags, + inheritable_metadata=metadata, + local_metadata=self.metadata, ) run_manager = callback_manager.on_retriever_start( dumpd(self), @@ -155,19 +187,37 @@ class BaseRetriever(Serializable, ABC): return result async def aget_relevant_documents( - self, query: str, *, callbacks: Callbacks = None, **kwargs: Any + self, + query: str, + *, + callbacks: Callbacks = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, ) -> List[Document]: """Asynchronously get documents relevant to a query. Args: query: string to find relevant documents for callbacks: Callback manager or list of callbacks + tags: Optional list of tags associated with the retriever. Defaults to None + These tags will be associated with each call to this retriever, + and passed as arguments to the handlers defined in `callbacks`. + metadata: Optional metadata associated with the retriever. Defaults to None + This metadata will be associated with each call to this retriever, + and passed as arguments to the handlers defined in `callbacks`. Returns: List of relevant documents """ from langchain.callbacks.manager import AsyncCallbackManager callback_manager = AsyncCallbackManager.configure( - callbacks, None, verbose=kwargs.get("verbose", False) + callbacks, + None, + verbose=kwargs.get("verbose", False), + inheritable_tags=tags, + local_tags=self.tags, + inheritable_metadata=metadata, + local_metadata=self.metadata, ) run_manager = await callback_manager.on_retriever_start( dumpd(self), diff --git a/langchain/tools/base.py b/langchain/tools/base.py index a8d4c88863..c41462ae2b 100644 --- a/langchain/tools/base.py +++ b/langchain/tools/base.py @@ -4,7 +4,7 @@ from __future__ import annotations import warnings from abc import ABC, abstractmethod from inspect import signature -from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union +from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union from pydantic import ( BaseModel, @@ -153,6 +153,18 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): """Callbacks to be called during tool execution.""" callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) """Deprecated. Please use callbacks instead.""" + tags: Optional[List[str]] = None + """Optional list of tags associated with the tool. Defaults to None + These tags will be associated with each call to this tool, + and passed as arguments to the handlers defined in `callbacks`. + You can use these to eg identify a specific instance of a tool with its use case. + """ + metadata: Optional[Dict[str, Any]] = None + """Optional metadata associated with the tool. Defaults to None + This metadata will be associated with each call to this tool, + and passed as arguments to the handlers defined in `callbacks`. + You can use these to eg identify a specific instance of a tool with its use case. + """ handle_tool_error: Optional[ Union[bool, str, Callable[[ToolException], str]] @@ -246,6 +258,9 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): start_color: Optional[str] = "green", color: Optional[str] = "green", callbacks: Callbacks = None, + *, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: """Run the tool.""" @@ -255,7 +270,13 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): else: verbose_ = self.verbose callback_manager = CallbackManager.configure( - callbacks, self.callbacks, verbose=verbose_ + callbacks, + self.callbacks, + verbose_, + tags, + self.tags, + metadata, + self.metadata, ) # TODO: maybe also pass through run_manager is _run supports kwargs new_arg_supported = signature(self._run).parameters.get("run_manager") @@ -310,6 +331,9 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): start_color: Optional[str] = "green", color: Optional[str] = "green", callbacks: Callbacks = None, + *, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Any: """Run the tool asynchronously.""" @@ -319,7 +343,13 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass): else: verbose_ = self.verbose callback_manager = AsyncCallbackManager.configure( - callbacks, self.callbacks, verbose=verbose_ + callbacks, + self.callbacks, + verbose_, + tags, + self.tags, + metadata, + self.metadata, ) new_arg_supported = signature(self._arun).parameters.get("run_manager") run_manager = await callback_manager.on_tool_start( diff --git a/tests/integration_tests/callbacks/test_langchain_tracer.py b/tests/integration_tests/callbacks/test_langchain_tracer.py index 5415b05022..dde02bbd78 100644 --- a/tests/integration_tests/callbacks/test_langchain_tracer.py +++ b/tests/integration_tests/callbacks/test_langchain_tracer.py @@ -181,6 +181,40 @@ def test_tracing_v2_chain_with_tags() -> None: chain.run("what is the meaning of life", tags=["a-tag"]) +def test_tracing_v2_agent_with_metadata() -> None: + os.environ["LANGCHAIN_TRACING_V2"] = "true" + llm = OpenAI(temperature=0) + chat = ChatOpenAI(temperature=0) + tools = load_tools(["llm-math", "serpapi"], llm=llm) + agent = initialize_agent( + tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + chat_agent = initialize_agent( + tools, chat, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + agent.run(questions[0], tags=["a-tag"], metadata={"a": "b", "c": "d"}) + chat_agent.run(questions[0], tags=["a-tag"], metadata={"a": "b", "c": "d"}) + + +@pytest.mark.asyncio +async def test_tracing_v2_async_agent_with_metadata() -> None: + os.environ["LANGCHAIN_TRACING_V2"] = "true" + llm = OpenAI(temperature=0, metadata={"f": "g", "h": "i"}) + chat = ChatOpenAI(temperature=0, metadata={"f": "g", "h": "i"}) + async_tools = load_tools(["llm-math", "serpapi"], llm=llm) + agent = initialize_agent( + async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True + ) + chat_agent = initialize_agent( + async_tools, + chat, + agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, + verbose=True, + ) + await agent.arun(questions[0], tags=["a-tag"], metadata={"a": "b", "c": "d"}) + await chat_agent.arun(questions[0], tags=["a-tag"], metadata={"a": "b", "c": "d"}) + + def test_trace_as_group() -> None: llm = OpenAI(temperature=0.9) prompt = PromptTemplate(