From c5078fb13c768104962d2a28e7f65c527d9c4dd0 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Sun, 17 Sep 2023 00:47:51 -0700 Subject: [PATCH] Add support for showing IO to chain group (#10510) As well as error propagation --- libs/langchain/langchain/callbacks/manager.py | 172 ++++++++++++++++-- .../callbacks/test_langchain_tracer.py | 33 ++-- 2 files changed, 176 insertions(+), 29 deletions(-) diff --git a/libs/langchain/langchain/callbacks/manager.py b/libs/langchain/langchain/callbacks/manager.py index ccd84a85e6..eb4e0ebdec 100644 --- a/libs/langchain/langchain/callbacks/manager.py +++ b/libs/langchain/langchain/callbacks/manager.py @@ -213,17 +213,20 @@ def trace_as_chain_group( group_name: str, callback_manager: Optional[CallbackManager] = None, *, + inputs: Optional[Dict[str, Any]] = None, project_name: Optional[str] = None, example_id: Optional[Union[str, UUID]] = None, run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, -) -> Generator[CallbackManager, None, None]: +) -> Generator[CallbackManagerForChainGroup, None, None]: """Get a callback manager for a chain group in a context manager. Useful for grouping different calls together as a single run even if they aren't composed in a single chain. Args: group_name (str): The name of the chain group. + callback_manager (CallbackManager, optional): The callback manager to use. + inputs (Dict[str, Any], optional): The inputs to the chain group. project_name (str, optional): The name of the project. Defaults to None. example_id (str or UUID, optional): The ID of the example. @@ -233,13 +236,17 @@ def trace_as_chain_group( Defaults to None. Returns: - CallbackManager: The callback manager for the chain group. + CallbackManagerForChainGroup: The callback manager for the chain group. Example: - >>> with trace_as_chain_group("group_name") as manager: - ... # Use the callback manager for the chain group - ... llm.predict("Foo", callbacks=manager) - """ + .. code-block:: python + + llm_input = "Foo" + with trace_as_chain_group("group_name", inputs={"input": llm_input}) as manager: + # Use the callback manager for the chain group + res = llm.predict(llm_input, callbacks=manager) + manager.on_chain_end({"output": res}) + """ # noqa: E501 cb = cast( Callbacks, [ @@ -256,9 +263,27 @@ def trace_as_chain_group( inheritable_tags=tags, ) - run_manager = cm.on_chain_start({"name": group_name}, {}, run_id=run_id) - yield run_manager.get_child() - run_manager.on_chain_end({}) + run_manager = cm.on_chain_start({"name": group_name}, inputs or {}, run_id=run_id) + child_cm = run_manager.get_child() + group_cm = CallbackManagerForChainGroup( + child_cm.handlers, + child_cm.inheritable_handlers, + child_cm.parent_run_id, + parent_run_manager=run_manager, + tags=child_cm.tags, + inheritable_tags=child_cm.inheritable_tags, + metadata=child_cm.metadata, + inheritable_metadata=child_cm.inheritable_metadata, + ) + try: + yield group_cm + except Exception as e: + if not group_cm.ended: + run_manager.on_chain_error(e) + raise e + else: + if not group_cm.ended: + run_manager.on_chain_end({}) @asynccontextmanager @@ -266,17 +291,20 @@ async def atrace_as_chain_group( group_name: str, callback_manager: Optional[AsyncCallbackManager] = None, *, + inputs: Optional[Dict[str, Any]] = None, project_name: Optional[str] = None, example_id: Optional[Union[str, UUID]] = None, run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, -) -> AsyncGenerator[AsyncCallbackManager, None]: +) -> AsyncGenerator[AsyncCallbackManagerForChainGroup, None]: """Get an async callback manager for a chain group in a context manager. Useful for grouping different async calls together as a single run even if they aren't composed in a single chain. Args: group_name (str): The name of the chain group. + callback_manager (AsyncCallbackManager, optional): The async callback manager to use, + which manages tracing and other callback behavior. project_name (str, optional): The name of the project. Defaults to None. example_id (str or UUID, optional): The ID of the example. @@ -288,10 +316,14 @@ async def atrace_as_chain_group( AsyncCallbackManager: The async callback manager for the chain group. Example: - >>> async with atrace_as_chain_group("group_name") as manager: - ... # Use the async callback manager for the chain group - ... await llm.apredict("Foo", callbacks=manager) - """ + .. code-block:: python + + llm_input = "Foo" + async with atrace_as_chain_group("group_name", inputs={"input": llm_input}) as manager: + # Use the async callback manager for the chain group + res = await llm.apredict(llm_input, callbacks=manager) + await manager.on_chain_end({"output": res}) + """ # noqa: E501 cb = cast( Callbacks, [ @@ -305,11 +337,29 @@ async def atrace_as_chain_group( ) cm = AsyncCallbackManager.configure(inheritable_callbacks=cb, inheritable_tags=tags) - run_manager = await cm.on_chain_start({"name": group_name}, {}, run_id=run_id) + run_manager = await cm.on_chain_start( + {"name": group_name}, inputs or {}, run_id=run_id + ) + child_cm = run_manager.get_child() + group_cm = AsyncCallbackManagerForChainGroup( + child_cm.handlers, + child_cm.inheritable_handlers, + child_cm.parent_run_id, + parent_run_manager=run_manager, + tags=child_cm.tags, + inheritable_tags=child_cm.inheritable_tags, + metadata=child_cm.metadata, + inheritable_metadata=child_cm.inheritable_metadata, + ) try: - yield run_manager.get_child() - finally: - await run_manager.on_chain_end({}) + yield group_cm + except Exception as e: + if not group_cm.ended: + await run_manager.on_chain_error(e) + raise e + else: + if not group_cm.ended: + await run_manager.on_chain_end({}) def _handle_event( @@ -1342,6 +1392,48 @@ class CallbackManager(BaseCallbackManager): ) +class CallbackManagerForChainGroup(CallbackManager): + def __init__( + self, + handlers: List[BaseCallbackHandler], + inheritable_handlers: List[BaseCallbackHandler] | None = None, + parent_run_id: UUID | None = None, + *, + parent_run_manager: CallbackManagerForChainRun, + **kwargs: Any, + ) -> None: + super().__init__( + handlers, + inheritable_handlers, + parent_run_id, + **kwargs, + ) + self.parent_run_manager = parent_run_manager + self.ended = False + + def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None: + """Run when traced chain group ends. + + Args: + outputs (Union[Dict[str, Any], Any]): The outputs of the chain. + """ + self.ended = True + return self.parent_run_manager.on_chain_end(outputs, **kwargs) + + def on_chain_error( + self, + error: BaseException, + **kwargs: Any, + ) -> None: + """Run when chain errors. + + Args: + error (Exception or KeyboardInterrupt): The error. + """ + self.ended = True + return self.parent_run_manager.on_chain_error(error, **kwargs) + + class AsyncCallbackManager(BaseCallbackManager): """Async callback manager that handles callbacks from LangChain.""" @@ -1634,6 +1726,50 @@ class AsyncCallbackManager(BaseCallbackManager): ) +class AsyncCallbackManagerForChainGroup(AsyncCallbackManager): + def __init__( + self, + handlers: List[BaseCallbackHandler], + inheritable_handlers: List[BaseCallbackHandler] | None = None, + parent_run_id: UUID | None = None, + *, + parent_run_manager: AsyncCallbackManagerForChainRun, + **kwargs: Any, + ) -> None: + super().__init__( + handlers, + inheritable_handlers, + parent_run_id, + **kwargs, + ) + self.parent_run_manager = parent_run_manager + self.ended = False + + async def on_chain_end( + self, outputs: Union[Dict[str, Any], Any], **kwargs: Any + ) -> None: + """Run when traced chain group ends. + + Args: + outputs (Union[Dict[str, Any], Any]): The outputs of the chain. + """ + self.ended = True + await self.parent_run_manager.on_chain_end(outputs, **kwargs) + + async def on_chain_error( + self, + error: BaseException, + **kwargs: Any, + ) -> None: + """Run when chain errors. + + Args: + error (Exception or KeyboardInterrupt): The error. + """ + self.ended = True + await self.parent_run_manager.on_chain_error(error, **kwargs) + + T = TypeVar("T", CallbackManager, AsyncCallbackManager) diff --git a/libs/langchain/tests/integration_tests/callbacks/test_langchain_tracer.py b/libs/langchain/tests/integration_tests/callbacks/test_langchain_tracer.py index dde02bbd78..e84eae5aa5 100644 --- a/libs/langchain/tests/integration_tests/callbacks/test_langchain_tracer.py +++ b/libs/langchain/tests/integration_tests/callbacks/test_langchain_tracer.py @@ -222,13 +222,15 @@ def test_trace_as_group() -> None: template="What is a good name for a company that makes {product}?", ) chain = LLMChain(llm=llm, prompt=prompt) - with trace_as_chain_group("my_group") as group_manager: + with trace_as_chain_group("my_group", inputs={"input": "cars"}) as group_manager: chain.run(product="cars", callbacks=group_manager) chain.run(product="computers", callbacks=group_manager) - chain.run(product="toys", callbacks=group_manager) + final_res = chain.run(product="toys", callbacks=group_manager) + group_manager.on_chain_end({"output": final_res}) - with trace_as_chain_group("my_group_2") as group_manager: - chain.run(product="toys", callbacks=group_manager) + with trace_as_chain_group("my_group_2", inputs={"input": "toys"}) as group_manager: + final_res = chain.run(product="toys", callbacks=group_manager) + group_manager.on_chain_end({"output": final_res}) def test_trace_as_group_with_env_set() -> None: @@ -239,13 +241,19 @@ def test_trace_as_group_with_env_set() -> None: template="What is a good name for a company that makes {product}?", ) chain = LLMChain(llm=llm, prompt=prompt) - with trace_as_chain_group("my_group") as group_manager: + with trace_as_chain_group( + "my_group_env_set", inputs={"input": "cars"} + ) as group_manager: chain.run(product="cars", callbacks=group_manager) chain.run(product="computers", callbacks=group_manager) - chain.run(product="toys", callbacks=group_manager) + final_res = chain.run(product="toys", callbacks=group_manager) + group_manager.on_chain_end({"output": final_res}) - with trace_as_chain_group("my_group_2") as group_manager: - chain.run(product="toys", callbacks=group_manager) + with trace_as_chain_group( + "my_group_2_env_set", inputs={"input": "toys"} + ) as group_manager: + final_res = chain.run(product="toys", callbacks=group_manager) + group_manager.on_chain_end({"output": final_res}) @pytest.mark.asyncio @@ -256,16 +264,19 @@ async def test_trace_as_group_async() -> None: template="What is a good name for a company that makes {product}?", ) chain = LLMChain(llm=llm, prompt=prompt) - async with atrace_as_chain_group("my_group") as group_manager: + async with atrace_as_chain_group("my_async_group") as group_manager: await chain.arun(product="cars", callbacks=group_manager) await chain.arun(product="computers", callbacks=group_manager) await chain.arun(product="toys", callbacks=group_manager) - async with atrace_as_chain_group("my_group_2") as group_manager: - await asyncio.gather( + async with atrace_as_chain_group( + "my_async_group_2", inputs={"input": "toys"} + ) as group_manager: + res = await asyncio.gather( *[ chain.arun(product="toys", callbacks=group_manager), chain.arun(product="computers", callbacks=group_manager), chain.arun(product="cars", callbacks=group_manager), ] ) + await group_manager.on_chain_end({"output": res})