support adding custom metadata to runs (#7120)

- [x] wire up tools
- [x] wire up retrievers
- [x] add integration test

<!-- Thank you for contributing to LangChain!

Replace this comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @baskaryan
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @baskaryan
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @hinthornw
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->
pull/7218/head
Ankush Gola 1 year ago committed by GitHub
parent 30d8d1d3d0
commit 4c1c05c2c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -147,6 +147,7 @@ class CallbackManagerMixin:
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when LLM starts running."""
@ -159,6 +160,7 @@ class CallbackManagerMixin:
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when a chat model starts running."""
@ -174,6 +176,7 @@ class CallbackManagerMixin:
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when Retriever starts running."""
@ -186,6 +189,7 @@ class CallbackManagerMixin:
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when chain starts running."""
@ -198,6 +202,7 @@ class CallbackManagerMixin:
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when tool starts running."""
@ -268,6 +273,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run when LLM starts running."""
@ -280,6 +286,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run when a chat model starts running."""
@ -328,6 +335,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run when chain starts running."""
@ -362,6 +370,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run when tool starts running."""
@ -429,6 +438,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run on retriever start."""
@ -467,6 +477,8 @@ class BaseCallbackManager(CallbackManagerMixin):
*,
tags: Optional[List[str]] = None,
inheritable_tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize callback manager."""
self.handlers: List[BaseCallbackHandler] = handlers
@ -476,6 +488,8 @@ class BaseCallbackManager(CallbackManagerMixin):
self.parent_run_id: Optional[UUID] = parent_run_id
self.tags = tags or []
self.inheritable_tags = inheritable_tags or []
self.metadata = metadata or {}
self.inheritable_metadata = inheritable_metadata or {}
@property
def is_async(self) -> bool:
@ -518,3 +532,13 @@ class BaseCallbackManager(CallbackManagerMixin):
for tag in tags:
self.tags.remove(tag)
self.inheritable_tags.remove(tag)
def add_metadata(self, metadata: Dict[str, Any], inherit: bool = True) -> None:
self.metadata.update(metadata)
if inherit:
self.inheritable_metadata.update(metadata)
def remove_metadata(self, keys: List[str]) -> None:
for key in keys:
self.metadata.pop(key)
self.inheritable_metadata.pop(key)

@ -383,6 +383,8 @@ class BaseRunManager(RunManagerMixin):
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
inheritable_tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize the run manager.
@ -395,6 +397,8 @@ class BaseRunManager(RunManagerMixin):
Defaults to None.
tags (Optional[List[str]]): The list of tags.
inheritable_tags (Optional[List[str]]): The list of inheritable tags.
metadata (Optional[Dict[str, Any]]): The metadata.
inheritable_metadata (Optional[Dict[str, Any]]): The inheritable metadata.
"""
self.run_id = run_id
self.handlers = handlers
@ -402,6 +406,8 @@ class BaseRunManager(RunManagerMixin):
self.parent_run_id = parent_run_id
self.tags = tags or []
self.inheritable_tags = inheritable_tags or []
self.metadata = metadata or {}
self.inheritable_metadata = inheritable_metadata or {}
@classmethod
def get_noop_manager(cls: Type[BRM]) -> BRM:
@ -416,6 +422,8 @@ class BaseRunManager(RunManagerMixin):
inheritable_handlers=[],
tags=[],
inheritable_tags=[],
metadata={},
inheritable_metadata={},
)
@ -447,6 +455,28 @@ class RunManager(BaseRunManager):
)
class ParentRunManager(RunManager):
"""Sync Parent Run Manager."""
def get_child(self, tag: Optional[str] = None) -> CallbackManager:
"""Get a child callback manager.
Args:
tag (str, optional): The tag for the child callback manager.
Defaults to None.
Returns:
CallbackManager: The child callback manager.
"""
manager = CallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
manager.add_metadata(self.inheritable_metadata)
if tag is not None:
manager.add_tags([tag], False)
return manager
class AsyncRunManager(BaseRunManager):
"""Async Run Manager."""
@ -475,6 +505,28 @@ class AsyncRunManager(BaseRunManager):
)
class AsyncParentRunManager(AsyncRunManager):
"""Async Parent Run Manager."""
def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
"""Get a child callback manager.
Args:
tag (str, optional): The tag for the child callback manager.
Defaults to None.
Returns:
AsyncCallbackManager: The child callback manager.
"""
manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
manager.add_metadata(self.inheritable_metadata)
if tag is not None:
manager.add_tags([tag], False)
return manager
class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
"""Callback manager for LLM run."""
@ -601,26 +653,9 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
)
class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
"""Callback manager for chain run."""
def get_child(self, tag: Optional[str] = None) -> CallbackManager:
"""Get a child callback manager.
Args:
tag (str, optional): The tag for the child callback manager.
Defaults to None.
Returns:
CallbackManager: The child callback manager.
"""
manager = CallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
if tag is not None:
manager.add_tags([tag], False)
return manager
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running.
@ -700,26 +735,9 @@ class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
)
class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
"""Async callback manager for chain run."""
def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
"""Get a child callback manager.
Args:
tag (str, optional): The tag for the child callback manager.
Defaults to None.
Returns:
AsyncCallbackManager: The child callback manager.
"""
manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
if tag is not None:
manager.add_tags([tag], False)
return manager
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""Run when chain ends running.
@ -799,26 +817,9 @@ class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
)
class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
"""Callback manager for tool run."""
def get_child(self, tag: Optional[str] = None) -> CallbackManager:
"""Get a child callback manager.
Args:
tag (str, optional): The tag for the child callback manager.
Defaults to None.
Returns:
CallbackManager: The child callback manager.
"""
manager = CallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
if tag is not None:
manager.add_tags([tag], False)
return manager
def on_tool_end(
self,
output: str,
@ -862,26 +863,9 @@ class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
)
class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
"""Async callback manager for tool run."""
def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
"""Get a child callback manager.
Args:
tag (str, optional): The tag to add to the child
callback manager. Defaults to None.
Returns:
AsyncCallbackManager: The child callback manager.
"""
manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
if tag is not None:
manager.add_tags([tag], False)
return manager
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Run when tool ends running.
@ -921,18 +905,9 @@ class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
)
class CallbackManagerForRetrieverRun(RunManager, RetrieverManagerMixin):
class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin):
"""Callback manager for retriever run."""
def get_child(self, tag: Optional[str] = None) -> CallbackManager:
"""Get a child callback manager."""
manager = CallbackManager([], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
if tag is not None:
manager.add_tags([tag], False)
return manager
def on_retriever_end(
self,
documents: Sequence[Document],
@ -969,20 +944,11 @@ class CallbackManagerForRetrieverRun(RunManager, RetrieverManagerMixin):
class AsyncCallbackManagerForRetrieverRun(
AsyncRunManager,
AsyncParentRunManager,
RetrieverManagerMixin,
):
"""Async callback manager for retriever run."""
def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
"""Get a child callback manager."""
manager = AsyncCallbackManager([], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
if tag is not None:
manager.add_tags([tag], False)
return manager
async def on_retriever_end(
self, documents: Sequence[Document], **kwargs: Any
) -> None:
@ -1048,6 +1014,7 @@ class CallbackManager(BaseCallbackManager):
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
metadata=self.metadata,
**kwargs,
)
@ -1059,6 +1026,8 @@ class CallbackManager(BaseCallbackManager):
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
)
@ -1094,6 +1063,7 @@ class CallbackManager(BaseCallbackManager):
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
metadata=self.metadata,
**kwargs,
)
@ -1105,6 +1075,8 @@ class CallbackManager(BaseCallbackManager):
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
)
@ -1139,6 +1111,7 @@ class CallbackManager(BaseCallbackManager):
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
metadata=self.metadata,
**kwargs,
)
@ -1149,6 +1122,8 @@ class CallbackManager(BaseCallbackManager):
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
def on_tool_start(
@ -1182,6 +1157,7 @@ class CallbackManager(BaseCallbackManager):
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
metadata=self.metadata,
**kwargs,
)
@ -1192,6 +1168,8 @@ class CallbackManager(BaseCallbackManager):
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
def on_retriever_start(
@ -1215,6 +1193,7 @@ class CallbackManager(BaseCallbackManager):
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
metadata=self.metadata,
**kwargs,
)
@ -1225,6 +1204,8 @@ class CallbackManager(BaseCallbackManager):
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
@classmethod
@ -1235,6 +1216,8 @@ class CallbackManager(BaseCallbackManager):
verbose: bool = False,
inheritable_tags: Optional[List[str]] = None,
local_tags: Optional[List[str]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None,
local_metadata: Optional[Dict[str, Any]] = None,
) -> CallbackManager:
"""Configure the callback manager.
@ -1248,6 +1231,10 @@ class CallbackManager(BaseCallbackManager):
Defaults to None.
local_tags (Optional[List[str]], optional): The local tags.
Defaults to None.
inheritable_metadata (Optional[Dict[str, Any]], optional): The inheritable
metadata. Defaults to None.
local_metadata (Optional[Dict[str, Any]], optional): The local metadata.
Defaults to None.
Returns:
CallbackManager: The configured callback manager.
@ -1259,6 +1246,8 @@ class CallbackManager(BaseCallbackManager):
verbose,
inheritable_tags,
local_tags,
inheritable_metadata,
local_metadata,
)
@ -1305,6 +1294,7 @@ class AsyncCallbackManager(BaseCallbackManager):
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
metadata=self.metadata,
**kwargs,
)
)
@ -1317,6 +1307,8 @@ class AsyncCallbackManager(BaseCallbackManager):
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
)
@ -1358,6 +1350,7 @@ class AsyncCallbackManager(BaseCallbackManager):
run_id=run_id_,
parent_run_id=self.parent_run_id,
tags=self.tags,
metadata=self.metadata,
**kwargs,
)
)
@ -1370,6 +1363,8 @@ class AsyncCallbackManager(BaseCallbackManager):
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
)
@ -1406,6 +1401,7 @@ class AsyncCallbackManager(BaseCallbackManager):
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
metadata=self.metadata,
**kwargs,
)
@ -1416,6 +1412,8 @@ class AsyncCallbackManager(BaseCallbackManager):
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
async def on_tool_start(
@ -1451,6 +1449,7 @@ class AsyncCallbackManager(BaseCallbackManager):
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
metadata=self.metadata,
**kwargs,
)
@ -1461,6 +1460,8 @@ class AsyncCallbackManager(BaseCallbackManager):
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
async def on_retriever_start(
@ -1484,6 +1485,7 @@ class AsyncCallbackManager(BaseCallbackManager):
run_id=run_id,
parent_run_id=self.parent_run_id,
tags=self.tags,
metadata=self.metadata,
**kwargs,
)
@ -1494,6 +1496,8 @@ class AsyncCallbackManager(BaseCallbackManager):
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
@classmethod
@ -1504,6 +1508,8 @@ class AsyncCallbackManager(BaseCallbackManager):
verbose: bool = False,
inheritable_tags: Optional[List[str]] = None,
local_tags: Optional[List[str]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None,
local_metadata: Optional[Dict[str, Any]] = None,
) -> AsyncCallbackManager:
"""Configure the async callback manager.
@ -1517,6 +1523,10 @@ class AsyncCallbackManager(BaseCallbackManager):
Defaults to None.
local_tags (Optional[List[str]], optional): The local tags.
Defaults to None.
inheritable_metadata (Optional[Dict[str, Any]], optional): The inheritable
metadata. Defaults to None.
local_metadata (Optional[Dict[str, Any]], optional): The local metadata.
Defaults to None.
Returns:
AsyncCallbackManager: The configured async callback manager.
@ -1528,6 +1538,8 @@ class AsyncCallbackManager(BaseCallbackManager):
verbose,
inheritable_tags,
local_tags,
inheritable_metadata,
local_metadata,
)
@ -1558,6 +1570,8 @@ def _configure(
verbose: bool = False,
inheritable_tags: Optional[List[str]] = None,
local_tags: Optional[List[str]] = None,
inheritable_metadata: Optional[Dict[str, Any]] = None,
local_metadata: Optional[Dict[str, Any]] = None,
) -> T:
"""Configure the callback manager.
@ -1571,6 +1585,10 @@ def _configure(
inheritable_tags (Optional[List[str]], optional): The inheritable tags.
Defaults to None.
local_tags (Optional[List[str]], optional): The local tags. Defaults to None.
inheritable_metadata (Optional[Dict[str, Any]], optional): The inheritable
metadata. Defaults to None.
local_metadata (Optional[Dict[str, Any]], optional): The local metadata.
Defaults to None.
Returns:
T: The configured callback manager.
@ -1590,6 +1608,8 @@ def _configure(
parent_run_id=inheritable_callbacks.parent_run_id,
tags=inheritable_callbacks.tags,
inheritable_tags=inheritable_callbacks.inheritable_tags,
metadata=inheritable_callbacks.metadata,
inheritable_metadata=inheritable_callbacks.inheritable_metadata,
)
local_handlers_ = (
local_callbacks
@ -1601,6 +1621,9 @@ def _configure(
if inheritable_tags or local_tags:
callback_manager.add_tags(inheritable_tags or [])
callback_manager.add_tags(local_tags or [], False)
if inheritable_metadata or local_metadata:
callback_manager.add_metadata(inheritable_metadata or {})
callback_manager.add_metadata(local_metadata or {}, False)
tracer = tracing_callback_var.get()
wandb_tracer = wandb_tracing_callback_var.get()

@ -89,12 +89,15 @@ class BaseTracer(BaseCallbackHandler, ABC):
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Start a trace for an LLM run."""
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
start_time = datetime.utcnow()
if metadata:
kwargs.update({"metadata": metadata})
llm_run = Run(
id=run_id,
parent_run_id=parent_run_id,
@ -186,12 +189,15 @@ class BaseTracer(BaseCallbackHandler, ABC):
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Start a trace for a chain run."""
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
start_time = datetime.utcnow()
if metadata:
kwargs.update({"metadata": metadata})
chain_run = Run(
id=run_id,
parent_run_id=parent_run_id,
@ -253,12 +259,15 @@ class BaseTracer(BaseCallbackHandler, ABC):
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Start a trace for a tool run."""
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
start_time = datetime.utcnow()
if metadata:
kwargs.update({"metadata": metadata})
tool_run = Run(
id=run_id,
parent_run_id=parent_run_id,
@ -317,12 +326,15 @@ class BaseTracer(BaseCallbackHandler, ABC):
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run when Retriever starts running."""
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
start_time = datetime.utcnow()
if metadata:
kwargs.update({"metadata": metadata})
retrieval_run = Run(
id=run_id,
name="Retriever",

@ -70,12 +70,15 @@ class LangChainTracer(BaseTracer):
run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Start a trace for an LLM run."""
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
start_time = datetime.utcnow()
if metadata:
kwargs.update({"metadata": metadata})
chat_model_run = Run(
id=run_id,
parent_run_id=parent_run_id,

@ -54,6 +54,12 @@ class Chain(Serializable, ABC):
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a chain with its use case.
"""
metadata: Optional[Dict[str, Any]] = None
"""Optional metadata associated with the chain. Defaults to None
This metadata will be associated with each call to this chain,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a chain with its use case.
"""
class Config:
"""Configuration for this pydantic object."""
@ -130,6 +136,7 @@ class Chain(Serializable, ABC):
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
@ -143,12 +150,20 @@ class Chain(Serializable, ABC):
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. If not provided, will
use the callbacks provided to the chain.
tags: Optional list of tags associated with the chain. Defaults to None
metadata: Optional metadata associated with the chain. Defaults to None
include_run_info: Whether to include run info in the response. Defaults
to False.
"""
inputs = self.prep_inputs(inputs)
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose, tags, self.tags
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start(
@ -179,6 +194,7 @@ class Chain(Serializable, ABC):
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
@ -192,12 +208,20 @@ class Chain(Serializable, ABC):
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. If not provided, will
use the callbacks provided to the chain.
tags: Optional list of tags associated with the chain. Defaults to None
metadata: Optional metadata associated with the chain. Defaults to None
include_run_info: Whether to include run info in the response. Defaults
to False.
"""
inputs = self.prep_inputs(inputs)
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose, tags, self.tags
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
run_manager = await callback_manager.on_chain_start(
@ -278,6 +302,7 @@ class Chain(Serializable, ABC):
*args: Any,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> str:
"""Run the chain as text in, text out or multiple variables, text out."""
@ -287,10 +312,14 @@ class Chain(Serializable, ABC):
if args and not kwargs:
if len(args) != 1:
raise ValueError("`run` supports only one positional argument.")
return self(args[0], callbacks=callbacks, tags=tags)[_output_key]
return self(args[0], callbacks=callbacks, tags=tags, metadata=metadata)[
_output_key
]
if kwargs and not args:
return self(kwargs, callbacks=callbacks, tags=tags)[_output_key]
return self(kwargs, callbacks=callbacks, tags=tags, metadata=metadata)[
_output_key
]
if not kwargs and not args:
raise ValueError(
@ -308,6 +337,7 @@ class Chain(Serializable, ABC):
*args: Any,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> str:
"""Run the chain as text in, text out or multiple variables, text out."""
@ -320,14 +350,18 @@ class Chain(Serializable, ABC):
if args and not kwargs:
if len(args) != 1:
raise ValueError("`run` supports only one positional argument.")
return (await self.acall(args[0], callbacks=callbacks, tags=tags))[
self.output_keys[0]
]
return (
await self.acall(
args[0], callbacks=callbacks, tags=tags, metadata=metadata
)
)[self.output_keys[0]]
if kwargs and not args:
return (await self.acall(kwargs, callbacks=callbacks, tags=tags))[
self.output_keys[0]
]
return (
await self.acall(
kwargs, callbacks=callbacks, tags=tags, metadata=metadata
)
)[self.output_keys[0]]
raise ValueError(
f"`run` supported with either positional arguments or keyword arguments"

@ -40,6 +40,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
tags: Optional[List[str]] = Field(default=None, exclude=True)
"""Tags to add to the run trace."""
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
"""Metadata to add to the run trace."""
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
@ -86,6 +88,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
@ -98,6 +101,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
run_managers = callback_manager.on_chat_model_start(
dumpd(self), messages, invocation_params=params, options=options
@ -139,6 +144,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
@ -151,6 +157,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
run_managers = await callback_manager.on_chat_model_start(

@ -244,6 +244,7 @@ The following is the expected answer. Use this to measure correctness:
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
@ -257,6 +258,8 @@ The following is the expected answer. Use this to measure correctness:
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. If not provided, will
use the callbacks provided to the chain.
tags: Tags to add to the chain run.
metadata: Metadata to add to the chain run.
include_run_info: Whether to include run info in the response. Defaults
to False.
"""

@ -80,6 +80,8 @@ class BaseLLM(BaseLanguageModel, ABC):
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
tags: Optional[List[str]] = Field(default=None, exclude=True)
"""Tags to add to the run trace."""
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
"""Metadata to add to the run trace."""
class Config:
"""Configuration for this pydantic object."""
@ -190,6 +192,7 @@ class BaseLLM(BaseLanguageModel, ABC):
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
@ -209,7 +212,13 @@ class BaseLLM(BaseLanguageModel, ABC):
) = get_prompts(params, prompts)
disregard_cache = self.cache is not None and not self.cache
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose, tags, self.tags
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager"
@ -293,6 +302,7 @@ class BaseLLM(BaseLanguageModel, ABC):
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
@ -307,7 +317,13 @@ class BaseLLM(BaseLanguageModel, ABC):
) = get_prompts(params, prompts)
disregard_cache = self.cache is not None and not self.cache
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose, tags, self.tags
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"
@ -350,6 +366,9 @@ class BaseLLM(BaseLanguageModel, ABC):
prompt: str,
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> str:
"""Check Cache and run the LLM on the given prompt and input."""
@ -360,7 +379,14 @@ class BaseLLM(BaseLanguageModel, ABC):
"`generate` instead."
)
return (
self.generate([prompt], stop=stop, callbacks=callbacks, **kwargs)
self.generate(
[prompt],
stop=stop,
callbacks=callbacks,
tags=tags,
metadata=metadata,
**kwargs,
)
.generations[0][0]
.text
)
@ -370,11 +396,19 @@ class BaseLLM(BaseLanguageModel, ABC):
prompt: str,
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> str:
"""Check Cache and run the LLM on the given prompt and input."""
result = await self.agenerate(
[prompt], stop=stop, callbacks=callbacks, **kwargs
[prompt],
stop=stop,
callbacks=callbacks,
tags=tags,
metadata=metadata,
**kwargs,
)
return result.generations[0][0].text

@ -3,7 +3,7 @@ from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from inspect import signature
from typing import TYPE_CHECKING, Any, List
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
@ -55,6 +55,20 @@ class BaseRetriever(Serializable, ABC):
_new_arg_supported: bool = False
_expects_other_args: bool = False
tags: Optional[List[str]] = None
"""Optional list of tags associated with the retriever. Defaults to None
These tags will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a retriever with its
use case.
"""
metadata: Optional[Dict[str, Any]] = None
"""Optional metadata associated with the retriever. Defaults to None
This metadata will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a retriever with its
use case.
"""
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
@ -117,19 +131,37 @@ class BaseRetriever(Serializable, ABC):
"""
def get_relevant_documents(
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
self,
query: str,
*,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""Retrieve documents relevant to a query.
Args:
query: string to find relevant documents for
callbacks: Callback manager or list of callbacks
tags: Optional list of tags associated with the retriever. Defaults to None
These tags will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
metadata: Optional metadata associated with the retriever. Defaults to None
This metadata will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
Returns:
List of relevant documents
"""
from langchain.callbacks.manager import CallbackManager
callback_manager = CallbackManager.configure(
callbacks, None, verbose=kwargs.get("verbose", False)
callbacks,
None,
verbose=kwargs.get("verbose", False),
inheritable_tags=tags,
local_tags=self.tags,
inheritable_metadata=metadata,
local_metadata=self.metadata,
)
run_manager = callback_manager.on_retriever_start(
dumpd(self),
@ -155,19 +187,37 @@ class BaseRetriever(Serializable, ABC):
return result
async def aget_relevant_documents(
self, query: str, *, callbacks: Callbacks = None, **kwargs: Any
self,
query: str,
*,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
Args:
query: string to find relevant documents for
callbacks: Callback manager or list of callbacks
tags: Optional list of tags associated with the retriever. Defaults to None
These tags will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
metadata: Optional metadata associated with the retriever. Defaults to None
This metadata will be associated with each call to this retriever,
and passed as arguments to the handlers defined in `callbacks`.
Returns:
List of relevant documents
"""
from langchain.callbacks.manager import AsyncCallbackManager
callback_manager = AsyncCallbackManager.configure(
callbacks, None, verbose=kwargs.get("verbose", False)
callbacks,
None,
verbose=kwargs.get("verbose", False),
inheritable_tags=tags,
local_tags=self.tags,
inheritable_metadata=metadata,
local_metadata=self.metadata,
)
run_manager = await callback_manager.on_retriever_start(
dumpd(self),

@ -4,7 +4,7 @@ from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from inspect import signature
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Type, Union
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
from pydantic import (
BaseModel,
@ -153,6 +153,18 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
"""Callbacks to be called during tool execution."""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""Deprecated. Please use callbacks instead."""
tags: Optional[List[str]] = None
"""Optional list of tags associated with the tool. Defaults to None
These tags will be associated with each call to this tool,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a tool with its use case.
"""
metadata: Optional[Dict[str, Any]] = None
"""Optional metadata associated with the tool. Defaults to None
This metadata will be associated with each call to this tool,
and passed as arguments to the handlers defined in `callbacks`.
You can use these to eg identify a specific instance of a tool with its use case.
"""
handle_tool_error: Optional[
Union[bool, str, Callable[[ToolException], str]]
@ -246,6 +258,9 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
start_color: Optional[str] = "green",
color: Optional[str] = "green",
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run the tool."""
@ -255,7 +270,13 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
else:
verbose_ = self.verbose
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, verbose=verbose_
callbacks,
self.callbacks,
verbose_,
tags,
self.tags,
metadata,
self.metadata,
)
# TODO: maybe also pass through run_manager is _run supports kwargs
new_arg_supported = signature(self._run).parameters.get("run_manager")
@ -310,6 +331,9 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
start_color: Optional[str] = "green",
color: Optional[str] = "green",
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Run the tool asynchronously."""
@ -319,7 +343,13 @@ class BaseTool(ABC, BaseModel, metaclass=ToolMetaclass):
else:
verbose_ = self.verbose
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, verbose=verbose_
callbacks,
self.callbacks,
verbose_,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = signature(self._arun).parameters.get("run_manager")
run_manager = await callback_manager.on_tool_start(

@ -181,6 +181,40 @@ def test_tracing_v2_chain_with_tags() -> None:
chain.run("what is the meaning of life", tags=["a-tag"])
def test_tracing_v2_agent_with_metadata() -> None:
os.environ["LANGCHAIN_TRACING_V2"] = "true"
llm = OpenAI(temperature=0)
chat = ChatOpenAI(temperature=0)
tools = load_tools(["llm-math", "serpapi"], llm=llm)
agent = initialize_agent(
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
chat_agent = initialize_agent(
tools, chat, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
agent.run(questions[0], tags=["a-tag"], metadata={"a": "b", "c": "d"})
chat_agent.run(questions[0], tags=["a-tag"], metadata={"a": "b", "c": "d"})
@pytest.mark.asyncio
async def test_tracing_v2_async_agent_with_metadata() -> None:
os.environ["LANGCHAIN_TRACING_V2"] = "true"
llm = OpenAI(temperature=0, metadata={"f": "g", "h": "i"})
chat = ChatOpenAI(temperature=0, metadata={"f": "g", "h": "i"})
async_tools = load_tools(["llm-math", "serpapi"], llm=llm)
agent = initialize_agent(
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
)
chat_agent = initialize_agent(
async_tools,
chat,
agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
)
await agent.arun(questions[0], tags=["a-tag"], metadata={"a": "b", "c": "d"})
await chat_agent.arun(questions[0], tags=["a-tag"], metadata={"a": "b", "c": "d"})
def test_trace_as_group() -> None:
llm = OpenAI(temperature=0.9)
prompt = PromptTemplate(

Loading…
Cancel
Save