Add test

Add test

Lint
This commit is contained in:
Nuno Campos 2023-08-18 14:08:54 +01:00
parent c184be5511
commit 93bbf67afc
3 changed files with 191 additions and 0 deletions

View File

@ -217,6 +217,12 @@ class Runnable(Generic[Input, Output], ABC):
"""
return RunnableBinding(bound=self, kwargs=kwargs)
def each(self) -> Runnable[List[Input], List[Output]]:
"""
Wrap a Runnable to run it on each element of the input sequence.
"""
return RunnableEach(bound=self)
def with_fallbacks(
self,
fallbacks: Sequence[Runnable[Input, Output]],
@ -1360,6 +1366,41 @@ class RunnableLambda(Runnable[Input, Output]):
return self._call_with_config(self.func, input, config)
class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
"""
A runnable that delegates calls to another runnable with each element of the input sequence.
"""
bound: Runnable[Input, Output]
class Config:
arbitrary_types_allowed = True
@property
def lc_serializable(self) -> bool:
return True
@property
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
def each(self) -> RunnableEach[Input, Output]: # type: ignore[override]
return self
def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]:
return RunnableEach(bound=self.bound.bind(**kwargs))
def invoke(
self, input: List[Input], config: Optional[RunnableConfig] = None
) -> List[Output]:
return self.bound.batch(input, config)
async def ainvoke(
self, input: List[Input], config: Optional[RunnableConfig] = None
) -> List[Output]:
return await self.bound.abatch(input, config)
class RunnableBinding(Serializable, Runnable[Input, Output]):
"""
A runnable that delegates calls to another runnable with a set of kwargs.

File diff suppressed because one or more lines are too long

View File

@ -1,5 +1,6 @@
from typing import Any, Dict, List, Optional
from uuid import UUID
from xml.dom import ValidationErr
import pytest
from freezegun import freeze_time
@ -20,6 +21,7 @@ from langchain.prompts.chat import (
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.pydantic_v1 import ValidationError
from langchain.schema.document import Document
from langchain.schema.messages import (
AIMessage,
@ -1086,3 +1088,18 @@ async def test_llm_with_fallbacks(
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
assert list(await runnable.ainvoke("hello")) == list("bar")
assert dumps(runnable, pretty=True) == snapshot
def test_each(snapshot: SnapshotAssertion) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
first_llm = FakeStreamingListLLM(responses=["first item, second item, third item"])
second_llm = FakeStreamingListLLM(responses=["this", "is", "a", "test"])
chain = prompt | first_llm | CommaSeparatedListOutputParser() | second_llm.each()
assert dumps(chain, pretty=True) == snapshot
output = chain.invoke({"question": "What up"})
assert output == ["this", "is", "a"]