Implement .transform() in RunnablePassthrough() (#9032)

- This ensures passthrough doesnt break streaming
---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Nuno Campos 2023-08-10 18:41:19 +01:00 committed by GitHub
parent 206f809366
commit 3bdc273ab3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 3 deletions

View File

@ -1,9 +1,18 @@
from __future__ import annotations
from typing import List, Optional
from typing import AsyncIterator, Iterator, List, Optional
from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import Input, Runnable, RunnableConfig
from langchain.schema.runnable.base import Input, Runnable
from langchain.schema.runnable.config import RunnableConfig
def identity(x: Input) -> Input:
return x
async def aidentity(x: Input) -> Input:
return x
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
@ -20,4 +29,19 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
return self.__class__.__module__.split(".")[:-1]
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
return self._call_with_config(lambda x: x, input, config)
return self._call_with_config(identity, input, config)
async def ainvoke(
self, input: Input, config: RunnableConfig | None = None
) -> Input:
return await self._acall_with_config(aidentity, input, config)
def transform(
self, input: Iterator[Input], config: RunnableConfig | None = None
) -> Iterator[Input]:
return self._transform_stream_with_config(input, identity, config)
def atransform(
self, input: AsyncIterator[Input], config: RunnableConfig | None = None
) -> AsyncIterator[Input]:
return self._atransform_stream_with_config(input, identity, config)

View File

@ -784,6 +784,13 @@ def test_deep_stream() -> None:
assert len(chunks) == len("foo-lish")
assert "".join(chunks) == "foo-lish"
chunks = []
for chunk in (chain | RunnablePassthrough()).stream({"question": "What up"}):
chunks.append(chunk)
assert len(chunks) == len("foo-lish")
assert "".join(chunks) == "foo-lish"
@pytest.mark.asyncio
async def test_deep_astream() -> None:
@ -804,6 +811,13 @@ async def test_deep_astream() -> None:
assert len(chunks) == len("foo-lish")
assert "".join(chunks) == "foo-lish"
chunks = []
async for chunk in (chain | RunnablePassthrough()).astream({"question": "What up"}):
chunks.append(chunk)
assert len(chunks) == len("foo-lish")
assert "".join(chunks) == "foo-lish"
@pytest.fixture()
def llm_with_fallbacks() -> RunnableWithFallbacks: