Add support for showing IO to chain group (#10510)

As well as error propagation
pull/10698/head
William FH 1 year ago committed by GitHub
parent 2c957de2fc
commit c5078fb13c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -213,17 +213,20 @@ def trace_as_chain_group(
group_name: str,
callback_manager: Optional[CallbackManager] = None,
*,
inputs: Optional[Dict[str, Any]] = None,
project_name: Optional[str] = None,
example_id: Optional[Union[str, UUID]] = None,
run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
) -> Generator[CallbackManager, None, None]:
) -> Generator[CallbackManagerForChainGroup, None, None]:
"""Get a callback manager for a chain group in a context manager.
Useful for grouping different calls together as a single run even if
they aren't composed in a single chain.
Args:
group_name (str): The name of the chain group.
callback_manager (CallbackManager, optional): The callback manager to use.
inputs (Dict[str, Any], optional): The inputs to the chain group.
project_name (str, optional): The name of the project.
Defaults to None.
example_id (str or UUID, optional): The ID of the example.
@ -233,13 +236,17 @@ def trace_as_chain_group(
Defaults to None.
Returns:
CallbackManager: The callback manager for the chain group.
CallbackManagerForChainGroup: The callback manager for the chain group.
Example:
>>> with trace_as_chain_group("group_name") as manager:
... # Use the callback manager for the chain group
... llm.predict("Foo", callbacks=manager)
"""
.. code-block:: python
llm_input = "Foo"
with trace_as_chain_group("group_name", inputs={"input": llm_input}) as manager:
# Use the callback manager for the chain group
res = llm.predict(llm_input, callbacks=manager)
manager.on_chain_end({"output": res})
""" # noqa: E501
cb = cast(
Callbacks,
[
@ -256,9 +263,27 @@ def trace_as_chain_group(
inheritable_tags=tags,
)
run_manager = cm.on_chain_start({"name": group_name}, {}, run_id=run_id)
yield run_manager.get_child()
run_manager.on_chain_end({})
run_manager = cm.on_chain_start({"name": group_name}, inputs or {}, run_id=run_id)
child_cm = run_manager.get_child()
group_cm = CallbackManagerForChainGroup(
child_cm.handlers,
child_cm.inheritable_handlers,
child_cm.parent_run_id,
parent_run_manager=run_manager,
tags=child_cm.tags,
inheritable_tags=child_cm.inheritable_tags,
metadata=child_cm.metadata,
inheritable_metadata=child_cm.inheritable_metadata,
)
try:
yield group_cm
except Exception as e:
if not group_cm.ended:
run_manager.on_chain_error(e)
raise e
else:
if not group_cm.ended:
run_manager.on_chain_end({})
@asynccontextmanager
@ -266,17 +291,20 @@ async def atrace_as_chain_group(
group_name: str,
callback_manager: Optional[AsyncCallbackManager] = None,
*,
inputs: Optional[Dict[str, Any]] = None,
project_name: Optional[str] = None,
example_id: Optional[Union[str, UUID]] = None,
run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
) -> AsyncGenerator[AsyncCallbackManager, None]:
) -> AsyncGenerator[AsyncCallbackManagerForChainGroup, None]:
"""Get an async callback manager for a chain group in a context manager.
Useful for grouping different async calls together as a single run even if
they aren't composed in a single chain.
Args:
group_name (str): The name of the chain group.
callback_manager (AsyncCallbackManager, optional): The async callback manager to use,
which manages tracing and other callback behavior.
project_name (str, optional): The name of the project.
Defaults to None.
example_id (str or UUID, optional): The ID of the example.
@ -288,10 +316,14 @@ async def atrace_as_chain_group(
AsyncCallbackManager: The async callback manager for the chain group.
Example:
>>> async with atrace_as_chain_group("group_name") as manager:
... # Use the async callback manager for the chain group
... await llm.apredict("Foo", callbacks=manager)
"""
.. code-block:: python
llm_input = "Foo"
async with atrace_as_chain_group("group_name", inputs={"input": llm_input}) as manager:
# Use the async callback manager for the chain group
res = await llm.apredict(llm_input, callbacks=manager)
await manager.on_chain_end({"output": res})
""" # noqa: E501
cb = cast(
Callbacks,
[
@ -305,11 +337,29 @@ async def atrace_as_chain_group(
)
cm = AsyncCallbackManager.configure(inheritable_callbacks=cb, inheritable_tags=tags)
run_manager = await cm.on_chain_start({"name": group_name}, {}, run_id=run_id)
run_manager = await cm.on_chain_start(
{"name": group_name}, inputs or {}, run_id=run_id
)
child_cm = run_manager.get_child()
group_cm = AsyncCallbackManagerForChainGroup(
child_cm.handlers,
child_cm.inheritable_handlers,
child_cm.parent_run_id,
parent_run_manager=run_manager,
tags=child_cm.tags,
inheritable_tags=child_cm.inheritable_tags,
metadata=child_cm.metadata,
inheritable_metadata=child_cm.inheritable_metadata,
)
try:
yield run_manager.get_child()
finally:
await run_manager.on_chain_end({})
yield group_cm
except Exception as e:
if not group_cm.ended:
await run_manager.on_chain_error(e)
raise e
else:
if not group_cm.ended:
await run_manager.on_chain_end({})
def _handle_event(
@ -1342,6 +1392,48 @@ class CallbackManager(BaseCallbackManager):
)
class CallbackManagerForChainGroup(CallbackManager):
def __init__(
self,
handlers: List[BaseCallbackHandler],
inheritable_handlers: List[BaseCallbackHandler] | None = None,
parent_run_id: UUID | None = None,
*,
parent_run_manager: CallbackManagerForChainRun,
**kwargs: Any,
) -> None:
super().__init__(
handlers,
inheritable_handlers,
parent_run_id,
**kwargs,
)
self.parent_run_manager = parent_run_manager
self.ended = False
def on_chain_end(self, outputs: Union[Dict[str, Any], Any], **kwargs: Any) -> None:
"""Run when traced chain group ends.
Args:
outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
"""
self.ended = True
return self.parent_run_manager.on_chain_end(outputs, **kwargs)
def on_chain_error(
self,
error: BaseException,
**kwargs: Any,
) -> None:
"""Run when chain errors.
Args:
error (Exception or KeyboardInterrupt): The error.
"""
self.ended = True
return self.parent_run_manager.on_chain_error(error, **kwargs)
class AsyncCallbackManager(BaseCallbackManager):
"""Async callback manager that handles callbacks from LangChain."""
@ -1634,6 +1726,50 @@ class AsyncCallbackManager(BaseCallbackManager):
)
class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
def __init__(
self,
handlers: List[BaseCallbackHandler],
inheritable_handlers: List[BaseCallbackHandler] | None = None,
parent_run_id: UUID | None = None,
*,
parent_run_manager: AsyncCallbackManagerForChainRun,
**kwargs: Any,
) -> None:
super().__init__(
handlers,
inheritable_handlers,
parent_run_id,
**kwargs,
)
self.parent_run_manager = parent_run_manager
self.ended = False
async def on_chain_end(
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
) -> None:
"""Run when traced chain group ends.
Args:
outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
"""
self.ended = True
await self.parent_run_manager.on_chain_end(outputs, **kwargs)
async def on_chain_error(
self,
error: BaseException,
**kwargs: Any,
) -> None:
"""Run when chain errors.
Args:
error (Exception or KeyboardInterrupt): The error.
"""
self.ended = True
await self.parent_run_manager.on_chain_error(error, **kwargs)
T = TypeVar("T", CallbackManager, AsyncCallbackManager)

@ -222,13 +222,15 @@ def test_trace_as_group() -> None:
template="What is a good name for a company that makes {product}?",
)
chain = LLMChain(llm=llm, prompt=prompt)
with trace_as_chain_group("my_group") as group_manager:
with trace_as_chain_group("my_group", inputs={"input": "cars"}) as group_manager:
chain.run(product="cars", callbacks=group_manager)
chain.run(product="computers", callbacks=group_manager)
chain.run(product="toys", callbacks=group_manager)
final_res = chain.run(product="toys", callbacks=group_manager)
group_manager.on_chain_end({"output": final_res})
with trace_as_chain_group("my_group_2") as group_manager:
chain.run(product="toys", callbacks=group_manager)
with trace_as_chain_group("my_group_2", inputs={"input": "toys"}) as group_manager:
final_res = chain.run(product="toys", callbacks=group_manager)
group_manager.on_chain_end({"output": final_res})
def test_trace_as_group_with_env_set() -> None:
@ -239,13 +241,19 @@ def test_trace_as_group_with_env_set() -> None:
template="What is a good name for a company that makes {product}?",
)
chain = LLMChain(llm=llm, prompt=prompt)
with trace_as_chain_group("my_group") as group_manager:
with trace_as_chain_group(
"my_group_env_set", inputs={"input": "cars"}
) as group_manager:
chain.run(product="cars", callbacks=group_manager)
chain.run(product="computers", callbacks=group_manager)
chain.run(product="toys", callbacks=group_manager)
final_res = chain.run(product="toys", callbacks=group_manager)
group_manager.on_chain_end({"output": final_res})
with trace_as_chain_group("my_group_2") as group_manager:
chain.run(product="toys", callbacks=group_manager)
with trace_as_chain_group(
"my_group_2_env_set", inputs={"input": "toys"}
) as group_manager:
final_res = chain.run(product="toys", callbacks=group_manager)
group_manager.on_chain_end({"output": final_res})
@pytest.mark.asyncio
@ -256,16 +264,19 @@ async def test_trace_as_group_async() -> None:
template="What is a good name for a company that makes {product}?",
)
chain = LLMChain(llm=llm, prompt=prompt)
async with atrace_as_chain_group("my_group") as group_manager:
async with atrace_as_chain_group("my_async_group") as group_manager:
await chain.arun(product="cars", callbacks=group_manager)
await chain.arun(product="computers", callbacks=group_manager)
await chain.arun(product="toys", callbacks=group_manager)
async with atrace_as_chain_group("my_group_2") as group_manager:
await asyncio.gather(
async with atrace_as_chain_group(
"my_async_group_2", inputs={"input": "toys"}
) as group_manager:
res = await asyncio.gather(
*[
chain.arun(product="toys", callbacks=group_manager),
chain.arun(product="computers", callbacks=group_manager),
chain.arun(product="cars", callbacks=group_manager),
]
)
await group_manager.on_chain_end({"output": res})

Loading…
Cancel
Save