mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
WIP
Add test Add test Lint
This commit is contained in:
parent
c184be5511
commit
93bbf67afc
@ -217,6 +217,12 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
return RunnableBinding(bound=self, kwargs=kwargs)
|
||||
|
||||
def each(self) -> Runnable[List[Input], List[Output]]:
|
||||
"""
|
||||
Wrap a Runnable to run it on each element of the input sequence.
|
||||
"""
|
||||
return RunnableEach(bound=self)
|
||||
|
||||
def with_fallbacks(
|
||||
self,
|
||||
fallbacks: Sequence[Runnable[Input, Output]],
|
||||
@ -1360,6 +1366,41 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
return self._call_with_config(self.func, input, config)
|
||||
|
||||
|
||||
class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
|
||||
"""
|
||||
A runnable that delegates calls to another runnable with each element of the input sequence.
|
||||
"""
|
||||
|
||||
bound: Runnable[Input, Output]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def lc_namespace(self) -> List[str]:
|
||||
return self.__class__.__module__.split(".")[:-1]
|
||||
|
||||
def each(self) -> RunnableEach[Input, Output]: # type: ignore[override]
|
||||
return self
|
||||
|
||||
def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]:
|
||||
return RunnableEach(bound=self.bound.bind(**kwargs))
|
||||
|
||||
def invoke(
|
||||
self, input: List[Input], config: Optional[RunnableConfig] = None
|
||||
) -> List[Output]:
|
||||
return self.bound.batch(input, config)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: List[Input], config: Optional[RunnableConfig] = None
|
||||
) -> List[Output]:
|
||||
return await self.bound.abatch(input, config)
|
||||
|
||||
|
||||
class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
"""
|
||||
A runnable that delegates calls to another runnable with a set of kwargs.
|
||||
|
File diff suppressed because one or more lines are too long
@ -1,5 +1,6 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
from xml.dom import ValidationErr
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
@ -20,6 +21,7 @@ from langchain.prompts.chat import (
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.pydantic_v1 import ValidationError
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
@ -1086,3 +1088,18 @@ async def test_llm_with_fallbacks(
|
||||
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
|
||||
assert list(await runnable.ainvoke("hello")) == list("bar")
|
||||
assert dumps(runnable, pretty=True) == snapshot
|
||||
|
||||
|
||||
def test_each(snapshot: SnapshotAssertion) -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
first_llm = FakeStreamingListLLM(responses=["first item, second item, third item"])
|
||||
second_llm = FakeStreamingListLLM(responses=["this", "is", "a", "test"])
|
||||
|
||||
chain = prompt | first_llm | CommaSeparatedListOutputParser() | second_llm.each()
|
||||
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
output = chain.invoke({"question": "What up"})
|
||||
assert output == ["this", "is", "a"]
|
||||
|
Loading…
Reference in New Issue
Block a user