Add support for tags (#5898)

<!--
Thank you for contributing to LangChain! Your PR will appear in our
release under the title you set. Please make sure it highlights your
valuable contribution.

Replace this with a description of the change, the issue it fixes (if
applicable), and relevant context. List any dependencies required for
this change.

After you're done, someone will review your PR. They may suggest
improvements. If no one reviews your PR within a few days, feel free to
@-mention the same people again, as notifications can get lost.

Finally, we'd love to show appreciation for your contribution - if you'd
like us to shout you out on Twitter, please also include your handle!
-->

<!-- Remove if not applicable -->

Fixes # (issue)

#### Before submitting

<!-- If you're adding a new integration, please include:

1. a test for the integration - favor unit tests that does not rely on
network access.
2. an example notebook showing its use


See contribution guidelines for more information on how to write tests,
lint
etc:


https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
-->

#### Who can review?

Tag maintainers/contributors who might be interested:

<!-- For a quicker response, figure out the right person to tag with @

  @hwchase17 - project lead

  Tracing / Callbacks
  - @agola11

  Async
  - @agola11

  DataLoaders
  - @eyurtsev

  Models
  - @hwchase17
  - @agola11

  Agents / Tools / Toolkits
  - @vowelparrot

  VectorStores / Retrievers / Memory
  - @dev2049

 -->
This commit is contained in:
Nuno Campos 2023-06-13 20:30:59 +01:00 committed by GitHub
parent 1281fdf0f2
commit 11ab0be11a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 210 additions and 39 deletions

View File

@ -124,6 +124,7 @@ class CallbackManagerMixin:
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when LLM starts running.""" """Run when LLM starts running."""
@ -135,6 +136,7 @@ class CallbackManagerMixin:
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when a chat model starts running.""" """Run when a chat model starts running."""
@ -149,6 +151,7 @@ class CallbackManagerMixin:
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when chain starts running.""" """Run when chain starts running."""
@ -160,6 +163,7 @@ class CallbackManagerMixin:
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when tool starts running.""" """Run when tool starts running."""
@ -221,6 +225,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when LLM starts running.""" """Run when LLM starts running."""
@ -232,6 +237,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Run when a chat model starts running.""" """Run when a chat model starts running."""
@ -276,6 +282,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when chain starts running.""" """Run when chain starts running."""
@ -307,6 +314,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
*, *,
run_id: UUID, run_id: UUID,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Run when tool starts running.""" """Run when tool starts running."""
@ -370,6 +378,9 @@ class BaseCallbackManager(CallbackManagerMixin):
handlers: List[BaseCallbackHandler], handlers: List[BaseCallbackHandler],
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None, inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
*,
tags: Optional[List[str]] = None,
inheritable_tags: Optional[List[str]] = None,
) -> None: ) -> None:
"""Initialize callback manager.""" """Initialize callback manager."""
self.handlers: List[BaseCallbackHandler] = handlers self.handlers: List[BaseCallbackHandler] = handlers
@ -377,6 +388,8 @@ class BaseCallbackManager(CallbackManagerMixin):
inheritable_handlers or [] inheritable_handlers or []
) )
self.parent_run_id: Optional[UUID] = parent_run_id self.parent_run_id: Optional[UUID] = parent_run_id
self.tags = tags or []
self.inheritable_tags = inheritable_tags or []
@property @property
def is_async(self) -> bool: def is_async(self) -> bool:
@ -406,3 +419,16 @@ class BaseCallbackManager(CallbackManagerMixin):
def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None: def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
"""Set handler as the only handler on the callback manager.""" """Set handler as the only handler on the callback manager."""
self.set_handlers([handler], inherit=inherit) self.set_handlers([handler], inherit=inherit)
def add_tags(self, tags: List[str], inherit: bool = True) -> None:
for tag in tags:
if tag in self.tags:
self.remove_tags([tag])
self.tags.extend(tags)
if inherit:
self.inheritable_tags.extend(tags)
def remove_tags(self, tags: List[str]) -> None:
for tag in tags:
self.tags.remove(tag)
self.inheritable_tags.remove(tag)

View File

@ -269,21 +269,32 @@ class BaseRunManager(RunManagerMixin):
def __init__( def __init__(
self, self,
*,
run_id: UUID, run_id: UUID,
handlers: List[BaseCallbackHandler], handlers: List[BaseCallbackHandler],
inheritable_handlers: List[BaseCallbackHandler], inheritable_handlers: List[BaseCallbackHandler],
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
tags: List[str],
inheritable_tags: List[str],
) -> None: ) -> None:
"""Initialize run manager.""" """Initialize run manager."""
self.run_id = run_id self.run_id = run_id
self.handlers = handlers self.handlers = handlers
self.inheritable_handlers = inheritable_handlers self.inheritable_handlers = inheritable_handlers
self.tags = tags
self.inheritable_tags = inheritable_tags
self.parent_run_id = parent_run_id self.parent_run_id = parent_run_id
@classmethod @classmethod
def get_noop_manager(cls: Type[BRM]) -> BRM: def get_noop_manager(cls: Type[BRM]) -> BRM:
"""Return a manager that doesn't perform any operations.""" """Return a manager that doesn't perform any operations."""
return cls(uuid4(), [], []) return cls(
run_id=uuid4(),
handlers=[],
inheritable_handlers=[],
tags=[],
inheritable_tags=[],
)
class RunManager(BaseRunManager): class RunManager(BaseRunManager):
@ -425,10 +436,13 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
class CallbackManagerForChainRun(RunManager, ChainManagerMixin): class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
"""Callback manager for chain run.""" """Callback manager for chain run."""
def get_child(self) -> CallbackManager: def get_child(self, tag: Optional[str] = None) -> CallbackManager:
"""Get a child callback manager.""" """Get a child callback manager."""
manager = CallbackManager([], parent_run_id=self.run_id) manager = CallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers) manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
if tag is not None:
manager.add_tags([tag], False)
return manager return manager
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
@ -487,10 +501,13 @@ class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin): class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
"""Async callback manager for chain run.""" """Async callback manager for chain run."""
def get_child(self) -> AsyncCallbackManager: def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
"""Get a child callback manager.""" """Get a child callback manager."""
manager = AsyncCallbackManager([], parent_run_id=self.run_id) manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers) manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
if tag is not None:
manager.add_tags([tag], False)
return manager return manager
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
@ -549,10 +566,13 @@ class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
class CallbackManagerForToolRun(RunManager, ToolManagerMixin): class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
"""Callback manager for tool run.""" """Callback manager for tool run."""
def get_child(self) -> CallbackManager: def get_child(self, tag: Optional[str] = None) -> CallbackManager:
"""Get a child callback manager.""" """Get a child callback manager."""
manager = CallbackManager([], parent_run_id=self.run_id) manager = CallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers) manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
if tag is not None:
manager.add_tags([tag], False)
return manager return manager
def on_tool_end( def on_tool_end(
@ -591,10 +611,13 @@ class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin): class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
"""Async callback manager for tool run.""" """Async callback manager for tool run."""
def get_child(self) -> AsyncCallbackManager: def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
"""Get a child callback manager.""" """Get a child callback manager."""
manager = AsyncCallbackManager([], parent_run_id=self.run_id) manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id)
manager.set_handlers(self.inheritable_handlers) manager.set_handlers(self.inheritable_handlers)
manager.add_tags(self.inheritable_tags)
if tag is not None:
manager.add_tags([tag], False)
return manager return manager
async def on_tool_end(self, output: str, **kwargs: Any) -> None: async def on_tool_end(self, output: str, **kwargs: Any) -> None:
@ -648,11 +671,17 @@ class CallbackManager(BaseCallbackManager):
prompts, prompts,
run_id=run_id, run_id=run_id,
parent_run_id=self.parent_run_id, parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs, **kwargs,
) )
return CallbackManagerForLLMRun( return CallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
) )
def on_chat_model_start( def on_chat_model_start(
@ -673,13 +702,19 @@ class CallbackManager(BaseCallbackManager):
messages, messages,
run_id=run_id, run_id=run_id,
parent_run_id=self.parent_run_id, parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs, **kwargs,
) )
# Re-use the LLM Run Manager since the outputs are treated # Re-use the LLM Run Manager since the outputs are treated
# the same for now # the same for now
return CallbackManagerForLLMRun( return CallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
) )
def on_chain_start( def on_chain_start(
@ -701,11 +736,17 @@ class CallbackManager(BaseCallbackManager):
inputs, inputs,
run_id=run_id, run_id=run_id,
parent_run_id=self.parent_run_id, parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs, **kwargs,
) )
return CallbackManagerForChainRun( return CallbackManagerForChainRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
) )
def on_tool_start( def on_tool_start(
@ -728,11 +769,17 @@ class CallbackManager(BaseCallbackManager):
input_str, input_str,
run_id=run_id, run_id=run_id,
parent_run_id=self.parent_run_id, parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs, **kwargs,
) )
return CallbackManagerForToolRun( return CallbackManagerForToolRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
) )
@classmethod @classmethod
@ -741,9 +788,18 @@ class CallbackManager(BaseCallbackManager):
inheritable_callbacks: Callbacks = None, inheritable_callbacks: Callbacks = None,
local_callbacks: Callbacks = None, local_callbacks: Callbacks = None,
verbose: bool = False, verbose: bool = False,
inheritable_tags: Optional[List[str]] = None,
local_tags: Optional[List[str]] = None,
) -> CallbackManager: ) -> CallbackManager:
"""Configure the callback manager.""" """Configure the callback manager."""
return _configure(cls, inheritable_callbacks, local_callbacks, verbose) return _configure(
cls,
inheritable_callbacks,
local_callbacks,
verbose,
inheritable_tags,
local_tags,
)
class AsyncCallbackManager(BaseCallbackManager): class AsyncCallbackManager(BaseCallbackManager):
@ -773,11 +829,17 @@ class AsyncCallbackManager(BaseCallbackManager):
prompts, prompts,
run_id=run_id, run_id=run_id,
parent_run_id=self.parent_run_id, parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs, **kwargs,
) )
return AsyncCallbackManagerForLLMRun( return AsyncCallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
) )
async def on_chat_model_start( async def on_chat_model_start(
@ -798,11 +860,17 @@ class AsyncCallbackManager(BaseCallbackManager):
messages, messages,
run_id=run_id, run_id=run_id,
parent_run_id=self.parent_run_id, parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs, **kwargs,
) )
return AsyncCallbackManagerForLLMRun( return AsyncCallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
) )
async def on_chain_start( async def on_chain_start(
@ -824,11 +892,17 @@ class AsyncCallbackManager(BaseCallbackManager):
inputs, inputs,
run_id=run_id, run_id=run_id,
parent_run_id=self.parent_run_id, parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs, **kwargs,
) )
return AsyncCallbackManagerForChainRun( return AsyncCallbackManagerForChainRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
) )
async def on_tool_start( async def on_tool_start(
@ -851,11 +925,17 @@ class AsyncCallbackManager(BaseCallbackManager):
input_str, input_str,
run_id=run_id, run_id=run_id,
parent_run_id=self.parent_run_id, parent_run_id=self.parent_run_id,
tags=self.tags,
**kwargs, **kwargs,
) )
return AsyncCallbackManagerForToolRun( return AsyncCallbackManagerForToolRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id run_id=run_id,
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
) )
@classmethod @classmethod
@ -864,9 +944,18 @@ class AsyncCallbackManager(BaseCallbackManager):
inheritable_callbacks: Callbacks = None, inheritable_callbacks: Callbacks = None,
local_callbacks: Callbacks = None, local_callbacks: Callbacks = None,
verbose: bool = False, verbose: bool = False,
inheritable_tags: Optional[List[str]] = None,
local_tags: Optional[List[str]] = None,
) -> AsyncCallbackManager: ) -> AsyncCallbackManager:
"""Configure the callback manager.""" """Configure the callback manager."""
return _configure(cls, inheritable_callbacks, local_callbacks, verbose) return _configure(
cls,
inheritable_callbacks,
local_callbacks,
verbose,
inheritable_tags,
local_tags,
)
T = TypeVar("T", CallbackManager, AsyncCallbackManager) T = TypeVar("T", CallbackManager, AsyncCallbackManager)
@ -887,9 +976,11 @@ def _configure(
inheritable_callbacks: Callbacks = None, inheritable_callbacks: Callbacks = None,
local_callbacks: Callbacks = None, local_callbacks: Callbacks = None,
verbose: bool = False, verbose: bool = False,
inheritable_tags: Optional[List[str]] = None,
local_tags: Optional[List[str]] = None,
) -> T: ) -> T:
"""Configure the callback manager.""" """Configure the callback manager."""
callback_manager = callback_manager_cls([]) callback_manager = callback_manager_cls(handlers=[])
if inheritable_callbacks or local_callbacks: if inheritable_callbacks or local_callbacks:
if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None: if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None:
inheritable_callbacks_ = inheritable_callbacks or [] inheritable_callbacks_ = inheritable_callbacks or []
@ -902,6 +993,8 @@ def _configure(
handlers=inheritable_callbacks.handlers, handlers=inheritable_callbacks.handlers,
inheritable_handlers=inheritable_callbacks.inheritable_handlers, inheritable_handlers=inheritable_callbacks.inheritable_handlers,
parent_run_id=inheritable_callbacks.parent_run_id, parent_run_id=inheritable_callbacks.parent_run_id,
tags=inheritable_callbacks.tags,
inheritable_tags=inheritable_callbacks.inheritable_tags,
) )
local_handlers_ = ( local_handlers_ = (
local_callbacks local_callbacks
@ -910,6 +1003,9 @@ def _configure(
) )
for handler in local_handlers_: for handler in local_handlers_:
callback_manager.add_handler(handler, False) callback_manager.add_handler(handler, False)
if inheritable_tags or local_tags:
callback_manager.add_tags(inheritable_tags or [])
callback_manager.add_tags(local_tags or [], False)
tracer = tracing_callback_var.get() tracer = tracing_callback_var.get()
wandb_tracer = wandb_tracing_callback_var.get() wandb_tracer = wandb_tracing_callback_var.get()

View File

@ -85,6 +85,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
prompts: List[str], prompts: List[str],
*, *,
run_id: UUID, run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
@ -101,6 +102,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
execution_order=execution_order, execution_order=execution_order,
child_execution_order=execution_order, child_execution_order=execution_order,
run_type=RunTypeEnum.llm, run_type=RunTypeEnum.llm,
tags=tags or [],
) )
self._start_trace(llm_run) self._start_trace(llm_run)
self._on_llm_start(llm_run) self._on_llm_start(llm_run)
@ -145,6 +147,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
inputs: Dict[str, Any], inputs: Dict[str, Any],
*, *,
run_id: UUID, run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
@ -162,6 +165,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
child_execution_order=execution_order, child_execution_order=execution_order,
child_runs=[], child_runs=[],
run_type=RunTypeEnum.chain, run_type=RunTypeEnum.chain,
tags=tags or [],
) )
self._start_trace(chain_run) self._start_trace(chain_run)
self._on_chain_start(chain_run) self._on_chain_start(chain_run)
@ -206,6 +210,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
input_str: str, input_str: str,
*, *,
run_id: UUID, run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
@ -223,6 +228,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
child_execution_order=execution_order, child_execution_order=execution_order,
child_runs=[], child_runs=[],
run_type=RunTypeEnum.tool, run_type=RunTypeEnum.tool,
tags=tags or [],
) )
self._start_trace(tool_run) self._start_trace(tool_run)
self._on_tool_start(tool_run) self._on_tool_start(tool_run)

View File

@ -59,6 +59,7 @@ class LangChainTracer(BaseTracer):
messages: List[List[BaseMessage]], messages: List[List[BaseMessage]],
*, *,
run_id: UUID, run_id: UUID,
tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
@ -75,6 +76,7 @@ class LangChainTracer(BaseTracer):
execution_order=execution_order, execution_order=execution_order,
child_execution_order=execution_order, child_execution_order=execution_order,
run_type=RunTypeEnum.llm, run_type=RunTypeEnum.llm,
tags=tags,
) )
self._start_trace(chat_model_run) self._start_trace(chat_model_run)
self._on_chat_model_start(chat_model_run) self._on_chat_model_start(chat_model_run)

View File

@ -94,6 +94,7 @@ class Run(BaseRunV2):
execution_order: int execution_order: int
child_execution_order: int child_execution_order: int
child_runs: List[Run] = Field(default_factory=list) child_runs: List[Run] = Field(default_factory=list)
tags: Optional[List[str]] = Field(default_factory=list)
@root_validator(pre=True) @root_validator(pre=True)
def assign_name(cls, values: dict) -> dict: def assign_name(cls, values: dict) -> dict:

View File

@ -36,6 +36,7 @@ class Chain(Serializable, ABC):
verbose: bool = Field( verbose: bool = Field(
default_factory=_get_verbosity default_factory=_get_verbosity
) # Whether to print the response text ) # Whether to print the response text
tags: Optional[List[str]] = None
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -111,6 +112,7 @@ class Chain(Serializable, ABC):
return_only_outputs: bool = False, return_only_outputs: bool = False,
callbacks: Callbacks = None, callbacks: Callbacks = None,
*, *,
tags: Optional[List[str]] = None,
include_run_info: bool = False, include_run_info: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired. """Run the logic of this chain and add to output if desired.
@ -129,7 +131,7 @@ class Chain(Serializable, ABC):
""" """
inputs = self.prep_inputs(inputs) inputs = self.prep_inputs(inputs)
callback_manager = CallbackManager.configure( callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose callbacks, self.callbacks, self.verbose, tags, self.tags
) )
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager") new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
@ -159,6 +161,7 @@ class Chain(Serializable, ABC):
return_only_outputs: bool = False, return_only_outputs: bool = False,
callbacks: Callbacks = None, callbacks: Callbacks = None,
*, *,
tags: Optional[List[str]] = None,
include_run_info: bool = False, include_run_info: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired. """Run the logic of this chain and add to output if desired.
@ -177,7 +180,7 @@ class Chain(Serializable, ABC):
""" """
inputs = self.prep_inputs(inputs) inputs = self.prep_inputs(inputs)
callback_manager = AsyncCallbackManager.configure( callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose callbacks, self.callbacks, self.verbose, tags, self.tags
) )
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager") new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
run_manager = await callback_manager.on_chain_start( run_manager = await callback_manager.on_chain_start(
@ -244,7 +247,13 @@ class Chain(Serializable, ABC):
"""Call the chain on all inputs in the list.""" """Call the chain on all inputs in the list."""
return [self(inputs, callbacks=callbacks) for inputs in input_list] return [self(inputs, callbacks=callbacks) for inputs in input_list]
def run(self, *args: Any, callbacks: Callbacks = None, **kwargs: Any) -> str: def run(
self,
*args: Any,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
"""Run the chain as text in, text out or multiple variables, text out.""" """Run the chain as text in, text out or multiple variables, text out."""
if len(self.output_keys) != 1: if len(self.output_keys) != 1:
raise ValueError( raise ValueError(
@ -255,10 +264,10 @@ class Chain(Serializable, ABC):
if args and not kwargs: if args and not kwargs:
if len(args) != 1: if len(args) != 1:
raise ValueError("`run` supports only one positional argument.") raise ValueError("`run` supports only one positional argument.")
return self(args[0], callbacks=callbacks)[self.output_keys[0]] return self(args[0], callbacks=callbacks, tags=tags)[self.output_keys[0]]
if kwargs and not args: if kwargs and not args:
return self(kwargs, callbacks=callbacks)[self.output_keys[0]] return self(kwargs, callbacks=callbacks, tags=tags)[self.output_keys[0]]
if not kwargs and not args: if not kwargs and not args:
raise ValueError( raise ValueError(
@ -271,7 +280,13 @@ class Chain(Serializable, ABC):
f" but not both. Got args: {args} and kwargs: {kwargs}." f" but not both. Got args: {args} and kwargs: {kwargs}."
) )
async def arun(self, *args: Any, callbacks: Callbacks = None, **kwargs: Any) -> str: async def arun(
self,
*args: Any,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
**kwargs: Any,
) -> str:
"""Run the chain as text in, text out or multiple variables, text out.""" """Run the chain as text in, text out or multiple variables, text out."""
if len(self.output_keys) != 1: if len(self.output_keys) != 1:
raise ValueError( raise ValueError(
@ -282,10 +297,14 @@ class Chain(Serializable, ABC):
if args and not kwargs: if args and not kwargs:
if len(args) != 1: if len(args) != 1:
raise ValueError("`run` supports only one positional argument.") raise ValueError("`run` supports only one positional argument.")
return (await self.acall(args[0], callbacks=callbacks))[self.output_keys[0]] return (await self.acall(args[0], callbacks=callbacks, tags=tags))[
self.output_keys[0]
]
if kwargs and not args: if kwargs and not args:
return (await self.acall(kwargs, callbacks=callbacks))[self.output_keys[0]] return (await self.acall(kwargs, callbacks=callbacks, tags=tags))[
self.output_keys[0]
]
raise ValueError( raise ValueError(
f"`run` supported with either positional arguments or keyword arguments" f"`run` supported with either positional arguments or keyword arguments"

View File

@ -98,7 +98,7 @@ class ConstitutionalChain(Chain):
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
response = self.chain.run( response = self.chain.run(
**inputs, **inputs,
callbacks=_run_manager.get_child(), callbacks=_run_manager.get_child("original"),
) )
initial_response = response initial_response = response
input_prompt = self.chain.prompt.format(**inputs) input_prompt = self.chain.prompt.format(**inputs)
@ -116,7 +116,7 @@ class ConstitutionalChain(Chain):
input_prompt=input_prompt, input_prompt=input_prompt,
output_from_model=response, output_from_model=response,
critique_request=constitutional_principle.critique_request, critique_request=constitutional_principle.critique_request,
callbacks=_run_manager.get_child(), callbacks=_run_manager.get_child("critique"),
) )
critique = self._parse_critique( critique = self._parse_critique(
output_string=raw_critique, output_string=raw_critique,
@ -137,7 +137,7 @@ class ConstitutionalChain(Chain):
critique_request=constitutional_principle.critique_request, critique_request=constitutional_principle.critique_request,
critique=critique, critique=critique,
revision_request=constitutional_principle.revision_request, revision_request=constitutional_principle.revision_request,
callbacks=_run_manager.get_child(), callbacks=_run_manager.get_child("revision"),
).strip() ).strip()
response = revision response = revision
critiques_and_revisions.append((critique, revision)) critiques_and_revisions.append((critique, revision))

View File

@ -283,7 +283,7 @@ class LLMChain(Chain):
return "llm_chain" return "llm_chain"
@classmethod @classmethod
def from_string(cls, llm: BaseLanguageModel, template: str) -> Chain: def from_string(cls, llm: BaseLanguageModel, template: str) -> LLMChain:
"""Create LLMChain from LLM and template.""" """Create LLMChain from LLM and template."""
prompt_template = PromptTemplate.from_template(template) prompt_template = PromptTemplate.from_template(template)
return cls(llm=llm, prompt=prompt_template) return cls(llm=llm, prompt=prompt_template)

View File

@ -174,7 +174,7 @@ class SimpleSequentialChain(Chain):
_input = inputs[self.input_key] _input = inputs[self.input_key]
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))]) color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
for i, chain in enumerate(self.chains): for i, chain in enumerate(self.chains):
_input = chain.run(_input, callbacks=_run_manager.get_child()) _input = chain.run(_input, callbacks=_run_manager.get_child(f"step_{i+1}"))
if self.strip_outputs: if self.strip_outputs:
_input = _input.strip() _input = _input.strip()
_run_manager.on_text( _run_manager.on_text(

View File

@ -13,6 +13,8 @@ from langchain.callbacks.manager import (
tracing_v2_enabled, tracing_v2_enabled,
) )
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain.chains.constitutional_ai.base import ConstitutionalChain
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI from langchain.llms import OpenAI
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
@ -160,6 +162,25 @@ def test_tracing_v2_context_manager() -> None:
agent.run(questions[0]) # this should not be traced agent.run(questions[0]) # this should not be traced
def test_tracing_v2_chain_with_tags() -> None:
llm = OpenAI(temperature=0)
chain = ConstitutionalChain.from_llm(
llm,
chain=LLMChain.from_string(llm, "Q: {question} A:"),
tags=["only-root"],
constitutional_principles=[
ConstitutionalPrinciple(
critique_request="Tell if this answer is good.",
revision_request="Give a better answer.",
)
],
)
if "LANGCHAIN_TRACING_V2" in os.environ:
del os.environ["LANGCHAIN_TRACING_V2"]
with tracing_v2_enabled():
chain.run("what is the meaning of life", tags=["a-tag"])
def test_trace_as_group() -> None: def test_trace_as_group() -> None:
llm = OpenAI(temperature=0.9) llm = OpenAI(temperature=0.9)
prompt = PromptTemplate( prompt = PromptTemplate(

View File

@ -86,7 +86,7 @@ def test_callback_manager() -> None:
"""Test the CallbackManager.""" """Test the CallbackManager."""
handler1 = FakeCallbackHandler() handler1 = FakeCallbackHandler()
handler2 = FakeCallbackHandler() handler2 = FakeCallbackHandler()
manager = CallbackManager([handler1, handler2]) manager = CallbackManager(handlers=[handler1, handler2])
_test_callback_manager(manager, handler1, handler2) _test_callback_manager(manager, handler1, handler2)
@ -143,7 +143,7 @@ async def test_async_callback_manager() -> None:
"""Test the AsyncCallbackManager.""" """Test the AsyncCallbackManager."""
handler1 = FakeAsyncCallbackHandler() handler1 = FakeAsyncCallbackHandler()
handler2 = FakeAsyncCallbackHandler() handler2 = FakeAsyncCallbackHandler()
manager = AsyncCallbackManager([handler1, handler2]) manager = AsyncCallbackManager(handlers=[handler1, handler2])
await _test_callback_manager_async(manager, handler1, handler2) await _test_callback_manager_async(manager, handler1, handler2)
@ -153,7 +153,7 @@ async def test_async_callback_manager_sync_handler() -> None:
handler1 = FakeCallbackHandler() handler1 = FakeCallbackHandler()
handler2 = FakeAsyncCallbackHandler() handler2 = FakeAsyncCallbackHandler()
handler3 = FakeAsyncCallbackHandler() handler3 = FakeAsyncCallbackHandler()
manager = AsyncCallbackManager([handler1, handler2, handler3]) manager = AsyncCallbackManager(handlers=[handler1, handler2, handler3])
await _test_callback_manager_async(manager, handler1, handler2, handler3) await _test_callback_manager_async(manager, handler1, handler2, handler3)
@ -165,11 +165,11 @@ def test_callback_manager_inheritance() -> None:
FakeCallbackHandler(), FakeCallbackHandler(),
) )
callback_manager1 = CallbackManager([handler1, handler2]) callback_manager1 = CallbackManager(handlers=[handler1, handler2])
assert callback_manager1.handlers == [handler1, handler2] assert callback_manager1.handlers == [handler1, handler2]
assert callback_manager1.inheritable_handlers == [] assert callback_manager1.inheritable_handlers == []
callback_manager2 = CallbackManager([]) callback_manager2 = CallbackManager(handlers=[])
assert callback_manager2.handlers == [] assert callback_manager2.handlers == []
assert callback_manager2.inheritable_handlers == [] assert callback_manager2.inheritable_handlers == []
@ -229,7 +229,7 @@ def test_callback_manager_configure(monkeypatch: pytest.MonkeyPatch) -> None:
assert isinstance(configured_manager.handlers[4], StdOutCallbackHandler) assert isinstance(configured_manager.handlers[4], StdOutCallbackHandler)
assert isinstance(configured_manager, CallbackManager) assert isinstance(configured_manager, CallbackManager)
async_local_callbacks = AsyncCallbackManager([handler3, handler4]) async_local_callbacks = AsyncCallbackManager(handlers=[handler3, handler4])
async_configured_manager = AsyncCallbackManager.configure( async_configured_manager = AsyncCallbackManager.configure(
inheritable_callbacks=inheritable_callbacks, inheritable_callbacks=inheritable_callbacks,
local_callbacks=async_local_callbacks, local_callbacks=async_local_callbacks,