core[patch]: RunnablePassthrough transform to autoupgrade to AddableDict (#19051)

Follow up on https://github.com/langchain-ai/langchain/pull/18743 which
missed RunnablePassthrough

Issues:

https://github.com/langchain-ai/langchain/issues/18741
https://github.com/langchain-ai/langgraph/issues/136
https://github.com/langchain-ai/langserve/issues/504
This commit is contained in:
Eugene Yurtsev 2024-03-14 16:59:46 -04:00 committed by GitHub
parent 41e2f60cd2
commit 06165efb5b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 36 additions and 15 deletions

View File

@ -69,6 +69,7 @@ from langchain_core.runnables.utils import (
accepts_config,
accepts_context,
accepts_run_manager,
adapt_first_streaming_chunk,
create_model,
gather_with_concurrency,
get_function_first_arg_dict_keys,
@ -1207,7 +1208,7 @@ class Runnable(Generic[Input, Output], ABC):
for chunk in input:
if not got_first_val:
final = _adapt_first_streaming_chunk(chunk) # type: ignore
final = adapt_first_streaming_chunk(chunk) # type: ignore
got_first_val = True
else:
# Make a best effort to gather, for any type that supports `+`
@ -1240,7 +1241,7 @@ class Runnable(Generic[Input, Output], ABC):
async for chunk in input:
if not got_first_val:
final = _adapt_first_streaming_chunk(chunk) # type: ignore
final = adapt_first_streaming_chunk(chunk) # type: ignore
got_first_val = True
else:
# Make a best effort to gather, for any type that supports `+`
@ -3731,7 +3732,7 @@ class RunnableLambda(Runnable[Input, Output]):
final: Optional[Input] = None
for ichunk in input:
if final is None:
final = _adapt_first_streaming_chunk(ichunk) # type: ignore
final = adapt_first_streaming_chunk(ichunk) # type: ignore
else:
try:
final = final + ichunk # type: ignore[operator]
@ -3815,7 +3816,7 @@ class RunnableLambda(Runnable[Input, Output]):
final: Optional[Input] = None
async for ichunk in input:
if final is None:
final = _adapt_first_streaming_chunk(ichunk)
final = adapt_first_streaming_chunk(ichunk)
else:
try:
final = final + ichunk # type: ignore[operator]
@ -4727,11 +4728,3 @@ def chain(
yield chunk
"""
return RunnableLambda(func)
def _adapt_first_streaming_chunk(chunk: Any) -> Any:
"""This might transform the first chunk of a stream into an AddableDict."""
if isinstance(chunk, dict) and not isinstance(chunk, AddableDict):
return AddableDict(chunk)
else:
return chunk

View File

@ -40,6 +40,7 @@ from langchain_core.runnables.graph import Graph
from langchain_core.runnables.utils import (
AddableDict,
ConfigurableFieldSpec,
adapt_first_streaming_chunk,
create_model,
)
from langchain_core.utils.aiter import atee, py_anext
@ -248,7 +249,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
for chunk in self._transform_stream_with_config(input, identity, config):
yield chunk
if final is None:
final = chunk
final = adapt_first_streaming_chunk(chunk)
else:
final = final + chunk
@ -276,7 +277,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
):
yield chunk
if final is None:
final = chunk
final = adapt_first_streaming_chunk(chunk)
else:
final = final + chunk

View File

@ -521,3 +521,11 @@ def _create_model_cached(
return _create_model_base(
__model_name, __config__=_SchemaConfig, **field_definitions
)
def adapt_first_streaming_chunk(chunk: Any) -> Any:
"""This might transform the first chunk of a stream into an AddableDict."""
if isinstance(chunk, dict) and not isinstance(chunk, AddableDict):
return AddableDict(chunk)
else:
return chunk

View File

@ -5324,7 +5324,7 @@ def test_default_transform_with_dicts() -> None:
assert list(runnable.transform(chunks)) == [{"foo": "an"}]
async def test_defualt_atransform_with_dicts() -> None:
async def test_default_atransform_with_dicts() -> None:
"""Test that default transform works with dicts."""
class CustomRunnable(RunnableSerializable[Input, Output]):
@ -5342,3 +5342,22 @@ async def test_defualt_atransform_with_dicts() -> None:
chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]
assert chunks == [{"foo": "an"}]
def test_passthrough_transform_with_dicts() -> None:
"""Test that default transform works with dicts."""
runnable = RunnablePassthrough(lambda x: x)
chunks = [chunk for chunk in runnable.transform(iter([{"foo": "a"}, {"foo": "n"}]))]
assert chunks == [{"foo": "a"}, {"foo": "n"}]
async def test_passthrough_atransform_with_dicts() -> None:
"""Test that default transform works with dicts."""
runnable = RunnablePassthrough(lambda x: x)
async def chunk_iterator() -> AsyncIterator[Dict[str, str]]:
yield {"foo": "a"}
yield {"foo": "n"}
chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]
assert chunks == [{"foo": "a"}, {"foo": "n"}]