|
|
|
@ -4,13 +4,15 @@ import asyncio
|
|
|
|
|
import functools
|
|
|
|
|
import logging
|
|
|
|
|
import uuid
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
|
from contextlib import asynccontextmanager, contextmanager
|
|
|
|
|
from contextvars import Context, copy_context
|
|
|
|
|
from contextvars import copy_context
|
|
|
|
|
from typing import (
|
|
|
|
|
TYPE_CHECKING,
|
|
|
|
|
Any,
|
|
|
|
|
AsyncGenerator,
|
|
|
|
|
Callable,
|
|
|
|
|
Coroutine,
|
|
|
|
|
Dict,
|
|
|
|
|
Generator,
|
|
|
|
@ -272,25 +274,14 @@ def handle_event(
|
|
|
|
|
# we end up in a deadlock, as we'd have gotten here from a
|
|
|
|
|
# running coroutine, which we cannot interrupt to run this one.
|
|
|
|
|
# The solution is to create a new loop in a new thread.
|
|
|
|
|
with _executor_w_context(1) as executor:
|
|
|
|
|
executor.submit(_run_coros, coros).result()
|
|
|
|
|
with ThreadPoolExecutor(1) as executor:
|
|
|
|
|
executor.submit(
|
|
|
|
|
cast(Callable, copy_context().run), _run_coros, coros
|
|
|
|
|
).result()
|
|
|
|
|
else:
|
|
|
|
|
_run_coros(coros)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _set_context(context: Context) -> None:
|
|
|
|
|
for var, value in context.items():
|
|
|
|
|
var.set(value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _executor_w_context(max_workers: Optional[int] = None) -> ThreadPoolExecutor:
|
|
|
|
|
return ThreadPoolExecutor(
|
|
|
|
|
max_workers=max_workers,
|
|
|
|
|
initializer=_set_context,
|
|
|
|
|
initargs=(copy_context(),),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
|
|
|
|
|
if hasattr(asyncio, "Runner"):
|
|
|
|
|
# Python 3.11+
|
|
|
|
@ -315,7 +306,6 @@ def _run_coros(coros: List[Coroutine[Any, Any, Any]]) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _ahandle_event_for_handler(
|
|
|
|
|
executor: ThreadPoolExecutor,
|
|
|
|
|
handler: BaseCallbackHandler,
|
|
|
|
|
event_name: str,
|
|
|
|
|
ignore_condition_name: Optional[str],
|
|
|
|
@ -332,13 +322,18 @@ async def _ahandle_event_for_handler(
|
|
|
|
|
event(*args, **kwargs)
|
|
|
|
|
else:
|
|
|
|
|
await asyncio.get_event_loop().run_in_executor(
|
|
|
|
|
executor, functools.partial(event, *args, **kwargs)
|
|
|
|
|
None,
|
|
|
|
|
cast(
|
|
|
|
|
Callable,
|
|
|
|
|
functools.partial(
|
|
|
|
|
copy_context().run, event, *args, **kwargs
|
|
|
|
|
),
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
except NotImplementedError as e:
|
|
|
|
|
if event_name == "on_chat_model_start":
|
|
|
|
|
message_strings = [get_buffer_string(m) for m in args[1]]
|
|
|
|
|
await _ahandle_event_for_handler(
|
|
|
|
|
executor,
|
|
|
|
|
handler,
|
|
|
|
|
"on_llm_start",
|
|
|
|
|
"ignore_llm",
|
|
|
|
@ -380,25 +375,23 @@ async def ahandle_event(
|
|
|
|
|
*args: The arguments to pass to the event handler
|
|
|
|
|
**kwargs: The keyword arguments to pass to the event handler
|
|
|
|
|
"""
|
|
|
|
|
with _executor_w_context() as executor:
|
|
|
|
|
for handler in [h for h in handlers if h.run_inline]:
|
|
|
|
|
await _ahandle_event_for_handler(
|
|
|
|
|
executor, handler, event_name, ignore_condition_name, *args, **kwargs
|
|
|
|
|
)
|
|
|
|
|
await asyncio.gather(
|
|
|
|
|
*(
|
|
|
|
|
_ahandle_event_for_handler(
|
|
|
|
|
executor,
|
|
|
|
|
handler,
|
|
|
|
|
event_name,
|
|
|
|
|
ignore_condition_name,
|
|
|
|
|
*args,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
for handler in handlers
|
|
|
|
|
if not handler.run_inline
|
|
|
|
|
for handler in [h for h in handlers if h.run_inline]:
|
|
|
|
|
await _ahandle_event_for_handler(
|
|
|
|
|
handler, event_name, ignore_condition_name, *args, **kwargs
|
|
|
|
|
)
|
|
|
|
|
await asyncio.gather(
|
|
|
|
|
*(
|
|
|
|
|
_ahandle_event_for_handler(
|
|
|
|
|
handler,
|
|
|
|
|
event_name,
|
|
|
|
|
ignore_condition_name,
|
|
|
|
|
*args,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
for handler in handlers
|
|
|
|
|
if not handler.run_inline
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BRM = TypeVar("BRM", bound="BaseRunManager")
|
|
|
|
@ -526,9 +519,17 @@ class ParentRunManager(RunManager):
|
|
|
|
|
return manager
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AsyncRunManager(BaseRunManager):
|
|
|
|
|
class AsyncRunManager(BaseRunManager, ABC):
|
|
|
|
|
"""Async Run Manager."""
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def get_sync(self) -> RunManager:
|
|
|
|
|
"""Get the equivalent sync RunManager.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
RunManager: The sync RunManager.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
async def on_text(
|
|
|
|
|
self,
|
|
|
|
|
text: str,
|
|
|
|
@ -664,6 +665,23 @@ class CallbackManagerForLLMRun(RunManager, LLMManagerMixin):
|
|
|
|
|
class AsyncCallbackManagerForLLMRun(AsyncRunManager, LLMManagerMixin):
|
|
|
|
|
"""Async callback manager for LLM run."""
|
|
|
|
|
|
|
|
|
|
def get_sync(self) -> CallbackManagerForLLMRun:
|
|
|
|
|
"""Get the equivalent sync RunManager.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
CallbackManagerForLLMRun: The sync RunManager.
|
|
|
|
|
"""
|
|
|
|
|
return CallbackManagerForLLMRun(
|
|
|
|
|
run_id=self.run_id,
|
|
|
|
|
handlers=self.handlers,
|
|
|
|
|
inheritable_handlers=self.inheritable_handlers,
|
|
|
|
|
parent_run_id=self.parent_run_id,
|
|
|
|
|
tags=self.tags,
|
|
|
|
|
inheritable_tags=self.inheritable_tags,
|
|
|
|
|
metadata=self.metadata,
|
|
|
|
|
inheritable_metadata=self.inheritable_metadata,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def on_llm_new_token(
|
|
|
|
|
self,
|
|
|
|
|
token: str,
|
|
|
|
@ -818,6 +836,23 @@ class CallbackManagerForChainRun(ParentRunManager, ChainManagerMixin):
|
|
|
|
|
class AsyncCallbackManagerForChainRun(AsyncParentRunManager, ChainManagerMixin):
|
|
|
|
|
"""Async callback manager for chain run."""
|
|
|
|
|
|
|
|
|
|
def get_sync(self) -> CallbackManagerForChainRun:
|
|
|
|
|
"""Get the equivalent sync RunManager.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
CallbackManagerForChainRun: The sync RunManager.
|
|
|
|
|
"""
|
|
|
|
|
return CallbackManagerForChainRun(
|
|
|
|
|
run_id=self.run_id,
|
|
|
|
|
handlers=self.handlers,
|
|
|
|
|
inheritable_handlers=self.inheritable_handlers,
|
|
|
|
|
parent_run_id=self.parent_run_id,
|
|
|
|
|
tags=self.tags,
|
|
|
|
|
inheritable_tags=self.inheritable_tags,
|
|
|
|
|
metadata=self.metadata,
|
|
|
|
|
inheritable_metadata=self.inheritable_metadata,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def on_chain_end(
|
|
|
|
|
self, outputs: Union[Dict[str, Any], Any], **kwargs: Any
|
|
|
|
|
) -> None:
|
|
|
|
@ -948,6 +983,23 @@ class CallbackManagerForToolRun(ParentRunManager, ToolManagerMixin):
|
|
|
|
|
class AsyncCallbackManagerForToolRun(AsyncParentRunManager, ToolManagerMixin):
|
|
|
|
|
"""Async callback manager for tool run."""
|
|
|
|
|
|
|
|
|
|
def get_sync(self) -> CallbackManagerForToolRun:
|
|
|
|
|
"""Get the equivalent sync RunManager.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
CallbackManagerForToolRun: The sync RunManager.
|
|
|
|
|
"""
|
|
|
|
|
return CallbackManagerForToolRun(
|
|
|
|
|
run_id=self.run_id,
|
|
|
|
|
handlers=self.handlers,
|
|
|
|
|
inheritable_handlers=self.inheritable_handlers,
|
|
|
|
|
parent_run_id=self.parent_run_id,
|
|
|
|
|
tags=self.tags,
|
|
|
|
|
inheritable_tags=self.inheritable_tags,
|
|
|
|
|
metadata=self.metadata,
|
|
|
|
|
inheritable_metadata=self.inheritable_metadata,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
|
|
|
|
"""Run when tool ends running.
|
|
|
|
|
|
|
|
|
@ -1031,6 +1083,23 @@ class AsyncCallbackManagerForRetrieverRun(
|
|
|
|
|
):
|
|
|
|
|
"""Async callback manager for retriever run."""
|
|
|
|
|
|
|
|
|
|
def get_sync(self) -> CallbackManagerForRetrieverRun:
|
|
|
|
|
"""Get the equivalent sync RunManager.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
CallbackManagerForRetrieverRun: The sync RunManager.
|
|
|
|
|
"""
|
|
|
|
|
return CallbackManagerForRetrieverRun(
|
|
|
|
|
run_id=self.run_id,
|
|
|
|
|
handlers=self.handlers,
|
|
|
|
|
inheritable_handlers=self.inheritable_handlers,
|
|
|
|
|
parent_run_id=self.parent_run_id,
|
|
|
|
|
tags=self.tags,
|
|
|
|
|
inheritable_tags=self.inheritable_tags,
|
|
|
|
|
metadata=self.metadata,
|
|
|
|
|
inheritable_metadata=self.inheritable_metadata,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def on_retriever_end(
|
|
|
|
|
self, documents: Sequence[Document], **kwargs: Any
|
|
|
|
|
) -> None:
|
|
|
|
|