From 0f255bb6c46f64f106910a8c955ec5765875eec2 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 28 Nov 2023 21:50:41 +0000 Subject: [PATCH] In Runnable.stream_log build up final_output from adding output chunks (#12781) Add arg to omit streamed_output list, in cases where final_output is enough this saves bandwidth --- .../output_parsers/transform.py | 6 +- .../core/langchain_core/runnables/__init__.py | 4 + libs/core/langchain_core/runnables/base.py | 47 +++++++++-- .../core/langchain_core/tracers/log_stream.py | 15 +--- .../unit_tests/runnables/test_imports.py | 2 + .../unit_tests/runnables/test_runnable.py | 78 +++++++++++++++++-- 6 files changed, 125 insertions(+), 27 deletions(-) diff --git a/libs/core/langchain_core/output_parsers/transform.py b/libs/core/langchain_core/output_parsers/transform.py index 7d37ccf90d..96688174d4 100644 --- a/libs/core/langchain_core/output_parsers/transform.py +++ b/libs/core/langchain_core/output_parsers/transform.py @@ -92,7 +92,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): if acc_gen is None: acc_gen = chunk_gen else: - acc_gen += chunk_gen + acc_gen = acc_gen + chunk_gen parsed = self.parse_result([acc_gen], partial=True) if parsed is not None and parsed != prev_parsed: @@ -120,9 +120,9 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]): if acc_gen is None: acc_gen = chunk_gen else: - acc_gen += chunk_gen + acc_gen = acc_gen + chunk_gen - parsed = self.parse_result([acc_gen], partial=True) + parsed = await self.aparse_result([acc_gen], partial=True) if parsed is not None and parsed != prev_parsed: if self.diff: yield self._diff(prev_parsed, parsed) diff --git a/libs/core/langchain_core/runnables/__init__.py b/libs/core/langchain_core/runnables/__init__.py index b5940562a1..b51a94eea3 100644 --- a/libs/core/langchain_core/runnables/__init__.py +++ b/libs/core/langchain_core/runnables/__init__.py @@ -34,13 +34,16 @@ from langchain_core.runnables.fallbacks import RunnableWithFallbacks from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.runnables.router import RouterInput, RouterRunnable from langchain_core.runnables.utils import ( + AddableDict, ConfigurableField, ConfigurableFieldMultiOption, ConfigurableFieldSingleOption, + aadd, add, ) __all__ = [ + "AddableDict", "ConfigurableField", "ConfigurableFieldSingleOption", "ConfigurableFieldMultiOption", @@ -60,5 +63,6 @@ __all__ = [ "RunnableSequence", "RunnableWithFallbacks", "get_config_list", + "aadd", "add", ] diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 3796d4cc7b..1943de8285 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -5,6 +5,7 @@ import inspect import threading from abc import ABC, abstractmethod from concurrent.futures import FIRST_COMPLETED, wait +from copy import deepcopy from functools import partial from itertools import tee from operator import itemgetter @@ -31,7 +32,7 @@ from typing import ( from typing_extensions import Literal, get_args -from langchain_core.load.dump import dumpd +from langchain_core.load.dump import dumpd, dumps from langchain_core.load.serializable import Serializable from langchain_core.pydantic_v1 import BaseModel, Field, create_model from langchain_core.runnables.config import ( @@ -507,6 +508,7 @@ class Runnable(Generic[Input, Output], ABC): config: Optional[RunnableConfig] = None, *, diff: Literal[True] = True, + with_streamed_output_list: bool = True, include_names: Optional[Sequence[str]] = None, include_types: Optional[Sequence[str]] = None, include_tags: Optional[Sequence[str]] = None, @@ -524,6 +526,7 @@ class Runnable(Generic[Input, Output], ABC): config: Optional[RunnableConfig] = None, *, diff: Literal[False], + with_streamed_output_list: bool = True, include_names: Optional[Sequence[str]] = None, include_types: Optional[Sequence[str]] = None, include_tags: Optional[Sequence[str]] = None, @@ -540,6 +543,7 @@ class Runnable(Generic[Input, Output], ABC): config: Optional[RunnableConfig] = None, *, diff: bool = True, + with_streamed_output_list: bool = True, include_names: Optional[Sequence[str]] = None, include_types: Optional[Sequence[str]] = None, include_tags: Optional[Sequence[str]] = None, @@ -557,7 +561,20 @@ class Runnable(Generic[Input, Output], ABC): step, and the final state of the run. The jsonpatch ops can be applied in order to construct state. + + Args: + input: The input to the runnable. + config: The config to use for the runnable. + diff: Whether to yield diffs between each step, or the current state. + with_streamed_output_list: Whether to yield the streamed_output list. + include_names: Only include logs with these names. + include_types: Only include logs with these types. + include_tags: Only include logs with these tags. + exclude_names: Exclude logs with these names. + exclude_types: Exclude logs with these types. + exclude_tags: Exclude logs with these tags. """ + import jsonpatch # type: ignore[import] from langchain_core.callbacks.base import BaseCallbackManager from langchain_core.tracers.log_stream import ( @@ -598,16 +615,36 @@ class Runnable(Generic[Input, Output], ABC): # add each chunk to the output stream async def consume_astream() -> None: try: + prev_final_output: Optional[Output] = None + final_output: Optional[Output] = None + async for chunk in self.astream(input, config, **kwargs): - await stream.send_stream.send( - RunLogPatch( + prev_final_output = final_output + if final_output is None: + final_output = chunk + else: + try: + final_output = final_output + chunk # type: ignore + except TypeError: + final_output = chunk + patches: List[Dict[str, Any]] = [] + if with_streamed_output_list: + patches.append( { "op": "add", "path": "/streamed_output/-", - "value": chunk, + # chunk cannot be shared between + # streamed_output and final_output + # otherwise jsonpatch.apply will + # modify both + "value": deepcopy(chunk), } ) - ) + for op in jsonpatch.JsonPatch.from_diff( + prev_final_output, final_output, dumps=dumps + ): + patches.append({**op, "path": f"/final_output{op['path']}"}) + await stream.send_stream.send(RunLogPatch(*patches)) finally: await stream.send_stream.aclose() diff --git a/libs/core/langchain_core/tracers/log_stream.py b/libs/core/langchain_core/tracers/log_stream.py index 7bd030691a..98189ad19c 100644 --- a/libs/core/langchain_core/tracers/log_stream.py +++ b/libs/core/langchain_core/tracers/log_stream.py @@ -59,7 +59,7 @@ class RunState(TypedDict): """List of output chunks streamed by Runnable.stream()""" final_output: Optional[Any] """Final output of the run, usually the result of aggregating (`+`) streamed_output. - Only available after the run has finished successfully.""" + Updated throughout the run when supported by the Runnable.""" logs: Dict[str, LogEntry] """Map of run names to sub-runs. If filters were supplied, this list will @@ -151,9 +151,7 @@ class LogStreamCallbackHandler(BaseTracer): send_stream: Any receive_stream: Any - send_stream, receive_stream = create_memory_object_stream( - math.inf, item_type=RunLogPatch - ) + send_stream, receive_stream = create_memory_object_stream(math.inf) self.lock = threading.Lock() self.send_stream = send_stream self.receive_stream = receive_stream @@ -278,15 +276,6 @@ class LogStreamCallbackHandler(BaseTracer): ) finally: if run.id == self.root_id: - self.send_stream.send_nowait( - RunLogPatch( - { - "op": "replace", - "path": "/final_output", - "value": load(run.outputs), - } - ) - ) if self.auto_close: self.send_stream.close() diff --git a/libs/core/tests/unit_tests/runnables/test_imports.py b/libs/core/tests/unit_tests/runnables/test_imports.py index bd873ff57b..935571ed12 100644 --- a/libs/core/tests/unit_tests/runnables/test_imports.py +++ b/libs/core/tests/unit_tests/runnables/test_imports.py @@ -1,6 +1,7 @@ from langchain_core.runnables import __all__ EXPECTED_ALL = [ + "AddableDict", "ConfigurableField", "ConfigurableFieldSingleOption", "ConfigurableFieldMultiOption", @@ -20,6 +21,7 @@ EXPECTED_ALL = [ "RunnableSequence", "RunnableWithFallbacks", "get_config_list", + "aadd", "add", ] diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 0ee3915966..161279945d 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -49,6 +49,7 @@ from langchain_core.prompts import ( from langchain_core.pydantic_v1 import BaseModel from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import ( + AddableDict, ConfigurableField, ConfigurableFieldMultiOption, ConfigurableFieldSingleOption, @@ -1542,6 +1543,7 @@ async def test_prompt() -> None: assert stream_log[1:] == [ RunLogPatch( + {"op": "add", "path": "/streamed_output/-", "value": expected}, { "op": "replace", "path": "/final_output", @@ -1551,9 +1553,8 @@ async def test_prompt() -> None: HumanMessage(content="What is your name?"), ] ), - } + }, ), - RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": expected}), ] stream_log_state = [ @@ -1612,6 +1613,7 @@ async def test_prompt() -> None: assert stream_log_nested[1:] == [ RunLogPatch( + {"op": "add", "path": "/streamed_output/-", "value": expected}, { "op": "replace", "path": "/final_output", @@ -1621,9 +1623,8 @@ async def test_prompt() -> None: HumanMessage(content="What is your name?"), ] ), - } + }, ), - RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": expected}), ] @@ -2107,9 +2108,9 @@ async def test_prompt_with_llm( "value": "2023-01-01T00:00:00.000", }, ), - RunLogPatch({"op": "add", "path": "/streamed_output/-", "value": "foo"}), RunLogPatch( - {"op": "replace", "path": "/final_output", "value": {"output": "foo"}} + {"op": "add", "path": "/streamed_output/-", "value": "foo"}, + {"op": "replace", "path": "/final_output", "value": "foo"}, ), ] @@ -2154,6 +2155,71 @@ async def test_stream_log_retriever() -> None: ] +@freeze_time("2023-01-01") +async def test_stream_log_lists() -> None: + async def list_producer(input: AsyncIterator[Any]) -> AsyncIterator[AddableDict]: + for i in range(4): + yield AddableDict(alist=[str(i)]) + + chain: Runnable = RunnableGenerator(list_producer) + + stream_log = [ + part async for part in chain.astream_log({"question": "What is your name?"}) + ] + + # remove ids from logs + for part in stream_log: + for op in part.ops: + if ( + isinstance(op["value"], dict) + and "id" in op["value"] + and not isinstance(op["value"]["id"], list) # serialized lc id + ): + del op["value"]["id"] + + assert stream_log == [ + RunLogPatch( + { + "op": "replace", + "path": "", + "value": {"final_output": None, "logs": {}, "streamed_output": []}, + } + ), + RunLogPatch( + {"op": "add", "path": "/streamed_output/-", "value": {"alist": ["0"]}}, + {"op": "replace", "path": "/final_output", "value": {"alist": ["0"]}}, + ), + RunLogPatch( + {"op": "add", "path": "/streamed_output/-", "value": {"alist": ["1"]}}, + {"op": "add", "path": "/final_output/alist/1", "value": "1"}, + ), + RunLogPatch( + {"op": "add", "path": "/streamed_output/-", "value": {"alist": ["2"]}}, + {"op": "add", "path": "/final_output/alist/2", "value": "2"}, + ), + RunLogPatch( + {"op": "add", "path": "/streamed_output/-", "value": {"alist": ["3"]}}, + {"op": "add", "path": "/final_output/alist/3", "value": "3"}, + ), + ] + + state = add(stream_log) + + assert isinstance(state, RunLog) + + assert state.state == { + "final_output": {"alist": ["0", "1", "2", "3"]}, + "logs": {}, + "streamed_output": [ + {"alist": ["0"]}, + {"alist": ["1"]}, + {"alist": ["2"]}, + {"alist": ["3"]}, + ], + } + + +@pytest.mark.asyncio @freeze_time("2023-01-01") async def test_prompt_with_llm_and_async_lambda( mocker: MockerFixture, snapshot: SnapshotAssertion