Expose handle_event and ahandle_events as public API (#12181)

Expose functionality to handle generic events.
pull/12232/head
Eugene Yurtsev 9 months ago committed by GitHub
parent 67c4fd0ad0
commit 079d1f3b8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -374,14 +374,25 @@ async def atrace_as_chain_group(
await run_manager.on_chain_end({})
def _handle_event(
def handle_event(
handlers: List[BaseCallbackHandler],
event_name: str,
ignore_condition_name: Optional[str],
*args: Any,
**kwargs: Any,
) -> None:
"""Generic event handler for CallbackManager."""
"""Generic event handler for CallbackManager.
Note: This function is used by langserve to handle events.
Args:
handlers: The list of handlers that will handle the event
event_name: The name of the event (e.g., "on_llm_start")
ignore_condition_name: Name of the attribute defined on handler
that if True will cause the handler to be skipped for the given event
*args: The arguments to pass to the event handler
**kwargs: The keyword arguments to pass to the event handler
"""
coros: List[Coroutine[Any, Any, Any]] = []
try:
@ -398,7 +409,7 @@ def _handle_event(
if event_name == "on_chat_model_start":
if message_strings is None:
message_strings = [get_buffer_string(m) for m in args[1]]
_handle_event(
handle_event(
[handler],
"on_llm_start",
"ignore_llm",
@ -508,14 +519,25 @@ async def _ahandle_event_for_handler(
raise e
async def _ahandle_event(
async def ahandle_event(
handlers: List[BaseCallbackHandler],
event_name: str,
ignore_condition_name: Optional[str],
*args: Any,
**kwargs: Any,
) -> None:
"""Generic event handler for AsyncCallbackManager."""
"""Generic event handler for AsyncCallbackManager.
Note: This function is used by langserve to handle events.
Args:
handlers: The list of handlers that will handle the event
event_name: The name of the event (e.g., "on_llm_start")
ignore_condition_name: Name of the attribute defined on handler
that if True will cause the handler to be skipped for the given event
*args: The arguments to pass to the event handler
**kwargs: The keyword arguments to pass to the event handler
"""
for handler in [h for h in handlers if h.run_inline]:
await _ahandle_event_for_handler(
handler, event_name, ignore_condition_name, *args, **kwargs
@ -606,7 +628,7 @@ class RunManager(BaseRunManager):
Returns:
Any: The result of the callback.
"""
_handle_event(
handle_event(
self.handlers,
"on_text",
None,
@ -622,7 +644,7 @@ class RunManager(BaseRunManager):
retry_state: RetryCallState,
**kwargs: Any,
) -> None:
_handle_event(
handle_event(
self.handlers,
"on_retry",
"ignore_retry",
@ -672,7 +694,7 @@ class AsyncRunManager(BaseRunManager):
Returns:
Any: The result of the callback.
"""
await _ahandle_event(
await ahandle_event(
self.handlers,
"on_text",
None,
@ -688,7 +710,7 @@ class AsyncRunManager(BaseRunManager):
retry_state: RetryCallState,
**kwargs: Any,
) -> None:
await _ahandle_event(
await ahandle_event(
self.handlers,
"on_retry",
"ignore_retry",
@ -737,7 +759,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
Args:
token (str): The new token.
"""
_handle_event(
handle_event(
self.handlers,
"on_llm_new_token",
"ignore_llm",
@ -755,7 +777,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
Args:
response (LLMResult): The LLM result.
"""
_handle_event(
handle_event(
self.handlers,
"on_llm_end",
"ignore_llm",
@ -776,7 +798,7 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
Args:
error (Exception or KeyboardInterrupt): The error.
"""
_handle_event(
handle_event(
self.handlers,
"on_llm_error",
"ignore_llm",
@ -803,7 +825,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
Args:
token (str): The new token.
"""
await _ahandle_event(
await ahandle_event(
self.handlers,
"on_llm_new_token",
"ignore_llm",
@ -821,7 +843,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
Args:
response (LLMResult): The LLM result.
"""
await _ahandle_event(
await ahandle_event(
self.handlers,
"on_llm_end",
"ignore_llm",
@ -842,7 +864,7 @@ class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
Args:
error (Exception or KeyboardInterrupt): The error.
"""
await _ahandle_event(
await ahandle_event(
self.handlers,
"on_llm_error",
"ignore_llm",
@ -863,7 +885,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
Args:
outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
"""
_handle_event(
handle_event(
self.handlers,
"on_chain_end",
"ignore_chain",
@ -884,7 +906,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
Args:
error (Exception or KeyboardInterrupt): The error.
"""
_handle_event(
handle_event(
self.handlers,
"on_chain_error",
"ignore_chain",
@ -904,7 +926,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
Returns:
Any: The result of the callback.
"""
_handle_event(
handle_event(
self.handlers,
"on_agent_action",
"ignore_agent",
@ -924,7 +946,7 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
Returns:
Any: The result of the callback.
"""
_handle_event(
handle_event(
self.handlers,
"on_agent_finish",
"ignore_agent",
@ -947,7 +969,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
Args:
outputs (Union[Dict[str, Any], Any]): The outputs of the chain.
"""
await _ahandle_event(
await ahandle_event(
self.handlers,
"on_chain_end",
"ignore_chain",
@ -968,7 +990,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
Args:
error (Exception or KeyboardInterrupt): The error.
"""
await _ahandle_event(
await ahandle_event(
self.handlers,
"on_chain_error",
"ignore_chain",
@ -988,7 +1010,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
Returns:
Any: The result of the callback.
"""
await _ahandle_event(
await ahandle_event(
self.handlers,
"on_agent_action",
"ignore_agent",
@ -1008,7 +1030,7 @@ class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
Returns:
Any: The result of the callback.
"""
await _ahandle_event(
await ahandle_event(
self.handlers,
"on_agent_finish",
"ignore_agent",
@ -1033,7 +1055,7 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
Args:
output (str): The output of the tool.
"""
_handle_event(
handle_event(
self.handlers,
"on_tool_end",
"ignore_agent",
@ -1054,7 +1076,7 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
Args:
error (Exception or KeyboardInterrupt): The error.
"""
_handle_event(
handle_event(
self.handlers,
"on_tool_error",
"ignore_agent",
@ -1075,7 +1097,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
Args:
output (str): The output of the tool.
"""
await _ahandle_event(
await ahandle_event(
self.handlers,
"on_tool_end",
"ignore_agent",
@ -1096,7 +1118,7 @@ class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
Args:
error (Exception or KeyboardInterrupt): The error.
"""
await _ahandle_event(
await ahandle_event(
self.handlers,
"on_tool_error",
"ignore_agent",
@ -1117,7 +1139,7 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin):
**kwargs: Any,
) -> None:
"""Run when retriever ends running."""
_handle_event(
handle_event(
self.handlers,
"on_retriever_end",
"ignore_retriever",
@ -1134,7 +1156,7 @@ class CallbackManagerForRetrieverRun(ParentRunManager, RetrieverManagerMixin):
**kwargs: Any,
) -> None:
"""Run when retriever errors."""
_handle_event(
handle_event(
self.handlers,
"on_retriever_error",
"ignore_retriever",
@ -1156,7 +1178,7 @@ class AsyncCallbackManagerForRetrieverRun(
self, documents: Sequence[Document], **kwargs: Any
) -> None:
"""Run when retriever ends running."""
await _ahandle_event(
await ahandle_event(
self.handlers,
"on_retriever_end",
"ignore_retriever",
@ -1173,7 +1195,7 @@ class AsyncCallbackManagerForRetrieverRun(
**kwargs: Any,
) -> None:
"""Run when retriever errors."""
await _ahandle_event(
await ahandle_event(
self.handlers,
"on_retriever_error",
"ignore_retriever",
@ -1208,7 +1230,7 @@ class CallbackManager(BaseCallbackManager):
managers = []
for prompt in prompts:
run_id_ = uuid.uuid4()
_handle_event(
handle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
@ -1257,7 +1279,7 @@ class CallbackManager(BaseCallbackManager):
managers = []
for message_list in messages:
run_id_ = uuid.uuid4()
_handle_event(
handle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
@ -1304,7 +1326,7 @@ class CallbackManager(BaseCallbackManager):
"""
if run_id is None:
run_id = uuid.uuid4()
_handle_event(
handle_event(
self.handlers,
"on_chain_start",
"ignore_chain",
@ -1350,7 +1372,7 @@ class CallbackManager(BaseCallbackManager):
if run_id is None:
run_id = uuid.uuid4()
_handle_event(
handle_event(
self.handlers,
"on_tool_start",
"ignore_agent",
@ -1386,7 +1408,7 @@ class CallbackManager(BaseCallbackManager):
if run_id is None:
run_id = uuid.uuid4()
_handle_event(
handle_event(
self.handlers,
"on_retriever_start",
"ignore_retriever",
@ -1531,7 +1553,7 @@ class AsyncCallbackManager(BaseCallbackManager):
run_id_ = uuid.uuid4()
tasks.append(
_ahandle_event(
ahandle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
@ -1587,7 +1609,7 @@ class AsyncCallbackManager(BaseCallbackManager):
run_id_ = uuid.uuid4()
tasks.append(
_ahandle_event(
ahandle_event(
self.handlers,
"on_chat_model_start",
"ignore_chat_model",
@ -1638,7 +1660,7 @@ class AsyncCallbackManager(BaseCallbackManager):
if run_id is None:
run_id = uuid.uuid4()
await _ahandle_event(
await ahandle_event(
self.handlers,
"on_chain_start",
"ignore_chain",
@ -1686,7 +1708,7 @@ class AsyncCallbackManager(BaseCallbackManager):
if run_id is None:
run_id = uuid.uuid4()
await _ahandle_event(
await ahandle_event(
self.handlers,
"on_tool_start",
"ignore_agent",
@ -1722,7 +1744,7 @@ class AsyncCallbackManager(BaseCallbackManager):
if run_id is None:
run_id = uuid.uuid4()
await _ahandle_event(
await ahandle_event(
self.handlers,
"on_retriever_start",
"ignore_retriever",

Loading…
Cancel
Save