mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Implement .transform() in RunnablePassthrough() (#9032)
- This ensures passthrough doesnt break streaming --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
206f809366
commit
3bdc273ab3
@ -1,9 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import AsyncIterator, Iterator, List, Optional
|
||||
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema.runnable.base import Input, Runnable, RunnableConfig
|
||||
from langchain.schema.runnable.base import Input, Runnable
|
||||
from langchain.schema.runnable.config import RunnableConfig
|
||||
|
||||
|
||||
def identity(x: Input) -> Input:
|
||||
return x
|
||||
|
||||
|
||||
async def aidentity(x: Input) -> Input:
|
||||
return x
|
||||
|
||||
|
||||
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
||||
@ -20,4 +29,19 @@ class RunnablePassthrough(Serializable, Runnable[Input, Input]):
|
||||
return self.__class__.__module__.split(".")[:-1]
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
||||
return self._call_with_config(lambda x: x, input, config)
|
||||
return self._call_with_config(identity, input, config)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: RunnableConfig | None = None
|
||||
) -> Input:
|
||||
return await self._acall_with_config(aidentity, input, config)
|
||||
|
||||
def transform(
|
||||
self, input: Iterator[Input], config: RunnableConfig | None = None
|
||||
) -> Iterator[Input]:
|
||||
return self._transform_stream_with_config(input, identity, config)
|
||||
|
||||
def atransform(
|
||||
self, input: AsyncIterator[Input], config: RunnableConfig | None = None
|
||||
) -> AsyncIterator[Input]:
|
||||
return self._atransform_stream_with_config(input, identity, config)
|
||||
|
@ -784,6 +784,13 @@ def test_deep_stream() -> None:
|
||||
assert len(chunks) == len("foo-lish")
|
||||
assert "".join(chunks) == "foo-lish"
|
||||
|
||||
chunks = []
|
||||
for chunk in (chain | RunnablePassthrough()).stream({"question": "What up"}):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == len("foo-lish")
|
||||
assert "".join(chunks) == "foo-lish"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deep_astream() -> None:
|
||||
@ -804,6 +811,13 @@ async def test_deep_astream() -> None:
|
||||
assert len(chunks) == len("foo-lish")
|
||||
assert "".join(chunks) == "foo-lish"
|
||||
|
||||
chunks = []
|
||||
async for chunk in (chain | RunnablePassthrough()).astream({"question": "What up"}):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == len("foo-lish")
|
||||
assert "".join(chunks) == "foo-lish"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def llm_with_fallbacks() -> RunnableWithFallbacks:
|
||||
|
Loading…
Reference in New Issue
Block a user