From 1e40427755f3034c5c411c1d0a921cdb3e13849d Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Fri, 14 Jul 2023 10:03:16 +0100 Subject: [PATCH] Enabled nesting chain group (#7697) --- langchain/callbacks/manager.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index bd0aa12632..b40afcaaef 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -177,6 +177,7 @@ def tracing_v2_enabled( @contextmanager def trace_as_chain_group( group_name: str, + callback_manager: Optional[CallbackManager] = None, *, project_name: Optional[str] = None, example_id: Optional[Union[str, UUID]] = None, @@ -203,12 +204,19 @@ def trace_as_chain_group( ... # Use the callback manager for the chain group ... llm.predict("Foo", callbacks=manager) """ - cb = LangChainTracer( - project_name=project_name, - example_id=example_id, + cb = cast( + Callbacks, + [ + LangChainTracer( + project_name=project_name, + example_id=example_id, + ) + ] + if callback_manager is None + else callback_manager, ) cm = CallbackManager.configure( - inheritable_callbacks=[cb], + inheritable_callbacks=cb, inheritable_tags=tags, ) @@ -220,6 +228,7 @@ def trace_as_chain_group( @asynccontextmanager async def atrace_as_chain_group( group_name: str, + callback_manager: Optional[AsyncCallbackManager] = None, *, project_name: Optional[str] = None, example_id: Optional[Union[str, UUID]] = None, @@ -245,13 +254,18 @@ async def atrace_as_chain_group( ... # Use the async callback manager for the chain group ... await llm.apredict("Foo", callbacks=manager) """ - cb = LangChainTracer( - project_name=project_name, - example_id=example_id, - ) - cm = AsyncCallbackManager.configure( - inheritable_callbacks=[cb], inheritable_tags=tags + cb = cast( + Callbacks, + [ + LangChainTracer( + project_name=project_name, + example_id=example_id, + ) + ] + if callback_manager is None + else callback_manager, ) + cm = AsyncCallbackManager.configure(inheritable_callbacks=cb, inheritable_tags=tags) run_manager = await cm.on_chain_start({"name": group_name}, {}) try: