diff --git a/libs/langchain/langchain/callbacks/base.py b/libs/langchain/langchain/callbacks/base.py index 519379fc50..b29c8f8b83 100644 --- a/libs/langchain/langchain/callbacks/base.py +++ b/libs/langchain/langchain/callbacks/base.py @@ -1,7 +1,7 @@ """Base callback handler that can be used to handle callbacks in langchain.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, TypeVar, Union from uuid import UUID from tenacity import RetryCallState @@ -502,6 +502,9 @@ class AsyncCallbackHandler(BaseCallbackHandler): """Run on retriever error.""" +T = TypeVar("T", bound="BaseCallbackManager") + + class BaseCallbackManager(CallbackManagerMixin): """Base callback manager that handles callbacks from LangChain.""" @@ -527,6 +530,18 @@ class BaseCallbackManager(CallbackManagerMixin): self.metadata = metadata or {} self.inheritable_metadata = inheritable_metadata or {} + def copy(self: T) -> T: + """Copy the callback manager.""" + return self.__class__( + handlers=self.handlers, + inheritable_handlers=self.inheritable_handlers, + parent_run_id=self.parent_run_id, + tags=self.tags, + inheritable_tags=self.inheritable_tags, + metadata=self.metadata, + inheritable_metadata=self.inheritable_metadata, + ) + @property def is_async(self) -> bool: """Whether the callback manager is async.""" diff --git a/libs/langchain/langchain/callbacks/tracers/base.py b/libs/langchain/langchain/callbacks/tracers/base.py index a527798a91..5b1427be43 100644 --- a/libs/langchain/langchain/callbacks/tracers/base.py +++ b/libs/langchain/langchain/callbacks/tracers/base.py @@ -58,6 +58,7 @@ class BaseTracer(BaseCallbackHandler, ABC): else: logger.debug(f"Parent run with UUID {run.parent_run_id} not found.") self.run_map[str(run.id)] = run + self._on_run_create(run) def _end_trace(self, run: Run) -> None: """End a trace for a run.""" @@ -74,6 +75,7 @@ class BaseTracer(BaseCallbackHandler, ABC): ): parent_run.child_execution_order = run.child_execution_order self.run_map.pop(str(run.id)) + self._on_run_update(run) def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int: """Get the execution order for a run.""" @@ -101,7 +103,7 @@ class BaseTracer(BaseCallbackHandler, ABC): parent_run_id: Optional[UUID] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> None: + ) -> Run: """Start a trace for an LLM run.""" parent_run_id_ = str(parent_run_id) if parent_run_id else None execution_order = self._get_execution_order(parent_run_id_) @@ -123,6 +125,7 @@ class BaseTracer(BaseCallbackHandler, ABC): ) self._start_trace(llm_run) self._on_llm_start(llm_run) + return llm_run def on_llm_new_token( self, @@ -132,7 +135,7 @@ class BaseTracer(BaseCallbackHandler, ABC): run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, - ) -> None: + ) -> Run: """Run on new LLM token. Only available when streaming is enabled.""" if not run_id: raise TracerException("No run_id provided for on_llm_new_token callback.") @@ -151,6 +154,8 @@ class BaseTracer(BaseCallbackHandler, ABC): "kwargs": event_kwargs, }, ) + self._on_llm_new_token(llm_run, token, chunk) + return llm_run def on_retry( self, @@ -158,7 +163,7 @@ class BaseTracer(BaseCallbackHandler, ABC): *, run_id: UUID, **kwargs: Any, - ) -> None: + ) -> Run: if not run_id: raise TracerException("No run_id provided for on_retry callback.") run_id_ = str(run_id) @@ -186,8 +191,9 @@ class BaseTracer(BaseCallbackHandler, ABC): "kwargs": retry_d, }, ) + return llm_run - def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> None: + def on_llm_end(self, response: LLMResult, *, run_id: UUID, **kwargs: Any) -> Run: """End a trace for an LLM run.""" if not run_id: raise TracerException("No run_id provided for on_llm_end callback.") @@ -208,6 +214,7 @@ class BaseTracer(BaseCallbackHandler, ABC): llm_run.events.append({"name": "end", "time": llm_run.end_time}) self._end_trace(llm_run) self._on_llm_end(llm_run) + return llm_run def on_llm_error( self, @@ -215,7 +222,7 @@ class BaseTracer(BaseCallbackHandler, ABC): *, run_id: UUID, **kwargs: Any, - ) -> None: + ) -> Run: """Handle an error for an LLM run.""" if not run_id: raise TracerException("No run_id provided for on_llm_error callback.") @@ -229,6 +236,7 @@ class BaseTracer(BaseCallbackHandler, ABC): llm_run.events.append({"name": "error", "time": llm_run.end_time}) self._end_trace(llm_run) self._on_chain_error(llm_run) + return llm_run def on_chain_start( self, @@ -242,7 +250,7 @@ class BaseTracer(BaseCallbackHandler, ABC): run_type: Optional[str] = None, name: Optional[str] = None, **kwargs: Any, - ) -> None: + ) -> Run: """Start a trace for a chain run.""" parent_run_id_ = str(parent_run_id) if parent_run_id else None execution_order = self._get_execution_order(parent_run_id_) @@ -266,6 +274,7 @@ class BaseTracer(BaseCallbackHandler, ABC): ) self._start_trace(chain_run) self._on_chain_start(chain_run) + return chain_run def on_chain_end( self, @@ -274,7 +283,7 @@ class BaseTracer(BaseCallbackHandler, ABC): run_id: UUID, inputs: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> None: + ) -> Run: """End a trace for a chain run.""" if not run_id: raise TracerException("No run_id provided for on_chain_end callback.") @@ -291,6 +300,7 @@ class BaseTracer(BaseCallbackHandler, ABC): chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs} self._end_trace(chain_run) self._on_chain_end(chain_run) + return chain_run def on_chain_error( self, @@ -299,7 +309,7 @@ class BaseTracer(BaseCallbackHandler, ABC): inputs: Optional[Dict[str, Any]] = None, run_id: UUID, **kwargs: Any, - ) -> None: + ) -> Run: """Handle an error for a chain run.""" if not run_id: raise TracerException("No run_id provided for on_chain_error callback.") @@ -314,6 +324,7 @@ class BaseTracer(BaseCallbackHandler, ABC): chain_run.inputs = inputs if isinstance(inputs, dict) else {"input": inputs} self._end_trace(chain_run) self._on_chain_error(chain_run) + return chain_run def on_tool_start( self, @@ -325,7 +336,7 @@ class BaseTracer(BaseCallbackHandler, ABC): parent_run_id: Optional[UUID] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> None: + ) -> Run: """Start a trace for a tool run.""" parent_run_id_ = str(parent_run_id) if parent_run_id else None execution_order = self._get_execution_order(parent_run_id_) @@ -348,8 +359,9 @@ class BaseTracer(BaseCallbackHandler, ABC): ) self._start_trace(tool_run) self._on_tool_start(tool_run) + return tool_run - def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None: + def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> Run: """End a trace for a tool run.""" if not run_id: raise TracerException("No run_id provided for on_tool_end callback.") @@ -362,6 +374,7 @@ class BaseTracer(BaseCallbackHandler, ABC): tool_run.events.append({"name": "end", "time": tool_run.end_time}) self._end_trace(tool_run) self._on_tool_end(tool_run) + return tool_run def on_tool_error( self, @@ -369,7 +382,7 @@ class BaseTracer(BaseCallbackHandler, ABC): *, run_id: UUID, **kwargs: Any, - ) -> None: + ) -> Run: """Handle an error for a tool run.""" if not run_id: raise TracerException("No run_id provided for on_tool_error callback.") @@ -382,6 +395,7 @@ class BaseTracer(BaseCallbackHandler, ABC): tool_run.events.append({"name": "error", "time": tool_run.end_time}) self._end_trace(tool_run) self._on_tool_error(tool_run) + return tool_run def on_retriever_start( self, @@ -393,7 +407,7 @@ class BaseTracer(BaseCallbackHandler, ABC): tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> None: + ) -> Run: """Run when Retriever starts running.""" parent_run_id_ = str(parent_run_id) if parent_run_id else None execution_order = self._get_execution_order(parent_run_id_) @@ -417,6 +431,7 @@ class BaseTracer(BaseCallbackHandler, ABC): ) self._start_trace(retrieval_run) self._on_retriever_start(retrieval_run) + return retrieval_run def on_retriever_error( self, @@ -424,7 +439,7 @@ class BaseTracer(BaseCallbackHandler, ABC): *, run_id: UUID, **kwargs: Any, - ) -> None: + ) -> Run: """Run when Retriever errors.""" if not run_id: raise TracerException("No run_id provided for on_retriever_error callback.") @@ -437,10 +452,11 @@ class BaseTracer(BaseCallbackHandler, ABC): retrieval_run.events.append({"name": "error", "time": retrieval_run.end_time}) self._end_trace(retrieval_run) self._on_retriever_error(retrieval_run) + return retrieval_run def on_retriever_end( self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any - ) -> None: + ) -> Run: """Run when Retriever ends running.""" if not run_id: raise TracerException("No run_id provided for on_retriever_end callback.") @@ -452,6 +468,7 @@ class BaseTracer(BaseCallbackHandler, ABC): retrieval_run.events.append({"name": "end", "time": retrieval_run.end_time}) self._end_trace(retrieval_run) self._on_retriever_end(retrieval_run) + return retrieval_run def __deepcopy__(self, memo: dict) -> BaseTracer: """Deepcopy the tracer.""" @@ -461,9 +478,23 @@ class BaseTracer(BaseCallbackHandler, ABC): """Copy the tracer.""" return self + def _on_run_create(self, run: Run) -> None: + """Process a run upon creation.""" + + def _on_run_update(self, run: Run) -> None: + """Process a run upon update.""" + def _on_llm_start(self, run: Run) -> None: """Process the LLM Run upon start.""" + def _on_llm_new_token( + self, + run: Run, + token: str, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], + ) -> None: + """Process new LLM token.""" + def _on_llm_end(self, run: Run) -> None: """Process the LLM Run.""" diff --git a/libs/langchain/langchain/callbacks/tracers/log_stream.py b/libs/langchain/langchain/callbacks/tracers/log_stream.py new file mode 100644 index 0000000000..c4978f88bf --- /dev/null +++ b/libs/langchain/langchain/callbacks/tracers/log_stream.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +import math +import threading +from typing import ( + Any, + AsyncIterator, + Dict, + List, + Optional, + Sequence, + TypedDict, + Union, +) +from uuid import UUID + +import jsonpatch +from anyio import create_memory_object_stream + +from langchain.callbacks.tracers.base import BaseTracer +from langchain.callbacks.tracers.schemas import Run +from langchain.schema.output import ChatGenerationChunk, GenerationChunk + + +class LogEntry(TypedDict): + id: str + """ID of the sub-run.""" + name: str + """Name of the object being run.""" + type: str + """Type of the object being run, eg. prompt, chain, llm, etc.""" + tags: List[str] + """List of tags for the run.""" + metadata: Dict[str, Any] + """Key-value pairs of metadata for the run.""" + start_time: str + """ISO-8601 timestamp of when the run started.""" + + streamed_output_str: List[str] + """List of LLM tokens streamed by this run, if applicable.""" + final_output: Optional[Any] + """Final output of this run. + Only available after the run has finished successfully.""" + end_time: Optional[str] + """ISO-8601 timestamp of when the run ended. + Only available after the run has finished.""" + + +class RunState(TypedDict): + id: str + """ID of the run.""" + streamed_output: List[Any] + """List of output chunks streamed by Runnable.stream()""" + final_output: Optional[Any] + """Final output of the run, usually the result of aggregating streamed_output. + Only available after the run has finished successfully.""" + + logs: list[LogEntry] + """List of sub-runs contained in this run, if any, in the order they were started. + If filters were supplied, this list will contain only the runs that matched the + filters.""" + + +class RunLogPatch: + ops: List[Dict[str, Any]] + """List of jsonpatch operations, which describe how to create the run state + from an empty dict. This is the minimal representation of the log, designed to + be serialized as JSON and sent over the wire to reconstruct the log on the other + side. Reconstruction of the state can be done with any jsonpatch-compliant library, + see https://jsonpatch.com for more information.""" + + def __init__(self, *ops: Dict[str, Any]) -> None: + self.ops = list(ops) + + def __add__(self, other: Union[RunLogPatch, Any]) -> RunLogPatch: + if type(other) == RunLogPatch: + ops = self.ops + other.ops + state = jsonpatch.apply_patch(None, ops) + return RunLog(*ops, state=state) + + raise TypeError( + f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" + ) + + def __repr__(self) -> str: + from pprint import pformat + + return f"RunLogPatch(ops={pformat(self.ops)})" + + def __eq__(self, other: object) -> bool: + return isinstance(other, RunLogPatch) and self.ops == other.ops + + +class RunLog(RunLogPatch): + state: RunState + """Current state of the log, obtained from applying all ops in sequence.""" + + def __init__(self, *ops: Dict[str, Any], state: RunState) -> None: + super().__init__(*ops) + self.state = state + + def __add__(self, other: Union[RunLogPatch, Any]) -> RunLogPatch: + if type(other) == RunLogPatch: + ops = self.ops + other.ops + state = jsonpatch.apply_patch(self.state, other.ops) + return RunLog(*ops, state=state) + + raise TypeError( + f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" + ) + + def __repr__(self) -> str: + from pprint import pformat + + return f"RunLog(state={pformat(self.state)})" + + +class LogStreamCallbackHandler(BaseTracer): + def __init__( + self, + *, + auto_close: bool = True, + include_names: Optional[Sequence[str]] = None, + include_types: Optional[Sequence[str]] = None, + include_tags: Optional[Sequence[str]] = None, + exclude_names: Optional[Sequence[str]] = None, + exclude_types: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, + ) -> None: + super().__init__() + + self.auto_close = auto_close + self.include_names = include_names + self.include_types = include_types + self.include_tags = include_tags + self.exclude_names = exclude_names + self.exclude_types = exclude_types + self.exclude_tags = exclude_tags + + send_stream, receive_stream = create_memory_object_stream( + math.inf, item_type=RunLogPatch + ) + self.lock = threading.Lock() + self.send_stream = send_stream + self.receive_stream = receive_stream + self._index_map: Dict[UUID, int] = {} + + def __aiter__(self) -> AsyncIterator[RunLogPatch]: + return self.receive_stream.__aiter__() + + def include_run(self, run: Run) -> bool: + if run.parent_run_id is None: + return False + + run_tags = run.tags or [] + + if ( + self.include_names is None + and self.include_types is None + and self.include_tags is None + ): + include = True + else: + include = False + + if self.include_names is not None: + include = include or run.name in self.include_names + if self.include_types is not None: + include = include or run.run_type in self.include_types + if self.include_tags is not None: + include = include or any(tag in self.include_tags for tag in run_tags) + + if self.exclude_names is not None: + include = include and run.name not in self.exclude_names + if self.exclude_types is not None: + include = include and run.run_type not in self.exclude_types + if self.exclude_tags is not None: + include = include and all(tag not in self.exclude_tags for tag in run_tags) + + return include + + def _persist_run(self, run: Run) -> None: + # This is a legacy method only called once for an entire run tree + # therefore not useful here + pass + + def _on_run_create(self, run: Run) -> None: + """Start a run.""" + if run.parent_run_id is None: + self.send_stream.send_nowait( + RunLogPatch( + { + "op": "replace", + "path": "", + "value": RunState( + id=run.id, + streamed_output=[], + final_output=None, + logs=[], + ), + } + ) + ) + + if not self.include_run(run): + return + + # Determine previous index, increment by 1 + with self.lock: + self._index_map[run.id] = max(self._index_map.values(), default=-1) + 1 + + # Add the run to the stream + self.send_stream.send_nowait( + RunLogPatch( + { + "op": "add", + "path": f"/logs/{self._index_map[run.id]}", + "value": LogEntry( + id=str(run.id), + name=run.name, + type=run.run_type, + tags=run.tags or [], + metadata=run.extra.get("metadata", {}), + start_time=run.start_time.isoformat(timespec="milliseconds"), + streamed_output_str=[], + final_output=None, + end_time=None, + ), + } + ) + ) + + def _on_run_update(self, run: Run) -> None: + """Finish a run.""" + try: + index = self._index_map.get(run.id) + + if index is None: + return + + self.send_stream.send_nowait( + RunLogPatch( + { + "op": "add", + "path": f"/logs/{index}/final_output", + "value": run.outputs, + }, + { + "op": "add", + "path": f"/logs/{index}/end_time", + "value": run.end_time.isoformat(timespec="milliseconds"), + }, + ) + ) + finally: + if run.parent_run_id is None: + self.send_stream.send_nowait( + RunLogPatch( + { + "op": "replace", + "path": "/final_output", + "value": run.outputs, + } + ) + ) + if self.auto_close: + self.send_stream.close() + + def _on_llm_new_token( + self, + run: Run, + token: str, + chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]], + ) -> None: + """Process new LLM token.""" + index = self._index_map.get(run.id) + + if index is None: + return + + self.send_stream.send_nowait( + RunLogPatch( + { + "op": "add", + "path": f"/logs/{index}/streamed_output_str/-", + "value": token, + } + ) + ) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 1473e58b79..709fb53309 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -34,6 +34,8 @@ if TYPE_CHECKING: ) +from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.tracers.log_stream import LogStreamCallbackHandler, RunLogPatch from langchain.load.dump import dumpd from langchain.load.serializable import Serializable from langchain.pydantic_v1 import Field @@ -190,6 +192,89 @@ class Runnable(Generic[Input, Output], ABC): """ yield await self.ainvoke(input, config, **kwargs) + async def astream_log( + self, + input: Any, + config: Optional[RunnableConfig] = None, + *, + include_names: Optional[Sequence[str]] = None, + include_types: Optional[Sequence[str]] = None, + include_tags: Optional[Sequence[str]] = None, + exclude_names: Optional[Sequence[str]] = None, + exclude_types: Optional[Sequence[str]] = None, + exclude_tags: Optional[Sequence[str]] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[RunLogPatch]: + """ + Stream all output from a runnable, as reported to the callback system. + This includes all inner runs of LLMs, Retrievers, Tools, etc. + + Output is streamed as Log objects, which include a list of + jsonpatch ops that describe how the state of the run has changed in each + step, and the final state of the run. + + The jsonpatch ops can be applied in order to construct state. + """ + + # Create a stream handler that will emit Log objects + stream = LogStreamCallbackHandler( + auto_close=False, + include_names=include_names, + include_types=include_types, + include_tags=include_tags, + exclude_names=exclude_names, + exclude_types=exclude_types, + exclude_tags=exclude_tags, + ) + + # Assign the stream handler to the config + config = config or {} + callbacks = config.get("callbacks") + if callbacks is None: + config["callbacks"] = [stream] + elif isinstance(callbacks, list): + config["callbacks"] = callbacks + [stream] + elif isinstance(callbacks, BaseCallbackManager): + callbacks = callbacks.copy() + callbacks.inheritable_handlers.append(stream) + config["callbacks"] = callbacks + else: + raise ValueError( + f"Unexpected type for callbacks: {callbacks}." + "Expected None, list or AsyncCallbackManager." + ) + + # Call the runnable in streaming mode, + # add each chunk to the output stream + async def consume_astream() -> None: + try: + async for chunk in self.astream(input, config, **kwargs): + await stream.send_stream.send( + RunLogPatch( + { + "op": "add", + "path": "/streamed_output/-", + "value": chunk, + } + ) + ) + finally: + await stream.send_stream.aclose() + + # Start the runnable in a task, so we can start consuming output + task = asyncio.create_task(consume_astream()) + + try: + # Yield each chunk from the output stream + async for log in stream: + yield log + finally: + # Wait for the runnable to finish, if not cancelled (eg. by break) + try: + await task + except asyncio.CancelledError: + pass + def transform( self, input: Iterator[Input], diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index 07fdf0b066..e6d1a54385 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -3610,6 +3610,20 @@ files = [ [package.dependencies] attrs = ">=19.2.0" +[[package]] +name = "jsonpatch" +version = "1.33" +description = "Apply JSON-Patches (RFC 6902)" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" +files = [ + {file = "jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade"}, + {file = "jsonpatch-1.33.tar.gz", hash = "sha256:9fcd4009c41e6d12348b4a0ff2563ba56a2923a7dfee731d004e212e1ee5030c"}, +] + +[package.dependencies] +jsonpointer = ">=1.9" + [[package]] name = "jsonpointer" version = "2.4" @@ -10608,4 +10622,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "11ce1c967a78f79a922b9bbbc1c00541703185e28c63b7a0a02aa5c562c36ee3" +content-hash = "3a3749b3d63be94ef11de23ec7ad40cc20cca78fa7352c5ed7d537988ce90a85" diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 9d9aa654f2..324d6b4353 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -129,6 +129,8 @@ markdownify = {version = "^0.11.6", optional = true} assemblyai = {version = "^0.17.0", optional = true} dashvector = {version = "^1.0.1", optional = true} sqlite-vss = {version = "^0.1.2", optional = true} +anyio = "<4.0" +jsonpatch = "^1.33" timescale-vector = {version = "^0.0.1", optional = true} diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index dc59e65cfd..7fed5d6648 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1,5 +1,5 @@ from operator import itemgetter -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Union, cast from uuid import UUID import pytest @@ -9,6 +9,7 @@ from syrupy import SnapshotAssertion from langchain.callbacks.manager import Callbacks, collect_runs from langchain.callbacks.tracers.base import BaseTracer +from langchain.callbacks.tracers.log_stream import RunLog, RunLogPatch from langchain.callbacks.tracers.schemas import Run from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler from langchain.chat_models.fake import FakeListChatModel @@ -368,6 +369,62 @@ async def test_prompt() -> None: part async for part in prompt.astream({"question": "What is your name?"}) ] == [expected] + stream_log = [ + part async for part in prompt.astream_log({"question": "What is your name?"}) + ] + + assert len(stream_log[0].ops) == 1 + assert stream_log[0].ops[0]["op"] == "replace" + assert stream_log[0].ops[0]["path"] == "" + assert stream_log[0].ops[0]["value"]["logs"] == [] + assert stream_log[0].ops[0]["value"]["final_output"] is None + assert stream_log[0].ops[0]["value"]["streamed_output"] == [] + assert type(stream_log[0].ops[0]["value"]["id"]) == UUID + + assert stream_log[1:] == [ + RunLogPatch( + { + "op": "replace", + "path": "/final_output", + "value": { + "id": ["langchain", "prompts", "chat", "ChatPromptValue"], + "kwargs": { + "messages": [ + { + "id": [ + "langchain", + "schema", + "messages", + "SystemMessage", + ], + "kwargs": {"content": "You are a nice " "assistant."}, + "lc": 1, + "type": "constructor", + }, + { + "id": [ + "langchain", + "schema", + "messages", + "HumanMessage", + ], + "kwargs": { + "additional_kwargs": {}, + "content": "What is your " "name?", + }, + "lc": 1, + "type": "constructor", + }, + ] + }, + "lc": 1, + "type": "constructor", + }, + } + ), + RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": expected}), + ] + def test_prompt_template_params() -> None: prompt = ChatPromptTemplate.from_template( @@ -560,7 +617,7 @@ async def test_prompt_with_llm( mocker.stop(prompt_spy) mocker.stop(llm_spy) - # Test stream# + # Test stream prompt_spy = mocker.spy(prompt.__class__, "ainvoke") llm_spy = mocker.spy(llm.__class__, "astream") tracer = FakeTracer() @@ -578,6 +635,136 @@ async def test_prompt_with_llm( ] ) + prompt_spy.reset_mock() + llm_spy.reset_mock() + stream_log = [ + part async for part in chain.astream_log({"question": "What is your name?"}) + ] + + # remove ids from logs + for part in stream_log: + for op in part.ops: + if ( + isinstance(op["value"], dict) + and "id" in op["value"] + and not isinstance(op["value"]["id"], list) # serialized lc id + ): + del op["value"]["id"] + + assert stream_log == [ + RunLogPatch( + { + "op": "replace", + "path": "", + "value": { + "logs": [], + "final_output": None, + "streamed_output": [], + }, + } + ), + RunLogPatch( + { + "op": "add", + "path": "/logs/0", + "value": { + "end_time": None, + "final_output": None, + "metadata": {}, + "name": "ChatPromptTemplate", + "start_time": "2023-01-01T00:00:00.000", + "streamed_output_str": [], + "tags": ["seq:step:1"], + "type": "prompt", + }, + } + ), + RunLogPatch( + { + "op": "add", + "path": "/logs/0/final_output", + "value": { + "id": ["langchain", "prompts", "chat", "ChatPromptValue"], + "kwargs": { + "messages": [ + { + "id": [ + "langchain", + "schema", + "messages", + "SystemMessage", + ], + "kwargs": { + "additional_kwargs": {}, + "content": "You are a nice " "assistant.", + }, + "lc": 1, + "type": "constructor", + }, + { + "id": [ + "langchain", + "schema", + "messages", + "HumanMessage", + ], + "kwargs": { + "additional_kwargs": {}, + "content": "What is your " "name?", + }, + "lc": 1, + "type": "constructor", + }, + ] + }, + "lc": 1, + "type": "constructor", + }, + }, + { + "op": "add", + "path": "/logs/0/end_time", + "value": "2023-01-01T00:00:00.000", + }, + ), + RunLogPatch( + { + "op": "add", + "path": "/logs/1", + "value": { + "end_time": None, + "final_output": None, + "metadata": {}, + "name": "FakeListLLM", + "start_time": "2023-01-01T00:00:00.000", + "streamed_output_str": [], + "tags": ["seq:step:2"], + "type": "llm", + }, + } + ), + RunLogPatch( + { + "op": "add", + "path": "/logs/1/final_output", + "value": { + "generations": [[{"generation_info": None, "text": "foo"}]], + "llm_output": None, + "run": None, + }, + }, + { + "op": "add", + "path": "/logs/1/end_time", + "value": "2023-01-01T00:00:00.000", + }, + ), + RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": "foo"}), + RunLogPatch( + {"op": "replace", "path": "/final_output", "value": {"output": "foo"}} + ), + ] + @pytest.mark.asyncio @freeze_time("2023-01-01") @@ -1213,6 +1400,74 @@ async def test_map_astream() -> None: {"question": "What is your name?"} ) + # Test astream_log state accumulation + + final_state = None + streamed_ops = [] + async for chunk in chain.astream_log({"question": "What is your name?"}): + streamed_ops.extend(chunk.ops) + if final_state is None: + final_state = chunk + else: + final_state += chunk + final_state = cast(RunLog, final_state) + + assert final_state.state["final_output"] == final_value + assert len(final_state.state["streamed_output"]) == len(streamed_chunks) + assert isinstance(final_state.state["id"], UUID) + assert len(final_state.ops) == len(streamed_ops) + assert len(final_state.state["logs"]) == 5 + assert final_state.state["logs"][0]["name"] == "ChatPromptTemplate" + assert final_state.state["logs"][0]["final_output"] == dumpd( + prompt.invoke({"question": "What is your name?"}) + ) + assert final_state.state["logs"][1]["name"] == "RunnableMap" + assert sorted(log["name"] for log in final_state.state["logs"][2:]) == [ + "FakeListChatModel", + "FakeStreamingListLLM", + "RunnablePassthrough", + ] + + # Test astream_log with include filters + final_state = None + async for chunk in chain.astream_log( + {"question": "What is your name?"}, include_names=["FakeListChatModel"] + ): + if final_state is None: + final_state = chunk + else: + final_state += chunk + final_state = cast(RunLog, final_state) + + assert final_state.state["final_output"] == final_value + assert len(final_state.state["streamed_output"]) == len(streamed_chunks) + assert len(final_state.state["logs"]) == 1 + assert final_state.state["logs"][0]["name"] == "FakeListChatModel" + + # Test astream_log with exclude filters + final_state = None + async for chunk in chain.astream_log( + {"question": "What is your name?"}, exclude_names=["FakeListChatModel"] + ): + if final_state is None: + final_state = chunk + else: + final_state += chunk + final_state = cast(RunLog, final_state) + + assert final_state.state["final_output"] == final_value + assert len(final_state.state["streamed_output"]) == len(streamed_chunks) + assert len(final_state.state["logs"]) == 4 + assert final_state.state["logs"][0]["name"] == "ChatPromptTemplate" + assert final_state.state["logs"][0]["final_output"] == dumpd( + prompt.invoke({"question": "What is your name?"}) + ) + assert final_state.state["logs"][1]["name"] == "RunnableMap" + assert sorted(log["name"] for log in final_state.state["logs"][2:]) == [ + "FakeStreamingListLLM", + "RunnablePassthrough", + ] + @pytest.mark.asyncio async def test_map_astream_iterator_input() -> None: diff --git a/libs/langchain/tests/unit_tests/test_dependencies.py b/libs/langchain/tests/unit_tests/test_dependencies.py index 2e40a7ab4f..cf21b90ab5 100644 --- a/libs/langchain/tests/unit_tests/test_dependencies.py +++ b/libs/langchain/tests/unit_tests/test_dependencies.py @@ -39,8 +39,10 @@ def test_required_dependencies(poetry_conf: Mapping[str, Any]) -> None: "PyYAML", "SQLAlchemy", "aiohttp", + "anyio", "async-timeout", "dataclasses-json", + "jsonpatch", "langsmith", "numexpr", "numpy",