mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Add support for tags (#5898)
<!-- Thank you for contributing to LangChain! Your PR will appear in our release under the title you set. Please make sure it highlights your valuable contribution. Replace this with a description of the change, the issue it fixes (if applicable), and relevant context. List any dependencies required for this change. After you're done, someone will review your PR. They may suggest improvements. If no one reviews your PR within a few days, feel free to @-mention the same people again, as notifications can get lost. Finally, we'd love to show appreciation for your contribution - if you'd like us to shout you out on Twitter, please also include your handle! --> <!-- Remove if not applicable --> Fixes # (issue) #### Before submitting <!-- If you're adding a new integration, please include: 1. a test for the integration - favor unit tests that does not rely on network access. 2. an example notebook showing its use See contribution guidelines for more information on how to write tests, lint etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md --> #### Who can review? Tag maintainers/contributors who might be interested: <!-- For a quicker response, figure out the right person to tag with @ @hwchase17 - project lead Tracing / Callbacks - @agola11 Async - @agola11 DataLoaders - @eyurtsev Models - @hwchase17 - @agola11 Agents / Tools / Toolkits - @vowelparrot VectorStores / Retrievers / Memory - @dev2049 -->
This commit is contained in:
parent
1281fdf0f2
commit
11ab0be11a
@ -124,6 +124,7 @@ class CallbackManagerMixin:
|
|||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Run when LLM starts running."""
|
"""Run when LLM starts running."""
|
||||||
@ -135,6 +136,7 @@ class CallbackManagerMixin:
|
|||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Run when a chat model starts running."""
|
"""Run when a chat model starts running."""
|
||||||
@ -149,6 +151,7 @@ class CallbackManagerMixin:
|
|||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Run when chain starts running."""
|
"""Run when chain starts running."""
|
||||||
@ -160,6 +163,7 @@ class CallbackManagerMixin:
|
|||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Run when tool starts running."""
|
"""Run when tool starts running."""
|
||||||
@ -221,6 +225,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when LLM starts running."""
|
"""Run when LLM starts running."""
|
||||||
@ -232,6 +237,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Run when a chat model starts running."""
|
"""Run when a chat model starts running."""
|
||||||
@ -276,6 +282,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when chain starts running."""
|
"""Run when chain starts running."""
|
||||||
@ -307,6 +314,7 @@ class AsyncCallbackHandler(BaseCallbackHandler):
|
|||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run when tool starts running."""
|
"""Run when tool starts running."""
|
||||||
@ -370,6 +378,9 @@ class BaseCallbackManager(CallbackManagerMixin):
|
|||||||
handlers: List[BaseCallbackHandler],
|
handlers: List[BaseCallbackHandler],
|
||||||
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
|
inheritable_handlers: Optional[List[BaseCallbackHandler]] = None,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
*,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
|
inheritable_tags: Optional[List[str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize callback manager."""
|
"""Initialize callback manager."""
|
||||||
self.handlers: List[BaseCallbackHandler] = handlers
|
self.handlers: List[BaseCallbackHandler] = handlers
|
||||||
@ -377,6 +388,8 @@ class BaseCallbackManager(CallbackManagerMixin):
|
|||||||
inheritable_handlers or []
|
inheritable_handlers or []
|
||||||
)
|
)
|
||||||
self.parent_run_id: Optional[UUID] = parent_run_id
|
self.parent_run_id: Optional[UUID] = parent_run_id
|
||||||
|
self.tags = tags or []
|
||||||
|
self.inheritable_tags = inheritable_tags or []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_async(self) -> bool:
|
def is_async(self) -> bool:
|
||||||
@ -406,3 +419,16 @@ class BaseCallbackManager(CallbackManagerMixin):
|
|||||||
def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
|
def set_handler(self, handler: BaseCallbackHandler, inherit: bool = True) -> None:
|
||||||
"""Set handler as the only handler on the callback manager."""
|
"""Set handler as the only handler on the callback manager."""
|
||||||
self.set_handlers([handler], inherit=inherit)
|
self.set_handlers([handler], inherit=inherit)
|
||||||
|
|
||||||
|
def add_tags(self, tags: List[str], inherit: bool = True) -> None:
|
||||||
|
for tag in tags:
|
||||||
|
if tag in self.tags:
|
||||||
|
self.remove_tags([tag])
|
||||||
|
self.tags.extend(tags)
|
||||||
|
if inherit:
|
||||||
|
self.inheritable_tags.extend(tags)
|
||||||
|
|
||||||
|
def remove_tags(self, tags: List[str]) -> None:
|
||||||
|
for tag in tags:
|
||||||
|
self.tags.remove(tag)
|
||||||
|
self.inheritable_tags.remove(tag)
|
||||||
|
@ -269,21 +269,32 @@ class BaseRunManager(RunManagerMixin):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
handlers: List[BaseCallbackHandler],
|
handlers: List[BaseCallbackHandler],
|
||||||
inheritable_handlers: List[BaseCallbackHandler],
|
inheritable_handlers: List[BaseCallbackHandler],
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
|
tags: List[str],
|
||||||
|
inheritable_tags: List[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize run manager."""
|
"""Initialize run manager."""
|
||||||
self.run_id = run_id
|
self.run_id = run_id
|
||||||
self.handlers = handlers
|
self.handlers = handlers
|
||||||
self.inheritable_handlers = inheritable_handlers
|
self.inheritable_handlers = inheritable_handlers
|
||||||
|
self.tags = tags
|
||||||
|
self.inheritable_tags = inheritable_tags
|
||||||
self.parent_run_id = parent_run_id
|
self.parent_run_id = parent_run_id
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_noop_manager(cls: Type[BRM]) -> BRM:
|
def get_noop_manager(cls: Type[BRM]) -> BRM:
|
||||||
"""Return a manager that doesn't perform any operations."""
|
"""Return a manager that doesn't perform any operations."""
|
||||||
return cls(uuid4(), [], [])
|
return cls(
|
||||||
|
run_id=uuid4(),
|
||||||
|
handlers=[],
|
||||||
|
inheritable_handlers=[],
|
||||||
|
tags=[],
|
||||||
|
inheritable_tags=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RunManager(BaseRunManager):
|
class RunManager(BaseRunManager):
|
||||||
@ -425,10 +436,13 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
|||||||
class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
|
class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
|
||||||
"""Callback manager for chain run."""
|
"""Callback manager for chain run."""
|
||||||
|
|
||||||
def get_child(self) -> CallbackManager:
|
def get_child(self, tag: Optional[str] = None) -> CallbackManager:
|
||||||
"""Get a child callback manager."""
|
"""Get a child callback manager."""
|
||||||
manager = CallbackManager([], parent_run_id=self.run_id)
|
manager = CallbackManager(handlers=[], parent_run_id=self.run_id)
|
||||||
manager.set_handlers(self.inheritable_handlers)
|
manager.set_handlers(self.inheritable_handlers)
|
||||||
|
manager.add_tags(self.inheritable_tags)
|
||||||
|
if tag is not None:
|
||||||
|
manager.add_tags([tag], False)
|
||||||
return manager
|
return manager
|
||||||
|
|
||||||
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||||
@ -487,10 +501,13 @@ class CallbackManagerForChainRun(RunManager, ChainManagerMixin):
|
|||||||
class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
|
class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
|
||||||
"""Async callback manager for chain run."""
|
"""Async callback manager for chain run."""
|
||||||
|
|
||||||
def get_child(self) -> AsyncCallbackManager:
|
def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
|
||||||
"""Get a child callback manager."""
|
"""Get a child callback manager."""
|
||||||
manager = AsyncCallbackManager([], parent_run_id=self.run_id)
|
manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id)
|
||||||
manager.set_handlers(self.inheritable_handlers)
|
manager.set_handlers(self.inheritable_handlers)
|
||||||
|
manager.add_tags(self.inheritable_tags)
|
||||||
|
if tag is not None:
|
||||||
|
manager.add_tags([tag], False)
|
||||||
return manager
|
return manager
|
||||||
|
|
||||||
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
|
||||||
@ -549,10 +566,13 @@ class AsyncCallbackManagerForChainRun(AsyncRunManager, ChainManagerMixin):
|
|||||||
class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
|
class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
|
||||||
"""Callback manager for tool run."""
|
"""Callback manager for tool run."""
|
||||||
|
|
||||||
def get_child(self) -> CallbackManager:
|
def get_child(self, tag: Optional[str] = None) -> CallbackManager:
|
||||||
"""Get a child callback manager."""
|
"""Get a child callback manager."""
|
||||||
manager = CallbackManager([], parent_run_id=self.run_id)
|
manager = CallbackManager(handlers=[], parent_run_id=self.run_id)
|
||||||
manager.set_handlers(self.inheritable_handlers)
|
manager.set_handlers(self.inheritable_handlers)
|
||||||
|
manager.add_tags(self.inheritable_tags)
|
||||||
|
if tag is not None:
|
||||||
|
manager.add_tags([tag], False)
|
||||||
return manager
|
return manager
|
||||||
|
|
||||||
def on_tool_end(
|
def on_tool_end(
|
||||||
@ -591,10 +611,13 @@ class CallbackManagerForToolRun(RunManager, ToolManagerMixin):
|
|||||||
class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
|
class AsyncCallbackManagerForToolRun(AsyncRunManager, ToolManagerMixin):
|
||||||
"""Async callback manager for tool run."""
|
"""Async callback manager for tool run."""
|
||||||
|
|
||||||
def get_child(self) -> AsyncCallbackManager:
|
def get_child(self, tag: Optional[str] = None) -> AsyncCallbackManager:
|
||||||
"""Get a child callback manager."""
|
"""Get a child callback manager."""
|
||||||
manager = AsyncCallbackManager([], parent_run_id=self.run_id)
|
manager = AsyncCallbackManager(handlers=[], parent_run_id=self.run_id)
|
||||||
manager.set_handlers(self.inheritable_handlers)
|
manager.set_handlers(self.inheritable_handlers)
|
||||||
|
manager.add_tags(self.inheritable_tags)
|
||||||
|
if tag is not None:
|
||||||
|
manager.add_tags([tag], False)
|
||||||
return manager
|
return manager
|
||||||
|
|
||||||
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||||||
@ -648,11 +671,17 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
prompts,
|
prompts,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
parent_run_id=self.parent_run_id,
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return CallbackManagerForLLMRun(
|
return CallbackManagerForLLMRun(
|
||||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
run_id=run_id,
|
||||||
|
handlers=self.handlers,
|
||||||
|
inheritable_handlers=self.inheritable_handlers,
|
||||||
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
|
inheritable_tags=self.inheritable_tags,
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_chat_model_start(
|
def on_chat_model_start(
|
||||||
@ -673,13 +702,19 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
messages,
|
messages,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
parent_run_id=self.parent_run_id,
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Re-use the LLM Run Manager since the outputs are treated
|
# Re-use the LLM Run Manager since the outputs are treated
|
||||||
# the same for now
|
# the same for now
|
||||||
return CallbackManagerForLLMRun(
|
return CallbackManagerForLLMRun(
|
||||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
run_id=run_id,
|
||||||
|
handlers=self.handlers,
|
||||||
|
inheritable_handlers=self.inheritable_handlers,
|
||||||
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
|
inheritable_tags=self.inheritable_tags,
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_chain_start(
|
def on_chain_start(
|
||||||
@ -701,11 +736,17 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
inputs,
|
inputs,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
parent_run_id=self.parent_run_id,
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return CallbackManagerForChainRun(
|
return CallbackManagerForChainRun(
|
||||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
run_id=run_id,
|
||||||
|
handlers=self.handlers,
|
||||||
|
inheritable_handlers=self.inheritable_handlers,
|
||||||
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
|
inheritable_tags=self.inheritable_tags,
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
@ -728,11 +769,17 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
input_str,
|
input_str,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
parent_run_id=self.parent_run_id,
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return CallbackManagerForToolRun(
|
return CallbackManagerForToolRun(
|
||||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
run_id=run_id,
|
||||||
|
handlers=self.handlers,
|
||||||
|
inheritable_handlers=self.inheritable_handlers,
|
||||||
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
|
inheritable_tags=self.inheritable_tags,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -741,9 +788,18 @@ class CallbackManager(BaseCallbackManager):
|
|||||||
inheritable_callbacks: Callbacks = None,
|
inheritable_callbacks: Callbacks = None,
|
||||||
local_callbacks: Callbacks = None,
|
local_callbacks: Callbacks = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
inheritable_tags: Optional[List[str]] = None,
|
||||||
|
local_tags: Optional[List[str]] = None,
|
||||||
) -> CallbackManager:
|
) -> CallbackManager:
|
||||||
"""Configure the callback manager."""
|
"""Configure the callback manager."""
|
||||||
return _configure(cls, inheritable_callbacks, local_callbacks, verbose)
|
return _configure(
|
||||||
|
cls,
|
||||||
|
inheritable_callbacks,
|
||||||
|
local_callbacks,
|
||||||
|
verbose,
|
||||||
|
inheritable_tags,
|
||||||
|
local_tags,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AsyncCallbackManager(BaseCallbackManager):
|
class AsyncCallbackManager(BaseCallbackManager):
|
||||||
@ -773,11 +829,17 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
prompts,
|
prompts,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
parent_run_id=self.parent_run_id,
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return AsyncCallbackManagerForLLMRun(
|
return AsyncCallbackManagerForLLMRun(
|
||||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
run_id=run_id,
|
||||||
|
handlers=self.handlers,
|
||||||
|
inheritable_handlers=self.inheritable_handlers,
|
||||||
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
|
inheritable_tags=self.inheritable_tags,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def on_chat_model_start(
|
async def on_chat_model_start(
|
||||||
@ -798,11 +860,17 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
messages,
|
messages,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
parent_run_id=self.parent_run_id,
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return AsyncCallbackManagerForLLMRun(
|
return AsyncCallbackManagerForLLMRun(
|
||||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
run_id=run_id,
|
||||||
|
handlers=self.handlers,
|
||||||
|
inheritable_handlers=self.inheritable_handlers,
|
||||||
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
|
inheritable_tags=self.inheritable_tags,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def on_chain_start(
|
async def on_chain_start(
|
||||||
@ -824,11 +892,17 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
inputs,
|
inputs,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
parent_run_id=self.parent_run_id,
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return AsyncCallbackManagerForChainRun(
|
return AsyncCallbackManagerForChainRun(
|
||||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
run_id=run_id,
|
||||||
|
handlers=self.handlers,
|
||||||
|
inheritable_handlers=self.inheritable_handlers,
|
||||||
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
|
inheritable_tags=self.inheritable_tags,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def on_tool_start(
|
async def on_tool_start(
|
||||||
@ -851,11 +925,17 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
input_str,
|
input_str,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
parent_run_id=self.parent_run_id,
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return AsyncCallbackManagerForToolRun(
|
return AsyncCallbackManagerForToolRun(
|
||||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
run_id=run_id,
|
||||||
|
handlers=self.handlers,
|
||||||
|
inheritable_handlers=self.inheritable_handlers,
|
||||||
|
parent_run_id=self.parent_run_id,
|
||||||
|
tags=self.tags,
|
||||||
|
inheritable_tags=self.inheritable_tags,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -864,9 +944,18 @@ class AsyncCallbackManager(BaseCallbackManager):
|
|||||||
inheritable_callbacks: Callbacks = None,
|
inheritable_callbacks: Callbacks = None,
|
||||||
local_callbacks: Callbacks = None,
|
local_callbacks: Callbacks = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
inheritable_tags: Optional[List[str]] = None,
|
||||||
|
local_tags: Optional[List[str]] = None,
|
||||||
) -> AsyncCallbackManager:
|
) -> AsyncCallbackManager:
|
||||||
"""Configure the callback manager."""
|
"""Configure the callback manager."""
|
||||||
return _configure(cls, inheritable_callbacks, local_callbacks, verbose)
|
return _configure(
|
||||||
|
cls,
|
||||||
|
inheritable_callbacks,
|
||||||
|
local_callbacks,
|
||||||
|
verbose,
|
||||||
|
inheritable_tags,
|
||||||
|
local_tags,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", CallbackManager, AsyncCallbackManager)
|
T = TypeVar("T", CallbackManager, AsyncCallbackManager)
|
||||||
@ -887,9 +976,11 @@ def _configure(
|
|||||||
inheritable_callbacks: Callbacks = None,
|
inheritable_callbacks: Callbacks = None,
|
||||||
local_callbacks: Callbacks = None,
|
local_callbacks: Callbacks = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
inheritable_tags: Optional[List[str]] = None,
|
||||||
|
local_tags: Optional[List[str]] = None,
|
||||||
) -> T:
|
) -> T:
|
||||||
"""Configure the callback manager."""
|
"""Configure the callback manager."""
|
||||||
callback_manager = callback_manager_cls([])
|
callback_manager = callback_manager_cls(handlers=[])
|
||||||
if inheritable_callbacks or local_callbacks:
|
if inheritable_callbacks or local_callbacks:
|
||||||
if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None:
|
if isinstance(inheritable_callbacks, list) or inheritable_callbacks is None:
|
||||||
inheritable_callbacks_ = inheritable_callbacks or []
|
inheritable_callbacks_ = inheritable_callbacks or []
|
||||||
@ -902,6 +993,8 @@ def _configure(
|
|||||||
handlers=inheritable_callbacks.handlers,
|
handlers=inheritable_callbacks.handlers,
|
||||||
inheritable_handlers=inheritable_callbacks.inheritable_handlers,
|
inheritable_handlers=inheritable_callbacks.inheritable_handlers,
|
||||||
parent_run_id=inheritable_callbacks.parent_run_id,
|
parent_run_id=inheritable_callbacks.parent_run_id,
|
||||||
|
tags=inheritable_callbacks.tags,
|
||||||
|
inheritable_tags=inheritable_callbacks.inheritable_tags,
|
||||||
)
|
)
|
||||||
local_handlers_ = (
|
local_handlers_ = (
|
||||||
local_callbacks
|
local_callbacks
|
||||||
@ -910,6 +1003,9 @@ def _configure(
|
|||||||
)
|
)
|
||||||
for handler in local_handlers_:
|
for handler in local_handlers_:
|
||||||
callback_manager.add_handler(handler, False)
|
callback_manager.add_handler(handler, False)
|
||||||
|
if inheritable_tags or local_tags:
|
||||||
|
callback_manager.add_tags(inheritable_tags or [])
|
||||||
|
callback_manager.add_tags(local_tags or [], False)
|
||||||
|
|
||||||
tracer = tracing_callback_var.get()
|
tracer = tracing_callback_var.get()
|
||||||
wandb_tracer = wandb_tracing_callback_var.get()
|
wandb_tracer = wandb_tracing_callback_var.get()
|
||||||
|
@ -85,6 +85,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -101,6 +102,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
execution_order=execution_order,
|
execution_order=execution_order,
|
||||||
child_execution_order=execution_order,
|
child_execution_order=execution_order,
|
||||||
run_type=RunTypeEnum.llm,
|
run_type=RunTypeEnum.llm,
|
||||||
|
tags=tags or [],
|
||||||
)
|
)
|
||||||
self._start_trace(llm_run)
|
self._start_trace(llm_run)
|
||||||
self._on_llm_start(llm_run)
|
self._on_llm_start(llm_run)
|
||||||
@ -145,6 +147,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
inputs: Dict[str, Any],
|
inputs: Dict[str, Any],
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -162,6 +165,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
child_execution_order=execution_order,
|
child_execution_order=execution_order,
|
||||||
child_runs=[],
|
child_runs=[],
|
||||||
run_type=RunTypeEnum.chain,
|
run_type=RunTypeEnum.chain,
|
||||||
|
tags=tags or [],
|
||||||
)
|
)
|
||||||
self._start_trace(chain_run)
|
self._start_trace(chain_run)
|
||||||
self._on_chain_start(chain_run)
|
self._on_chain_start(chain_run)
|
||||||
@ -206,6 +210,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
input_str: str,
|
input_str: str,
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -223,6 +228,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
child_execution_order=execution_order,
|
child_execution_order=execution_order,
|
||||||
child_runs=[],
|
child_runs=[],
|
||||||
run_type=RunTypeEnum.tool,
|
run_type=RunTypeEnum.tool,
|
||||||
|
tags=tags or [],
|
||||||
)
|
)
|
||||||
self._start_trace(tool_run)
|
self._start_trace(tool_run)
|
||||||
self._on_tool_start(tool_run)
|
self._on_tool_start(tool_run)
|
||||||
|
@ -59,6 +59,7 @@ class LangChainTracer(BaseTracer):
|
|||||||
messages: List[List[BaseMessage]],
|
messages: List[List[BaseMessage]],
|
||||||
*,
|
*,
|
||||||
run_id: UUID,
|
run_id: UUID,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -75,6 +76,7 @@ class LangChainTracer(BaseTracer):
|
|||||||
execution_order=execution_order,
|
execution_order=execution_order,
|
||||||
child_execution_order=execution_order,
|
child_execution_order=execution_order,
|
||||||
run_type=RunTypeEnum.llm,
|
run_type=RunTypeEnum.llm,
|
||||||
|
tags=tags,
|
||||||
)
|
)
|
||||||
self._start_trace(chat_model_run)
|
self._start_trace(chat_model_run)
|
||||||
self._on_chat_model_start(chat_model_run)
|
self._on_chat_model_start(chat_model_run)
|
||||||
|
@ -94,6 +94,7 @@ class Run(BaseRunV2):
|
|||||||
execution_order: int
|
execution_order: int
|
||||||
child_execution_order: int
|
child_execution_order: int
|
||||||
child_runs: List[Run] = Field(default_factory=list)
|
child_runs: List[Run] = Field(default_factory=list)
|
||||||
|
tags: Optional[List[str]] = Field(default_factory=list)
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def assign_name(cls, values: dict) -> dict:
|
def assign_name(cls, values: dict) -> dict:
|
||||||
|
@ -36,6 +36,7 @@ class Chain(Serializable, ABC):
|
|||||||
verbose: bool = Field(
|
verbose: bool = Field(
|
||||||
default_factory=_get_verbosity
|
default_factory=_get_verbosity
|
||||||
) # Whether to print the response text
|
) # Whether to print the response text
|
||||||
|
tags: Optional[List[str]] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -111,6 +112,7 @@ class Chain(Serializable, ABC):
|
|||||||
return_only_outputs: bool = False,
|
return_only_outputs: bool = False,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
*,
|
*,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
include_run_info: bool = False,
|
include_run_info: bool = False,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Run the logic of this chain and add to output if desired.
|
"""Run the logic of this chain and add to output if desired.
|
||||||
@ -129,7 +131,7 @@ class Chain(Serializable, ABC):
|
|||||||
"""
|
"""
|
||||||
inputs = self.prep_inputs(inputs)
|
inputs = self.prep_inputs(inputs)
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = CallbackManager.configure(
|
||||||
callbacks, self.callbacks, self.verbose
|
callbacks, self.callbacks, self.verbose, tags, self.tags
|
||||||
)
|
)
|
||||||
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
@ -159,6 +161,7 @@ class Chain(Serializable, ABC):
|
|||||||
return_only_outputs: bool = False,
|
return_only_outputs: bool = False,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
*,
|
*,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
include_run_info: bool = False,
|
include_run_info: bool = False,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Run the logic of this chain and add to output if desired.
|
"""Run the logic of this chain and add to output if desired.
|
||||||
@ -177,7 +180,7 @@ class Chain(Serializable, ABC):
|
|||||||
"""
|
"""
|
||||||
inputs = self.prep_inputs(inputs)
|
inputs = self.prep_inputs(inputs)
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
callbacks, self.callbacks, self.verbose
|
callbacks, self.callbacks, self.verbose, tags, self.tags
|
||||||
)
|
)
|
||||||
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
|
||||||
run_manager = await callback_manager.on_chain_start(
|
run_manager = await callback_manager.on_chain_start(
|
||||||
@ -244,7 +247,13 @@ class Chain(Serializable, ABC):
|
|||||||
"""Call the chain on all inputs in the list."""
|
"""Call the chain on all inputs in the list."""
|
||||||
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
return [self(inputs, callbacks=callbacks) for inputs in input_list]
|
||||||
|
|
||||||
def run(self, *args: Any, callbacks: Callbacks = None, **kwargs: Any) -> str:
|
def run(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
"""Run the chain as text in, text out or multiple variables, text out."""
|
"""Run the chain as text in, text out or multiple variables, text out."""
|
||||||
if len(self.output_keys) != 1:
|
if len(self.output_keys) != 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -255,10 +264,10 @@ class Chain(Serializable, ABC):
|
|||||||
if args and not kwargs:
|
if args and not kwargs:
|
||||||
if len(args) != 1:
|
if len(args) != 1:
|
||||||
raise ValueError("`run` supports only one positional argument.")
|
raise ValueError("`run` supports only one positional argument.")
|
||||||
return self(args[0], callbacks=callbacks)[self.output_keys[0]]
|
return self(args[0], callbacks=callbacks, tags=tags)[self.output_keys[0]]
|
||||||
|
|
||||||
if kwargs and not args:
|
if kwargs and not args:
|
||||||
return self(kwargs, callbacks=callbacks)[self.output_keys[0]]
|
return self(kwargs, callbacks=callbacks, tags=tags)[self.output_keys[0]]
|
||||||
|
|
||||||
if not kwargs and not args:
|
if not kwargs and not args:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -271,7 +280,13 @@ class Chain(Serializable, ABC):
|
|||||||
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
f" but not both. Got args: {args} and kwargs: {kwargs}."
|
||||||
)
|
)
|
||||||
|
|
||||||
async def arun(self, *args: Any, callbacks: Callbacks = None, **kwargs: Any) -> str:
|
async def arun(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
"""Run the chain as text in, text out or multiple variables, text out."""
|
"""Run the chain as text in, text out or multiple variables, text out."""
|
||||||
if len(self.output_keys) != 1:
|
if len(self.output_keys) != 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -282,10 +297,14 @@ class Chain(Serializable, ABC):
|
|||||||
if args and not kwargs:
|
if args and not kwargs:
|
||||||
if len(args) != 1:
|
if len(args) != 1:
|
||||||
raise ValueError("`run` supports only one positional argument.")
|
raise ValueError("`run` supports only one positional argument.")
|
||||||
return (await self.acall(args[0], callbacks=callbacks))[self.output_keys[0]]
|
return (await self.acall(args[0], callbacks=callbacks, tags=tags))[
|
||||||
|
self.output_keys[0]
|
||||||
|
]
|
||||||
|
|
||||||
if kwargs and not args:
|
if kwargs and not args:
|
||||||
return (await self.acall(kwargs, callbacks=callbacks))[self.output_keys[0]]
|
return (await self.acall(kwargs, callbacks=callbacks, tags=tags))[
|
||||||
|
self.output_keys[0]
|
||||||
|
]
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`run` supported with either positional arguments or keyword arguments"
|
f"`run` supported with either positional arguments or keyword arguments"
|
||||||
|
@ -98,7 +98,7 @@ class ConstitutionalChain(Chain):
|
|||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
response = self.chain.run(
|
response = self.chain.run(
|
||||||
**inputs,
|
**inputs,
|
||||||
callbacks=_run_manager.get_child(),
|
callbacks=_run_manager.get_child("original"),
|
||||||
)
|
)
|
||||||
initial_response = response
|
initial_response = response
|
||||||
input_prompt = self.chain.prompt.format(**inputs)
|
input_prompt = self.chain.prompt.format(**inputs)
|
||||||
@ -116,7 +116,7 @@ class ConstitutionalChain(Chain):
|
|||||||
input_prompt=input_prompt,
|
input_prompt=input_prompt,
|
||||||
output_from_model=response,
|
output_from_model=response,
|
||||||
critique_request=constitutional_principle.critique_request,
|
critique_request=constitutional_principle.critique_request,
|
||||||
callbacks=_run_manager.get_child(),
|
callbacks=_run_manager.get_child("critique"),
|
||||||
)
|
)
|
||||||
critique = self._parse_critique(
|
critique = self._parse_critique(
|
||||||
output_string=raw_critique,
|
output_string=raw_critique,
|
||||||
@ -137,7 +137,7 @@ class ConstitutionalChain(Chain):
|
|||||||
critique_request=constitutional_principle.critique_request,
|
critique_request=constitutional_principle.critique_request,
|
||||||
critique=critique,
|
critique=critique,
|
||||||
revision_request=constitutional_principle.revision_request,
|
revision_request=constitutional_principle.revision_request,
|
||||||
callbacks=_run_manager.get_child(),
|
callbacks=_run_manager.get_child("revision"),
|
||||||
).strip()
|
).strip()
|
||||||
response = revision
|
response = revision
|
||||||
critiques_and_revisions.append((critique, revision))
|
critiques_and_revisions.append((critique, revision))
|
||||||
|
@ -283,7 +283,7 @@ class LLMChain(Chain):
|
|||||||
return "llm_chain"
|
return "llm_chain"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_string(cls, llm: BaseLanguageModel, template: str) -> Chain:
|
def from_string(cls, llm: BaseLanguageModel, template: str) -> LLMChain:
|
||||||
"""Create LLMChain from LLM and template."""
|
"""Create LLMChain from LLM and template."""
|
||||||
prompt_template = PromptTemplate.from_template(template)
|
prompt_template = PromptTemplate.from_template(template)
|
||||||
return cls(llm=llm, prompt=prompt_template)
|
return cls(llm=llm, prompt=prompt_template)
|
||||||
|
@ -174,7 +174,7 @@ class SimpleSequentialChain(Chain):
|
|||||||
_input = inputs[self.input_key]
|
_input = inputs[self.input_key]
|
||||||
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
|
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
|
||||||
for i, chain in enumerate(self.chains):
|
for i, chain in enumerate(self.chains):
|
||||||
_input = chain.run(_input, callbacks=_run_manager.get_child())
|
_input = chain.run(_input, callbacks=_run_manager.get_child(f"step_{i+1}"))
|
||||||
if self.strip_outputs:
|
if self.strip_outputs:
|
||||||
_input = _input.strip()
|
_input = _input.strip()
|
||||||
_run_manager.on_text(
|
_run_manager.on_text(
|
||||||
|
@ -13,6 +13,8 @@ from langchain.callbacks.manager import (
|
|||||||
tracing_v2_enabled,
|
tracing_v2_enabled,
|
||||||
)
|
)
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
|
from langchain.chains.constitutional_ai.base import ConstitutionalChain
|
||||||
|
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
from langchain.llms import OpenAI
|
from langchain.llms import OpenAI
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
@ -160,6 +162,25 @@ def test_tracing_v2_context_manager() -> None:
|
|||||||
agent.run(questions[0]) # this should not be traced
|
agent.run(questions[0]) # this should not be traced
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracing_v2_chain_with_tags() -> None:
|
||||||
|
llm = OpenAI(temperature=0)
|
||||||
|
chain = ConstitutionalChain.from_llm(
|
||||||
|
llm,
|
||||||
|
chain=LLMChain.from_string(llm, "Q: {question} A:"),
|
||||||
|
tags=["only-root"],
|
||||||
|
constitutional_principles=[
|
||||||
|
ConstitutionalPrinciple(
|
||||||
|
critique_request="Tell if this answer is good.",
|
||||||
|
revision_request="Give a better answer.",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
if "LANGCHAIN_TRACING_V2" in os.environ:
|
||||||
|
del os.environ["LANGCHAIN_TRACING_V2"]
|
||||||
|
with tracing_v2_enabled():
|
||||||
|
chain.run("what is the meaning of life", tags=["a-tag"])
|
||||||
|
|
||||||
|
|
||||||
def test_trace_as_group() -> None:
|
def test_trace_as_group() -> None:
|
||||||
llm = OpenAI(temperature=0.9)
|
llm = OpenAI(temperature=0.9)
|
||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
|
@ -86,7 +86,7 @@ def test_callback_manager() -> None:
|
|||||||
"""Test the CallbackManager."""
|
"""Test the CallbackManager."""
|
||||||
handler1 = FakeCallbackHandler()
|
handler1 = FakeCallbackHandler()
|
||||||
handler2 = FakeCallbackHandler()
|
handler2 = FakeCallbackHandler()
|
||||||
manager = CallbackManager([handler1, handler2])
|
manager = CallbackManager(handlers=[handler1, handler2])
|
||||||
_test_callback_manager(manager, handler1, handler2)
|
_test_callback_manager(manager, handler1, handler2)
|
||||||
|
|
||||||
|
|
||||||
@ -143,7 +143,7 @@ async def test_async_callback_manager() -> None:
|
|||||||
"""Test the AsyncCallbackManager."""
|
"""Test the AsyncCallbackManager."""
|
||||||
handler1 = FakeAsyncCallbackHandler()
|
handler1 = FakeAsyncCallbackHandler()
|
||||||
handler2 = FakeAsyncCallbackHandler()
|
handler2 = FakeAsyncCallbackHandler()
|
||||||
manager = AsyncCallbackManager([handler1, handler2])
|
manager = AsyncCallbackManager(handlers=[handler1, handler2])
|
||||||
await _test_callback_manager_async(manager, handler1, handler2)
|
await _test_callback_manager_async(manager, handler1, handler2)
|
||||||
|
|
||||||
|
|
||||||
@ -153,7 +153,7 @@ async def test_async_callback_manager_sync_handler() -> None:
|
|||||||
handler1 = FakeCallbackHandler()
|
handler1 = FakeCallbackHandler()
|
||||||
handler2 = FakeAsyncCallbackHandler()
|
handler2 = FakeAsyncCallbackHandler()
|
||||||
handler3 = FakeAsyncCallbackHandler()
|
handler3 = FakeAsyncCallbackHandler()
|
||||||
manager = AsyncCallbackManager([handler1, handler2, handler3])
|
manager = AsyncCallbackManager(handlers=[handler1, handler2, handler3])
|
||||||
await _test_callback_manager_async(manager, handler1, handler2, handler3)
|
await _test_callback_manager_async(manager, handler1, handler2, handler3)
|
||||||
|
|
||||||
|
|
||||||
@ -165,11 +165,11 @@ def test_callback_manager_inheritance() -> None:
|
|||||||
FakeCallbackHandler(),
|
FakeCallbackHandler(),
|
||||||
)
|
)
|
||||||
|
|
||||||
callback_manager1 = CallbackManager([handler1, handler2])
|
callback_manager1 = CallbackManager(handlers=[handler1, handler2])
|
||||||
assert callback_manager1.handlers == [handler1, handler2]
|
assert callback_manager1.handlers == [handler1, handler2]
|
||||||
assert callback_manager1.inheritable_handlers == []
|
assert callback_manager1.inheritable_handlers == []
|
||||||
|
|
||||||
callback_manager2 = CallbackManager([])
|
callback_manager2 = CallbackManager(handlers=[])
|
||||||
assert callback_manager2.handlers == []
|
assert callback_manager2.handlers == []
|
||||||
assert callback_manager2.inheritable_handlers == []
|
assert callback_manager2.inheritable_handlers == []
|
||||||
|
|
||||||
@ -229,7 +229,7 @@ def test_callback_manager_configure(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
assert isinstance(configured_manager.handlers[4], StdOutCallbackHandler)
|
assert isinstance(configured_manager.handlers[4], StdOutCallbackHandler)
|
||||||
assert isinstance(configured_manager, CallbackManager)
|
assert isinstance(configured_manager, CallbackManager)
|
||||||
|
|
||||||
async_local_callbacks = AsyncCallbackManager([handler3, handler4])
|
async_local_callbacks = AsyncCallbackManager(handlers=[handler3, handler4])
|
||||||
async_configured_manager = AsyncCallbackManager.configure(
|
async_configured_manager = AsyncCallbackManager.configure(
|
||||||
inheritable_callbacks=inheritable_callbacks,
|
inheritable_callbacks=inheritable_callbacks,
|
||||||
local_callbacks=async_local_callbacks,
|
local_callbacks=async_local_callbacks,
|
||||||
|
Loading…
Reference in New Issue
Block a user