pull/11214/head
Nuno Campos 1 year ago
parent 0318cdd33c
commit 2387647d30

@ -123,6 +123,7 @@ class Runnable(Generic[Input, Output], ABC):
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Callable[[Iterator[Any]], Iterator[Other]],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
) -> RunnableSequence[Input, Other]:
@ -132,7 +133,8 @@ class Runnable(Generic[Input, Output], ABC):
self,
other: Union[
Runnable[Other, Any],
Callable[[Any], Other],
Callable[[Other], Any],
Callable[[Iterator[Other]], Iterator[Any]],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
],
) -> RunnableSequence[Other, Output]:
@ -353,7 +355,7 @@ class Runnable(Generic[Input, Output], ABC):
else:
# Make a best effort to gather, for any type that supports `+`
# This method should throw an error if gathering fails.
final += chunk # type: ignore[operator]
final = final + chunk # type: ignore[operator]
if got_first_val:
yield from self.stream(final, config, **kwargs)
@ -379,7 +381,7 @@ class Runnable(Generic[Input, Output], ABC):
else:
# Make a best effort to gather, for any type that supports `+`
# This method should throw an error if gathering fails.
final += chunk # type: ignore[operator]
final = final + chunk # type: ignore[operator]
if got_first_val:
async for output in self.astream(final, config, **kwargs):
@ -710,7 +712,7 @@ class Runnable(Generic[Input, Output], ABC):
final_output = chunk
else:
try:
final_output += chunk # type: ignore[operator]
final_output = final_output + chunk # type: ignore
except TypeError:
final_output = None
final_output_supported = False
@ -720,7 +722,7 @@ class Runnable(Generic[Input, Output], ABC):
final_input = ichunk
else:
try:
final_input += ichunk # type: ignore[operator]
final_input = final_input + ichunk # type: ignore
except TypeError:
final_input = None
final_input_supported = False
@ -788,7 +790,7 @@ class Runnable(Generic[Input, Output], ABC):
final_output = chunk
else:
try:
final_output += chunk # type: ignore[operator]
final_output = final_output + chunk # type: ignore
except TypeError:
final_output = None
final_output_supported = False
@ -798,7 +800,7 @@ class Runnable(Generic[Input, Output], ABC):
final_input = ichunk
else:
try:
final_input += ichunk # type: ignore[operator]
final_input = final_input + ichunk # type: ignore[operator]
except TypeError:
final_input = None
final_input_supported = False
@ -1315,6 +1317,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Callable[[Iterator[Any]], Iterator[Other]],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
) -> RunnableSequence[Input, Other]:
@ -1335,7 +1338,8 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
self,
other: Union[
Runnable[Other, Any],
Callable[[Any], Other],
Callable[[Other], Any],
Callable[[Iterator[Other]], Iterator[Any]],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
],
) -> RunnableSequence[Other, Output]:
@ -1755,7 +1759,7 @@ class RunnableMapChunk(Dict[str, Any]):
if key not in chunk or chunk[key] is None:
chunk[key] = other[key]
elif other[key] is not None:
chunk[key] += other[key]
chunk[key] = chunk[key] + other[key]
return chunk
def __radd__(self, other: RunnableMapChunk) -> RunnableMapChunk:
@ -1764,7 +1768,7 @@ class RunnableMapChunk(Dict[str, Any]):
if key not in chunk or chunk[key] is None:
chunk[key] = self[key]
elif self[key] is not None:
chunk[key] += self[key]
chunk[key] = chunk[key] + self[key]
return chunk
@ -2107,7 +2111,7 @@ class RunnableGenerator(Runnable[Input, Output]):
return Any
@property
def OutputType(self) -> Type[Output]:
def OutputType(self) -> Any:
func = getattr(self, "_transform", None) or getattr(self, "_atransform")
try:
sig = inspect.signature(func)
@ -2137,7 +2141,7 @@ class RunnableGenerator(Runnable[Input, Output]):
self,
input: Iterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Any | None,
**kwargs: Any,
) -> Iterator[Output]:
return self._transform_stream_with_config(
input, self._transform, config, **kwargs
@ -2147,7 +2151,7 @@ class RunnableGenerator(Runnable[Input, Output]):
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Any | None,
**kwargs: Any,
) -> Iterator[Output]:
return self.transform(iter([input]), config, **kwargs)
@ -2159,14 +2163,14 @@ class RunnableGenerator(Runnable[Input, Output]):
if final is None:
final = output
else:
final += output
return final
final = final + output
return cast(Output, final)
def atransform(
self,
input: AsyncIterator[Input],
config: Optional[RunnableConfig] = None,
**kwargs: Any | None,
**kwargs: Any,
) -> AsyncIterator[Output]:
if not hasattr(self, "_atransform"):
raise NotImplementedError("This runnable does not support async methods.")
@ -2179,7 +2183,7 @@ class RunnableGenerator(Runnable[Input, Output]):
self,
input: Input,
config: Optional[RunnableConfig] = None,
**kwargs: Any | None,
**kwargs: Any,
) -> AsyncIterator[Output]:
async def input_aiter() -> AsyncIterator[Input]:
yield input
@ -2187,15 +2191,15 @@ class RunnableGenerator(Runnable[Input, Output]):
return self.atransform(input_aiter(), config, **kwargs)
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
final = None
async for output in self.astream(input, config):
async for output in self.astream(input, config, **kwargs):
if final is None:
final = output
else:
final += output
return final
final = final + output
return cast(Output, final)
class RunnableLambda(Runnable[Input, Output]):
@ -2687,7 +2691,7 @@ def coerce_to_runnable(thing: RunnableLike) -> Runnable[Input, Output]:
elif inspect.isasyncgenfunction(thing) or inspect.isgeneratorfunction(thing):
return RunnableGenerator(thing)
elif callable(thing):
return RunnableLambda(thing)
return RunnableLambda(cast(Callable[[Input], Output], thing))
elif isinstance(thing, dict):
runnables: Mapping[str, Runnable[Any, Any]] = {
key: coerce_to_runnable(r) for key, r in thing.items()

@ -15,14 +15,7 @@ from typing import (
from typing_extensions import TypedDict
from langchain.load.serializable import Serializable
from langchain.schema.runnable.base import (
Input,
Other,
Output,
Runnable,
RunnableSequence,
coerce_to_runnable,
)
from langchain.schema.runnable.base import Input, Output, Runnable, coerce_to_runnable
from langchain.schema.runnable.config import (
RunnableConfig,
get_config_list,
@ -71,28 +64,6 @@ class RouterRunnable(Serializable, Runnable[RouterInput, Output]):
def get_lc_namespace(cls) -> List[str]:
return cls.__module__.split(".")[:-1]
def __or__(
self,
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
Mapping[str, Any],
],
) -> RunnableSequence[RouterInput, Other]:
return RunnableSequence(first=self, last=coerce_to_runnable(other))
def __ror__(
self,
other: Union[
Runnable[Other, Any],
Callable[[Any], Other],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
Mapping[str, Any],
],
) -> RunnableSequence[Other, Output]:
return RunnableSequence(first=coerce_to_runnable(other), last=self)
def invoke(
self, input: RouterInput, config: Optional[RunnableConfig] = None
) -> Output:

@ -12,7 +12,6 @@ from typing import (
cast,
)
from uuid import UUID
from langchain.schema.runnable.base import RunnableGenerator
import pytest
from freezegun import freeze_time
@ -57,6 +56,7 @@ from langchain.schema.runnable import (
RunnableSequence,
RunnableWithFallbacks,
)
from langchain.schema.runnable.base import RunnableGenerator
from langchain.tools.base import BaseTool, tool
from langchain.tools.json.tool import JsonListKeysTool, JsonSpec
@ -2876,7 +2876,7 @@ async def test_runnable_gen_transform() -> None:
async for i in input:
yield i + 1
chain = RunnableGenerator(gen_indexes, agen_indexes) | plus_one
chain: Runnable = RunnableGenerator(gen_indexes, agen_indexes) | plus_one
achain = RunnableGenerator(gen_indexes, agen_indexes) | aplus_one
assert chain.input_schema.schema() == {

Loading…
Cancel
Save