Add `Runnable.astream_log()` (#10374)

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/10901/head
Nuno Campos 11 months ago committed by GitHub
parent a1ade48e8f
commit fcb5aba9f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,7 +1,7 @@
"""Base callback handler that can be used to handle callbacks in langchain."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union
from uuid import UUID
from tenacity import RetryCallState
@ -502,6 +502,9 @@ class AsyncCallbackHandler(BaseCallbackHandler):
"""Run on retriever error."""
T = TypeVar("T", bound="BaseCallbackManager")
class BaseCallbackManager(CallbackManagerMixin):
"""Base callback manager that handles callbacks from LangChain."""
@ -527,6 +530,18 @@ class BaseCallbackManager(CallbackManagerMixin):
self.metadata = metadata or {}
self.inheritable_metadata = inheritable_metadata or {}
def copy(self: T) -> T:
"""Copy the callback manager."""
return self.__class__(
handlers=self.handlers,
inheritable_handlers=self.inheritable_handlers,
parent_run_id=self.parent_run_id,
tags=self.tags,
inheritable_tags=self.inheritable_tags,
metadata=self.metadata,
inheritable_metadata=self.inheritable_metadata,
)
@property
def is_async(self) -> bool:
"""Whether the callback manager is async."""

@ -58,6 +58,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
else:
logger.debug(f"Parent run with UUID {run.parent_run_id} not found.")
self.run_map[str(run.id)] = run
self._on_run_create(run)
def _end_trace(self, run: Run) -> None:
"""End a trace for a run."""
@ -74,6 +75,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
):
parent_run.child_execution_order = run.child_execution_order
self.run_map.pop(str(run.id))
self._on_run_update(run)
def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int:
"""Get the execution order for a run."""
@ -101,7 +103,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
) -> Run:
"""Start a trace for an LLM run."""
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
@ -123,6 +125,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
)
self._start_trace(llm_run)
self._on_llm_start(llm_run)
return llm_run
def on_llm_new_token(
self,
@ -132,7 +135,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
) -> Run:
"""Run on new LLM token. Only available when streaming is enabled."""
if not run_id:
raise TracerException("No run_id provided for on_llm_new_token callback.")
@ -151,6 +154,8 @@ class BaseTracer(BaseCallbackHandler, ABC):
"kwargs": event_kwargs,
},
)
self._on_llm_new_token(llm_run, token, chunk)
return llm_run
def on_retry(
self,
@ -158,7 +163,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
*,
run_id: UUID,
**kwargs: Any,
) -> None:
) -> Run:
if not run_id:
raise TracerException("No run_id provided for on_retry callback.")
run_id_ = str(run_id)
@ -186,8 +191,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
"kwargs": retry_d,
},
)
return llm_run
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> None:
def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run:
"""End a trace for an LLM run."""
if not run_id:
raise TracerException("No run_id provided for on_llm_end callback.")
@ -208,6 +214,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
llm_run.events.append({"name": "end", "time": llm_run.end_time})
self._end_trace(llm_run)
self._on_llm_end(llm_run)
return llm_run
def on_llm_error(
self,
@ -215,7 +222,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
*,
run_id: UUID,
**kwargs: Any,
) -> None:
) -> Run:
"""Handle an error for an LLM run."""
if not run_id:
raise TracerException("No run_id provided for on_llm_error callback.")
@ -229,6 +236,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
llm_run.events.append({"name": "error", "time": llm_run.end_time})
self._end_trace(llm_run)
self._on_chain_error(llm_run)
return llm_run
def on_chain_start(
self,
@ -242,7 +250,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
run_type: Optional[str] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
) -> Run:
"""Start a trace for a chain run."""
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
@ -266,6 +274,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
)
self._start_trace(chain_run)
self._on_chain_start(chain_run)
return chain_run
def on_chain_end(
self,
@ -274,7 +283,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
run_id: UUID,
inputs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
) -> Run:
"""End a trace for a chain run."""
if not run_id:
raise TracerException("No run_id provided for on_chain_end callback.")
@ -291,6 +300,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs}
self._end_trace(chain_run)
self._on_chain_end(chain_run)
return chain_run
def on_chain_error(
self,
@ -299,7 +309,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
inputs: Optional[Dict[str, Any]] = None,
run_id: UUID,
**kwargs: Any,
) -> None:
) -> Run:
"""Handle an error for a chain run."""
if not run_id:
raise TracerException("No run_id provided for on_chain_error callback.")
@ -314,6 +324,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs}
self._end_trace(chain_run)
self._on_chain_error(chain_run)
return chain_run
def on_tool_start(
self,
@ -325,7 +336,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
) -> Run:
"""Start a trace for a tool run."""
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
@ -348,8 +359,9 @@ class BaseTracer(BaseCallbackHandler, ABC):
)
self._start_trace(tool_run)
self._on_tool_start(tool_run)
return tool_run
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None:
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> Run:
"""End a trace for a tool run."""
if not run_id:
raise TracerException("No run_id provided for on_tool_end callback.")
@ -362,6 +374,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
tool_run.events.append({"name": "end", "time": tool_run.end_time})
self._end_trace(tool_run)
self._on_tool_end(tool_run)
return tool_run
def on_tool_error(
self,
@ -369,7 +382,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
*,
run_id: UUID,
**kwargs: Any,
) -> None:
) -> Run:
"""Handle an error for a tool run."""
if not run_id:
raise TracerException("No run_id provided for on_tool_error callback.")
@ -382,6 +395,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
tool_run.events.append({"name": "error", "time": tool_run.end_time})
self._end_trace(tool_run)
self._on_tool_error(tool_run)
return tool_run
def on_retriever_start(
self,
@ -393,7 +407,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
) -> Run:
"""Run when Retriever starts running."""
parent_run_id_ = str(parent_run_id) if parent_run_id else None
execution_order = self._get_execution_order(parent_run_id_)
@ -417,6 +431,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
)
self._start_trace(retrieval_run)
self._on_retriever_start(retrieval_run)
return retrieval_run
def on_retriever_error(
self,
@ -424,7 +439,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
*,
run_id: UUID,
**kwargs: Any,
) -> None:
) -> Run:
"""Run when Retriever errors."""
if not run_id:
raise TracerException("No run_id provided for on_retriever_error callback.")
@ -437,10 +452,11 @@ class BaseTracer(BaseCallbackHandler, ABC):
retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time})
self._end_trace(retrieval_run)
self._on_retriever_error(retrieval_run)
return retrieval_run
def on_retriever_end(
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
) -> None:
) -> Run:
"""Run when Retriever ends running."""
if not run_id:
raise TracerException("No run_id provided for on_retriever_end callback.")
@ -452,6 +468,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time})
self._end_trace(retrieval_run)
self._on_retriever_end(retrieval_run)
return retrieval_run
def __deepcopy__(self, memo: dict) -> BaseTracer:
"""Deepcopy the tracer."""
@ -461,9 +478,23 @@ class BaseTracer(BaseCallbackHandler, ABC):
"""Copy the tracer."""
return self
def _on_run_create(self, run: Run) -> None:
"""Process a run upon creation."""
def _on_run_update(self, run: Run) -> None:
"""Process a run upon update."""
def _on_llm_start(self, run: Run) -> None:
"""Process the LLM Run upon start."""
def _on_llm_new_token(
self,
run: Run,
token: str,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
) -> None:
"""Process new LLM token."""
def _on_llm_end(self, run: Run) -> None:
"""Process the LLM Run."""

@ -0,0 +1,289 @@
from __future__ import annotations
import math
import threading
from typing import (
Any,
AsyncIterator,
Dict,
List,
Optional,
Sequence,
TypedDict,
Union,
)
from uuid import UUID
import jsonpatch
from anyio import create_memory_object_stream
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
class LogEntry(TypedDict):
id: str
"""ID of the sub-run."""
name: str
"""Name of the object being run."""
type: str
"""Type of the object being run, eg. prompt, chain, llm, etc."""
tags: List[str]
"""List of tags for the run."""
metadata: Dict[str, Any]
"""Key-value pairs of metadata for the run."""
start_time: str
"""ISO-8601 timestamp of when the run started."""
streamed_output_str: List[str]
"""List of LLM tokens streamed by this run, if applicable."""
final_output: Optional[Any]
"""Final output of this run.
Only available after the run has finished successfully."""
end_time: Optional[str]
"""ISO-8601 timestamp of when the run ended.
Only available after the run has finished."""
class RunState(TypedDict):
id: str
"""ID of the run."""
streamed_output: List[Any]
"""List of output chunks streamed by Runnable.stream()"""
final_output: Optional[Any]
"""Final output of the run, usually the result of aggregating streamed_output.
Only available after the run has finished successfully."""
logs: list[LogEntry]
"""List of sub-runs contained in this run, if any, in the order they were started.
If filters were supplied, this list will contain only the runs that matched the
filters."""
class RunLogPatch:
ops: List[Dict[str, Any]]
"""List of jsonpatch operations, which describe how to create the run state
from an empty dict. This is the minimal representation of the log, designed to
be serialized as JSON and sent over the wire to reconstruct the log on the other
side. Reconstruction of the state can be done with any jsonpatch-compliant library,
see https://jsonpatch.com for more information."""
def __init__(self, *ops: Dict[str, Any]) -> None:
self.ops = list(ops)
def __add__(self, other: Union[RunLogPatch, Any]) -> RunLogPatch:
if type(other) == RunLogPatch:
ops = self.ops + other.ops
state = jsonpatch.apply_patch(None, ops)
return RunLog(*ops, state=state)
raise TypeError(
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
)
def __repr__(self) -> str:
from pprint import pformat
return f"RunLogPatch(ops={pformat(self.ops)})"
def __eq__(self, other: object) -> bool:
return isinstance(other, RunLogPatch) and self.ops == other.ops
class RunLog(RunLogPatch):
state: RunState
"""Current state of the log, obtained from applying all ops in sequence."""
def __init__(self, *ops: Dict[str, Any], state: RunState) -> None:
super().__init__(*ops)
self.state = state
def __add__(self, other: Union[RunLogPatch, Any]) -> RunLogPatch:
if type(other) == RunLogPatch:
ops = self.ops + other.ops
state = jsonpatch.apply_patch(self.state, other.ops)
return RunLog(*ops, state=state)
raise TypeError(
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
)
def __repr__(self) -> str:
from pprint import pformat
return f"RunLog(state={pformat(self.state)})"
class LogStreamCallbackHandler(BaseTracer):
def __init__(
self,
*,
auto_close: bool = True,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
) -> None:
super().__init__()
self.auto_close = auto_close
self.include_names = include_names
self.include_types = include_types
self.include_tags = include_tags
self.exclude_names = exclude_names
self.exclude_types = exclude_types
self.exclude_tags = exclude_tags
send_stream, receive_stream = create_memory_object_stream(
math.inf, item_type=RunLogPatch
)
self.lock = threading.Lock()
self.send_stream = send_stream
self.receive_stream = receive_stream
self._index_map: Dict[UUID, int] = {}
def __aiter__(self) -> AsyncIterator[RunLogPatch]:
return self.receive_stream.__aiter__()
def include_run(self, run: Run) -> bool:
if run.parent_run_id is None:
return False
run_tags = run.tags or []
if (
self.include_names is None
and self.include_types is None
and self.include_tags is None
):
include = True
else:
include = False
if self.include_names is not None:
include = include or run.name in self.include_names
if self.include_types is not None:
include = include or run.run_type in self.include_types
if self.include_tags is not None:
include = include or any(tag in self.include_tags for tag in run_tags)
if self.exclude_names is not None:
include = include and run.name not in self.exclude_names
if self.exclude_types is not None:
include = include and run.run_type not in self.exclude_types
if self.exclude_tags is not None:
include = include and all(tag not in self.exclude_tags for tag in run_tags)
return include
def _persist_run(self, run: Run) -> None:
# This is a legacy method only called once for an entire run tree
# therefore not useful here
pass
def _on_run_create(self, run: Run) -> None:
"""Start a run."""
if run.parent_run_id is None:
self.send_stream.send_nowait(
RunLogPatch(
{
"op": "replace",
"path": "",
"value": RunState(
id=run.id,
streamed_output=[],
final_output=None,
logs=[],
),
}
)
)
if not self.include_run(run):
return
# Determine previous index, increment by 1
with self.lock:
self._index_map[run.id] = max(self._index_map.values(), default=-1) + 1
# Add the run to the stream
self.send_stream.send_nowait(
RunLogPatch(
{
"op": "add",
"path": f"/logs/{self._index_map[run.id]}",
"value": LogEntry(
id=str(run.id),
name=run.name,
type=run.run_type,
tags=run.tags or [],
metadata=run.extra.get("metadata", {}),
start_time=run.start_time.isoformat(timespec="milliseconds"),
streamed_output_str=[],
final_output=None,
end_time=None,
),
}
)
)
def _on_run_update(self, run: Run) -> None:
"""Finish a run."""
try:
index = self._index_map.get(run.id)
if index is None:
return
self.send_stream.send_nowait(
RunLogPatch(
{
"op": "add",
"path": f"/logs/{index}/final_output",
"value": run.outputs,
},
{
"op": "add",
"path": f"/logs/{index}/end_time",
"value": run.end_time.isoformat(timespec="milliseconds"),
},
)
)
finally:
if run.parent_run_id is None:
self.send_stream.send_nowait(
RunLogPatch(
{
"op": "replace",
"path": "/final_output",
"value": run.outputs,
}
)
)
if self.auto_close:
self.send_stream.close()
def _on_llm_new_token(
self,
run: Run,
token: str,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]],
) -> None:
"""Process new LLM token."""
index = self._index_map.get(run.id)
if index is None:
return
self.send_stream.send_nowait(
RunLogPatch(
{
"op": "add",
"path": f"/logs/{index}/streamed_output_str/-",
"value": token,
}
)
)

@ -34,6 +34,8 @@ if TYPE_CHECKING:
)
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.tracers.log_stream import LogStreamCallbackHandler, RunLogPatch
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.pydantic_v1 import Field
@ -190,6 +192,89 @@ class Runnable(Generic[Input, Output], ABC):
"""
yield await self.ainvoke(input, config, **kwargs)
async def astream_log(
self,
input: Any,
config: Optional[RunnableConfig] = None,
*,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Optional[Any],
) -> AsyncIterator[RunLogPatch]:
"""
Stream all output from a runnable, as reported to the callback system.
This includes all inner runs of LLMs, Retrievers, Tools, etc.
Output is streamed as Log objects, which include a list of
jsonpatch ops that describe how the state of the run has changed in each
step, and the final state of the run.
The jsonpatch ops can be applied in order to construct state.
"""
# Create a stream handler that will emit Log objects
stream = LogStreamCallbackHandler(
auto_close=False,
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_names=exclude_names,
exclude_types=exclude_types,
exclude_tags=exclude_tags,
)
# Assign the stream handler to the config
config = config or {}
callbacks = config.get("callbacks")
if callbacks is None:
config["callbacks"] = [stream]
elif isinstance(callbacks, list):
config["callbacks"] = callbacks + [stream]
elif isinstance(callbacks, BaseCallbackManager):
callbacks = callbacks.copy()
callbacks.inheritable_handlers.append(stream)
config["callbacks"] = callbacks
else:
raise ValueError(
f"Unexpected type for callbacks: {callbacks}."
"Expected None, list or AsyncCallbackManager."
)
# Call the runnable in streaming mode,
# add each chunk to the output stream
async def consume_astream() -> None:
try:
async for chunk in self.astream(input, config, **kwargs):
await stream.send_stream.send(
RunLogPatch(
{
"op": "add",
"path": "/streamed_output/-",
"value": chunk,
}
)
)
finally:
await stream.send_stream.aclose()
# Start the runnable in a task, so we can start consuming output
task = asyncio.create_task(consume_astream())
try:
# Yield each chunk from the output stream
async for log in stream:
yield log
finally:
# Wait for the runnable to finish, if not cancelled (eg. by break)
try:
await task
except asyncio.CancelledError:
pass
def transform(
self,
input: Iterator[Input],

@ -3610,6 +3610,20 @@ files = [
[package.dependencies]
attrs = ">=19.2.0"
[[package]]
name = "jsonpatch"
version = "1.33"
description = "Apply JSON-Patches (RFC 6902)"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*"
files = [
{file = "jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade"},
{file = "jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c"},
]
[package.dependencies]
jsonpointer = ">=1.9"
[[package]]
name = "jsonpointer"
version = "2.4"
@ -10608,4 +10622,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "11ce1c967a78f79a922b9bbbc1c00541703185e28c63b7a0a02aa5c562c36ee3"
content-hash = "3a3749b3d63be94ef11de23ec7ad40cc20cca78fa7352c5ed7d537988ce90a85"

@ -129,6 +129,8 @@ markdownify = {version = "^0.11.6", optional = true}
assemblyai = {version = "^0.17.0", optional = true}
dashvector = {version = "^1.0.1", optional = true}
sqlite-vss = {version = "^0.1.2", optional = true}
anyio = "<4.0"
jsonpatch = "^1.33"
timescale-vector = {version = "^0.0.1", optional = true}

@ -1,5 +1,5 @@
from operator import itemgetter
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Union, cast
from uuid import UUID
import pytest
@ -9,6 +9,7 @@ from syrupy import SnapshotAssertion
from langchain.callbacks.manager import Callbacks, collect_runs
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch
from langchain.callbacks.tracers.schemas import Run
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
from langchain.chat_models.fake import FakeListChatModel
@ -368,6 +369,62 @@ async def test_prompt() -> None:
part async for part in prompt.astream({"question": "What is your name?"})
] == [expected]
stream_log = [
part async for part in prompt.astream_log({"question": "What is your name?"})
]
assert len(stream_log[0].ops) == 1
assert stream_log[0].ops[0]["op"] == "replace"
assert stream_log[0].ops[0]["path"] == ""
assert stream_log[0].ops[0]["value"]["logs"] == []
assert stream_log[0].ops[0]["value"]["final_output"] is None
assert stream_log[0].ops[0]["value"]["streamed_output"] == []
assert type(stream_log[0].ops[0]["value"]["id"]) == UUID
assert stream_log[1:] == [
RunLogPatch(
{
"op": "replace",
"path": "/final_output",
"value": {
"id": ["langchain", "prompts", "chat", "ChatPromptValue"],
"kwargs": {
"messages": [
{
"id": [
"langchain",
"schema",
"messages",
"SystemMessage",
],
"kwargs": {"content": "You are a nice " "assistant."},
"lc": 1,
"type": "constructor",
},
{
"id": [
"langchain",
"schema",
"messages",
"HumanMessage",
],
"kwargs": {
"additional_kwargs": {},
"content": "What is your " "name?",
},
"lc": 1,
"type": "constructor",
},
]
},
"lc": 1,
"type": "constructor",
},
}
),
RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": expected}),
]
def test_prompt_template_params() -> None:
prompt = ChatPromptTemplate.from_template(
@ -560,7 +617,7 @@ async def test_prompt_with_llm(
mocker.stop(prompt_spy)
mocker.stop(llm_spy)
# Test stream#
# Test stream
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
llm_spy = mocker.spy(llm.__class__, "astream")
tracer = FakeTracer()
@ -578,6 +635,136 @@ async def test_prompt_with_llm(
]
)
prompt_spy.reset_mock()
llm_spy.reset_mock()
stream_log = [
part async for part in chain.astream_log({"question": "What is your name?"})
]
# remove ids from logs
for part in stream_log:
for op in part.ops:
if (
isinstance(op["value"], dict)
and "id" in op["value"]
and not isinstance(op["value"]["id"], list) # serialized lc id
):
del op["value"]["id"]
assert stream_log == [
RunLogPatch(
{
"op": "replace",
"path": "",
"value": {
"logs": [],
"final_output": None,
"streamed_output": [],
},
}
),
RunLogPatch(
{
"op": "add",
"path": "/logs/0",
"value": {
"end_time": None,
"final_output": None,
"metadata": {},
"name": "ChatPromptTemplate",
"start_time": "2023-01-01T00:00:00.000",
"streamed_output_str": [],
"tags": ["seq:step:1"],
"type": "prompt",
},
}
),
RunLogPatch(
{
"op": "add",
"path": "/logs/0/final_output",
"value": {
"id": ["langchain", "prompts", "chat", "ChatPromptValue"],
"kwargs": {
"messages": [
{
"id": [
"langchain",
"schema",
"messages",
"SystemMessage",
],
"kwargs": {
"additional_kwargs": {},
"content": "You are a nice " "assistant.",
},
"lc": 1,
"type": "constructor",
},
{
"id": [
"langchain",
"schema",
"messages",
"HumanMessage",
],
"kwargs": {
"additional_kwargs": {},
"content": "What is your " "name?",
},
"lc": 1,
"type": "constructor",
},
]
},
"lc": 1,
"type": "constructor",
},
},
{
"op": "add",
"path": "/logs/0/end_time",
"value": "2023-01-01T00:00:00.000",
},
),
RunLogPatch(
{
"op": "add",
"path": "/logs/1",
"value": {
"end_time": None,
"final_output": None,
"metadata": {},
"name": "FakeListLLM",
"start_time": "2023-01-01T00:00:00.000",
"streamed_output_str": [],
"tags": ["seq:step:2"],
"type": "llm",
},
}
),
RunLogPatch(
{
"op": "add",
"path": "/logs/1/final_output",
"value": {
"generations": [[{"generation_info": None, "text": "foo"}]],
"llm_output": None,
"run": None,
},
},
{
"op": "add",
"path": "/logs/1/end_time",
"value": "2023-01-01T00:00:00.000",
},
),
RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": "foo"}),
RunLogPatch(
{"op": "replace", "path": "/final_output", "value": {"output": "foo"}}
),
]
@pytest.mark.asyncio
@freeze_time("2023-01-01")
@ -1213,6 +1400,74 @@ async def test_map_astream() -> None:
{"question": "What is your name?"}
)
# Test astream_log state accumulation
final_state = None
streamed_ops = []
async for chunk in chain.astream_log({"question": "What is your name?"}):
streamed_ops.extend(chunk.ops)
if final_state is None:
final_state = chunk
else:
final_state += chunk
final_state = cast(RunLog, final_state)
assert final_state.state["final_output"] == final_value
assert len(final_state.state["streamed_output"]) == len(streamed_chunks)
assert isinstance(final_state.state["id"], UUID)
assert len(final_state.ops) == len(streamed_ops)
assert len(final_state.state["logs"]) == 5
assert final_state.state["logs"][0]["name"] == "ChatPromptTemplate"
assert final_state.state["logs"][0]["final_output"] == dumpd(
prompt.invoke({"question": "What is your name?"})
)
assert final_state.state["logs"][1]["name"] == "RunnableMap"
assert sorted(log["name"] for log in final_state.state["logs"][2:]) == [
"FakeListChatModel",
"FakeStreamingListLLM",
"RunnablePassthrough",
]
# Test astream_log with include filters
final_state = None
async for chunk in chain.astream_log(
{"question": "What is your name?"}, include_names=["FakeListChatModel"]
):
if final_state is None:
final_state = chunk
else:
final_state += chunk
final_state = cast(RunLog, final_state)
assert final_state.state["final_output"] == final_value
assert len(final_state.state["streamed_output"]) == len(streamed_chunks)
assert len(final_state.state["logs"]) == 1
assert final_state.state["logs"][0]["name"] == "FakeListChatModel"
# Test astream_log with exclude filters
final_state = None
async for chunk in chain.astream_log(
{"question": "What is your name?"}, exclude_names=["FakeListChatModel"]
):
if final_state is None:
final_state = chunk
else:
final_state += chunk
final_state = cast(RunLog, final_state)
assert final_state.state["final_output"] == final_value
assert len(final_state.state["streamed_output"]) == len(streamed_chunks)
assert len(final_state.state["logs"]) == 4
assert final_state.state["logs"][0]["name"] == "ChatPromptTemplate"
assert final_state.state["logs"][0]["final_output"] == dumpd(
prompt.invoke({"question": "What is your name?"})
)
assert final_state.state["logs"][1]["name"] == "RunnableMap"
assert sorted(log["name"] for log in final_state.state["logs"][2:]) == [
"FakeStreamingListLLM",
"RunnablePassthrough",
]
@pytest.mark.asyncio
async def test_map_astream_iterator_input() -> None:

@ -39,8 +39,10 @@ def test_required_dependencies(poetry_conf: Mapping[str, Any]) -> None:
"PyYAML",
"SQLAlchemy",
"aiohttp",
"anyio",
"async-timeout",
"dataclasses-json",
"jsonpatch",
"langsmith",
"numexpr",
"numpy",

Loading…
Cancel
Save