mirror of
https://github.com/hwchase17/langchain
synced 2024-11-16 06:13:16 +00:00
core[patch]: RunnablePassthrough transform to autoupgrade to AddableDict (#19051)
Follow up on https://github.com/langchain-ai/langchain/pull/18743 which missed RunnablePassthrough Issues: https://github.com/langchain-ai/langchain/issues/18741 https://github.com/langchain-ai/langgraph/issues/136 https://github.com/langchain-ai/langserve/issues/504
This commit is contained in:
parent
41e2f60cd2
commit
06165efb5b
@ -69,6 +69,7 @@ from langchain_core.runnables.utils import (
|
||||
accepts_config,
|
||||
accepts_context,
|
||||
accepts_run_manager,
|
||||
adapt_first_streaming_chunk,
|
||||
create_model,
|
||||
gather_with_concurrency,
|
||||
get_function_first_arg_dict_keys,
|
||||
@ -1207,7 +1208,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
for chunk in input:
|
||||
if not got_first_val:
|
||||
final = _adapt_first_streaming_chunk(chunk) # type: ignore
|
||||
final = adapt_first_streaming_chunk(chunk) # type: ignore
|
||||
got_first_val = True
|
||||
else:
|
||||
# Make a best effort to gather, for any type that supports `+`
|
||||
@ -1240,7 +1241,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
async for chunk in input:
|
||||
if not got_first_val:
|
||||
final = _adapt_first_streaming_chunk(chunk) # type: ignore
|
||||
final = adapt_first_streaming_chunk(chunk) # type: ignore
|
||||
got_first_val = True
|
||||
else:
|
||||
# Make a best effort to gather, for any type that supports `+`
|
||||
@ -3731,7 +3732,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
final: Optional[Input] = None
|
||||
for ichunk in input:
|
||||
if final is None:
|
||||
final = _adapt_first_streaming_chunk(ichunk) # type: ignore
|
||||
final = adapt_first_streaming_chunk(ichunk) # type: ignore
|
||||
else:
|
||||
try:
|
||||
final = final + ichunk # type: ignore[operator]
|
||||
@ -3815,7 +3816,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
final: Optional[Input] = None
|
||||
async for ichunk in input:
|
||||
if final is None:
|
||||
final = _adapt_first_streaming_chunk(ichunk)
|
||||
final = adapt_first_streaming_chunk(ichunk)
|
||||
else:
|
||||
try:
|
||||
final = final + ichunk # type: ignore[operator]
|
||||
@ -4727,11 +4728,3 @@ def chain(
|
||||
yield chunk
|
||||
"""
|
||||
return RunnableLambda(func)
|
||||
|
||||
|
||||
def _adapt_first_streaming_chunk(chunk: Any) -> Any:
|
||||
"""This might transform the first chunk of a stream into an AddableDict."""
|
||||
if isinstance(chunk, dict) and not isinstance(chunk, AddableDict):
|
||||
return AddableDict(chunk)
|
||||
else:
|
||||
return chunk
|
||||
|
@ -40,6 +40,7 @@ from langchain_core.runnables.graph import Graph
|
||||
from langchain_core.runnables.utils import (
|
||||
AddableDict,
|
||||
ConfigurableFieldSpec,
|
||||
adapt_first_streaming_chunk,
|
||||
create_model,
|
||||
)
|
||||
from langchain_core.utils.aiter import atee, py_anext
|
||||
@ -248,7 +249,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
for chunk in self._transform_stream_with_config(input, identity, config):
|
||||
yield chunk
|
||||
if final is None:
|
||||
final = chunk
|
||||
final = adapt_first_streaming_chunk(chunk)
|
||||
else:
|
||||
final = final + chunk
|
||||
|
||||
@ -276,7 +277,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
):
|
||||
yield chunk
|
||||
if final is None:
|
||||
final = chunk
|
||||
final = adapt_first_streaming_chunk(chunk)
|
||||
else:
|
||||
final = final + chunk
|
||||
|
||||
|
@ -521,3 +521,11 @@ def _create_model_cached(
|
||||
return _create_model_base(
|
||||
__model_name, __config__=_SchemaConfig, **field_definitions
|
||||
)
|
||||
|
||||
|
||||
def adapt_first_streaming_chunk(chunk: Any) -> Any:
|
||||
"""This might transform the first chunk of a stream into an AddableDict."""
|
||||
if isinstance(chunk, dict) and not isinstance(chunk, AddableDict):
|
||||
return AddableDict(chunk)
|
||||
else:
|
||||
return chunk
|
||||
|
@ -5324,7 +5324,7 @@ def test_default_transform_with_dicts() -> None:
|
||||
assert list(runnable.transform(chunks)) == [{"foo": "an"}]
|
||||
|
||||
|
||||
async def test_defualt_atransform_with_dicts() -> None:
|
||||
async def test_default_atransform_with_dicts() -> None:
|
||||
"""Test that default transform works with dicts."""
|
||||
|
||||
class CustomRunnable(RunnableSerializable[Input, Output]):
|
||||
@ -5342,3 +5342,22 @@ async def test_defualt_atransform_with_dicts() -> None:
|
||||
chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]
|
||||
|
||||
assert chunks == [{"foo": "an"}]
|
||||
|
||||
|
||||
def test_passthrough_transform_with_dicts() -> None:
|
||||
"""Test that default transform works with dicts."""
|
||||
runnable = RunnablePassthrough(lambda x: x)
|
||||
chunks = [chunk for chunk in runnable.transform(iter([{"foo": "a"}, {"foo": "n"}]))]
|
||||
assert chunks == [{"foo": "a"}, {"foo": "n"}]
|
||||
|
||||
|
||||
async def test_passthrough_atransform_with_dicts() -> None:
|
||||
"""Test that default transform works with dicts."""
|
||||
runnable = RunnablePassthrough(lambda x: x)
|
||||
|
||||
async def chunk_iterator() -> AsyncIterator[Dict[str, str]]:
|
||||
yield {"foo": "a"}
|
||||
yield {"foo": "n"}
|
||||
|
||||
chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]
|
||||
assert chunks == [{"foo": "a"}, {"foo": "n"}]
|
||||
|
Loading…
Reference in New Issue
Block a user