pull/21638/head
Eugene Yurtsev 2 weeks ago
parent 209814c91d
commit 807cd5409d

@ -1093,12 +1093,13 @@ class Runnable(Generic[Input, Output], ABC):
'Only version "v1" of the schema is currently supported.'
)
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_core.runnables import ensure_config
from langchain_core.tracers.event_stream import (
_AstreamEventHandler,
_event_stream_implementation,
)
handler = _AstreamEventHandler(
event_streamer = _AstreamEventHandler(
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
@ -1109,10 +1110,67 @@ class Runnable(Generic[Input, Output], ABC):
config = ensure_config(config)
async for event in _event_stream_implementation(
self, input, config=config, stream=handler
):
yield event
first_event_sent = False
first_event_run_id = None
# Assign the stream handler to the config
config = ensure_config(config)
callbacks = config.get("callbacks")
if callbacks is None:
config["callbacks"] = [event_streamer]
elif isinstance(callbacks, list):
config["callbacks"] = callbacks + [event_streamer]
elif isinstance(callbacks, BaseCallbackManager):
callbacks = callbacks.copy()
callbacks.add_handler(event_streamer, inherit=True)
config["callbacks"] = callbacks
else:
raise ValueError(
f"Unexpected type for callbacks: {callbacks}."
"Expected None, list or AsyncCallbackManager."
)
# Call the runnable in streaming mode,
# add each chunk to the output stream
async def consume_astream() -> None:
try:
async for _ in self.astream(input, config, **kwargs):
# All the content will be picked up
pass
finally:
await event_streamer.send_stream.aclose()
# Start the runnable in a task, so we can start consuming output
task = asyncio.create_task(consume_astream())
try:
async for event in event_streamer:
if not first_event_sent:
first_event_sent = True
# This is a work-around an issue where the inputs into the
# chain are not available until the entire input is consumed.
# As a temporary solution, we'll modify the input to be the input
# that was passed into the chain.
event["data"]["input"] = input
first_event_run_id = event["run_id"]
yield event
continue
if event["run_id"] == first_event_run_id and event["event"].endswith(
"_end"
):
# If it's the end event corresponding to the root runnable
# we want to include the input in the event since it's guaranteed
# to be included in the first event.
if "input" in event["data"]:
del event["data"]["input"]
yield event
finally:
# Wait for the runnable to finish, if not cancelled (eg. by break)
try:
await task
except asyncio.CancelledError:
pass
def transform(
self,

@ -19,7 +19,7 @@ from typing import (
from uuid import UUID
from tenacity import RetryCallState
from typing_extensions import TypedDict
from typing_extensions import NotRequired, TypedDict
from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.messages import BaseMessage
@ -30,15 +30,12 @@ from langchain_core.outputs import (
)
from langchain_core.runnables.schema import StreamEvent
from langchain_core.runnables.utils import (
Input,
Output,
_RootEventFilter,
)
from langchain_core.tracers.memory_stream import _MemoryStream
if TYPE_CHECKING:
from langchain_core.documents import Document
from langchain_core.runnables import Runnable, RunnableConfig
logger = logging.getLogger(__name__)
@ -46,13 +43,14 @@ logger = logging.getLogger(__name__)
class RunInfo(TypedDict):
"""Information about a run."""
name: Optional[str]
name: str
tags: List[str]
metadata: Dict[str, Any]
run_type: str
inputs: NotRequired[Any]
def _assign_name(name: Optional[str], serialized: Dict[str, Any]) -> Optional[str]:
def _assign_name(name: Optional[str], serialized: Dict[str, Any]) -> str:
"""Assign a name to a run."""
if name is not None:
return name
@ -60,7 +58,7 @@ def _assign_name(name: Optional[str], serialized: Dict[str, Any]) -> Optional[st
return serialized["name"]
elif "id" in serialized:
return serialized["id"][-1]
return None
return "Unnamed"
class _AstreamEventHandler(AsyncCallbackHandler):
@ -378,6 +376,7 @@ class _AstreamEventHandler(AsyncCallbackHandler):
"metadata": metadata or {},
"name": name_,
"run_type": "tool",
"inputs": inputs,
}
await self._send(
@ -397,11 +396,19 @@ class _AstreamEventHandler(AsyncCallbackHandler):
async def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
"""End a trace for a tool run."""
run_info = self.run_map.pop(run_id)
if "inputs" not in run_info:
raise AssertionError(
f"Run ID {run_id} is a tool call and is expected to have "
f"inputs associated with it."
)
inputs = run_info["inputs"]
await self._send(
{
"event": "on_tool_end",
"data": {
"output": output,
"input": inputs,
},
"run_id": str(run_id),
"name": run_info["name"],
@ -455,7 +462,7 @@ class _AstreamEventHandler(AsyncCallbackHandler):
{
"event": "on_retriever_end",
"data": {
"documents": documents,
"output": documents,
},
"run_id": str(run_id),
"name": run_info["name"],
@ -472,59 +479,3 @@ class _AstreamEventHandler(AsyncCallbackHandler):
def __copy__(self) -> _AstreamEventHandler:
"""Copy the tracer."""
return self
async def _event_stream_implementation(
runnable: Runnable[Input, Output],
input: Any,
config: Optional[RunnableConfig] = None,
*,
stream: _AstreamEventHandler,
**kwargs: Any,
) -> AsyncIterator[StreamEvent]:
"""Implementation of astream_log for a given runnable.
The implementation has been factored out (at least temporarily) as both
astream_log and astream_events relies on it.
"""
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_core.runnables import ensure_config
# Assign the stream handler to the config
config = ensure_config(config)
callbacks = config.get("callbacks")
if callbacks is None:
config["callbacks"] = [stream]
elif isinstance(callbacks, list):
config["callbacks"] = callbacks + [stream]
elif isinstance(callbacks, BaseCallbackManager):
callbacks = callbacks.copy()
callbacks.add_handler(stream, inherit=True)
config["callbacks"] = callbacks
else:
raise ValueError(
f"Unexpected type for callbacks: {callbacks}."
"Expected None, list or AsyncCallbackManager."
)
# Call the runnable in streaming mode,
# add each chunk to the output stream
async def consume_astream() -> None:
try:
async for _ in runnable.astream(input, config, **kwargs):
# All the content will be picked up
pass
finally:
await stream.send_stream.aclose()
# Start the runnable in a task, so we can start consuming output
task = asyncio.create_task(consume_astream())
try:
async for log in stream:
yield log
finally:
# Wait for the runnable to finish, if not cancelled (eg. by break)
try:
await task
except asyncio.CancelledError:
pass

@ -31,7 +31,6 @@ from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import tool
from langchain_core.tracers.event_stream import (
_AstreamEventHandler,
_event_stream_implementation,
)
from tests.unit_tests.stubs import AnyStr

Loading…
Cancel
Save