|
|
|
@ -19,7 +19,7 @@ from typing import (
|
|
|
|
|
from uuid import UUID
|
|
|
|
|
|
|
|
|
|
from tenacity import RetryCallState
|
|
|
|
|
from typing_extensions import TypedDict
|
|
|
|
|
from typing_extensions import NotRequired, TypedDict
|
|
|
|
|
|
|
|
|
|
from langchain_core.callbacks.base import AsyncCallbackHandler
|
|
|
|
|
from langchain_core.messages import BaseMessage
|
|
|
|
@ -30,15 +30,12 @@ from langchain_core.outputs import (
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.runnables.schema import StreamEvent
|
|
|
|
|
from langchain_core.runnables.utils import (
|
|
|
|
|
Input,
|
|
|
|
|
Output,
|
|
|
|
|
_RootEventFilter,
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.tracers.memory_stream import _MemoryStream
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from langchain_core.documents import Document
|
|
|
|
|
from langchain_core.runnables import Runnable, RunnableConfig
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
@ -46,13 +43,14 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
class RunInfo(TypedDict):
|
|
|
|
|
"""Information about a run."""
|
|
|
|
|
|
|
|
|
|
name: Optional[str]
|
|
|
|
|
name: str
|
|
|
|
|
tags: List[str]
|
|
|
|
|
metadata: Dict[str, Any]
|
|
|
|
|
run_type: str
|
|
|
|
|
inputs: NotRequired[Any]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _assign_name(name: Optional[str], serialized: Dict[str, Any]) -> Optional[str]:
|
|
|
|
|
def _assign_name(name: Optional[str], serialized: Dict[str, Any]) -> str:
|
|
|
|
|
"""Assign a name to a run."""
|
|
|
|
|
if name is not None:
|
|
|
|
|
return name
|
|
|
|
@ -60,7 +58,7 @@ def _assign_name(name: Optional[str], serialized: Dict[str, Any]) -> Optional[st
|
|
|
|
|
return serialized["name"]
|
|
|
|
|
elif "id" in serialized:
|
|
|
|
|
return serialized["id"][-1]
|
|
|
|
|
return None
|
|
|
|
|
return "Unnamed"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _AstreamEventHandler(AsyncCallbackHandler):
|
|
|
|
@ -378,6 +376,7 @@ class _AstreamEventHandler(AsyncCallbackHandler):
|
|
|
|
|
"metadata": metadata or {},
|
|
|
|
|
"name": name_,
|
|
|
|
|
"run_type": "tool",
|
|
|
|
|
"inputs": inputs,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
await self._send(
|
|
|
|
@ -397,11 +396,19 @@ class _AstreamEventHandler(AsyncCallbackHandler):
|
|
|
|
|
async def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
|
|
|
|
"""End a trace for a tool run."""
|
|
|
|
|
run_info = self.run_map.pop(run_id)
|
|
|
|
|
if "inputs" not in run_info:
|
|
|
|
|
raise AssertionError(
|
|
|
|
|
f"Run ID {run_id} is a tool call and is expected to have "
|
|
|
|
|
f"inputs associated with it."
|
|
|
|
|
)
|
|
|
|
|
inputs = run_info["inputs"]
|
|
|
|
|
|
|
|
|
|
await self._send(
|
|
|
|
|
{
|
|
|
|
|
"event": "on_tool_end",
|
|
|
|
|
"data": {
|
|
|
|
|
"output": output,
|
|
|
|
|
"input": inputs,
|
|
|
|
|
},
|
|
|
|
|
"run_id": str(run_id),
|
|
|
|
|
"name": run_info["name"],
|
|
|
|
@ -455,7 +462,7 @@ class _AstreamEventHandler(AsyncCallbackHandler):
|
|
|
|
|
{
|
|
|
|
|
"event": "on_retriever_end",
|
|
|
|
|
"data": {
|
|
|
|
|
"documents": documents,
|
|
|
|
|
"output": documents,
|
|
|
|
|
},
|
|
|
|
|
"run_id": str(run_id),
|
|
|
|
|
"name": run_info["name"],
|
|
|
|
@ -472,59 +479,3 @@ class _AstreamEventHandler(AsyncCallbackHandler):
|
|
|
|
|
def __copy__(self) -> _AstreamEventHandler:
|
|
|
|
|
"""Copy the tracer."""
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _event_stream_implementation(
|
|
|
|
|
runnable: Runnable[Input, Output],
|
|
|
|
|
input: Any,
|
|
|
|
|
config: Optional[RunnableConfig] = None,
|
|
|
|
|
*,
|
|
|
|
|
stream: _AstreamEventHandler,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> AsyncIterator[StreamEvent]:
|
|
|
|
|
"""Implementation of astream_log for a given runnable.
|
|
|
|
|
|
|
|
|
|
The implementation has been factored out (at least temporarily) as both
|
|
|
|
|
astream_log and astream_events relies on it.
|
|
|
|
|
"""
|
|
|
|
|
from langchain_core.callbacks.base import BaseCallbackManager
|
|
|
|
|
from langchain_core.runnables import ensure_config
|
|
|
|
|
|
|
|
|
|
# Assign the stream handler to the config
|
|
|
|
|
config = ensure_config(config)
|
|
|
|
|
callbacks = config.get("callbacks")
|
|
|
|
|
if callbacks is None:
|
|
|
|
|
config["callbacks"] = [stream]
|
|
|
|
|
elif isinstance(callbacks, list):
|
|
|
|
|
config["callbacks"] = callbacks + [stream]
|
|
|
|
|
elif isinstance(callbacks, BaseCallbackManager):
|
|
|
|
|
callbacks = callbacks.copy()
|
|
|
|
|
callbacks.add_handler(stream, inherit=True)
|
|
|
|
|
config["callbacks"] = callbacks
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Unexpected type for callbacks: {callbacks}."
|
|
|
|
|
"Expected None, list or AsyncCallbackManager."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Call the runnable in streaming mode,
|
|
|
|
|
# add each chunk to the output stream
|
|
|
|
|
async def consume_astream() -> None:
|
|
|
|
|
try:
|
|
|
|
|
async for _ in runnable.astream(input, config, **kwargs):
|
|
|
|
|
# All the content will be picked up
|
|
|
|
|
pass
|
|
|
|
|
finally:
|
|
|
|
|
await stream.send_stream.aclose()
|
|
|
|
|
|
|
|
|
|
# Start the runnable in a task, so we can start consuming output
|
|
|
|
|
task = asyncio.create_task(consume_astream())
|
|
|
|
|
try:
|
|
|
|
|
async for log in stream:
|
|
|
|
|
yield log
|
|
|
|
|
finally:
|
|
|
|
|
# Wait for the runnable to finish, if not cancelled (eg. by break)
|
|
|
|
|
try:
|
|
|
|
|
await task
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
pass
|
|
|
|
|