|
|
|
@ -1,10 +1,13 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
|
import copy
|
|
|
|
|
import threading
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
|
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
|
|
|
|
|
from itertools import tee
|
|
|
|
|
from typing import (
|
|
|
|
|
TYPE_CHECKING,
|
|
|
|
|
Any,
|
|
|
|
|
AsyncIterator,
|
|
|
|
|
Awaitable,
|
|
|
|
@ -23,15 +26,25 @@ from typing import (
|
|
|
|
|
cast,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from langchain.callbacks.manager import (
|
|
|
|
|
AsyncCallbackManagerForChainRun,
|
|
|
|
|
CallbackManagerForChainRun,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from langchain.callbacks.base import BaseCallbackManager
|
|
|
|
|
from langchain.load.dump import dumpd
|
|
|
|
|
from langchain.load.serializable import Serializable
|
|
|
|
|
from langchain.pydantic_v1 import Field
|
|
|
|
|
from langchain.schema.runnable.config import RunnableConfig
|
|
|
|
|
from langchain.schema.runnable.utils import (
|
|
|
|
|
accepts_run_manager,
|
|
|
|
|
accepts_run_manager_and_config,
|
|
|
|
|
gather_with_concurrency,
|
|
|
|
|
)
|
|
|
|
|
from langchain.utils.aiter import atee, py_anext
|
|
|
|
|
from langchain.utils.iter import safetee
|
|
|
|
|
|
|
|
|
|
Input = TypeVar("Input")
|
|
|
|
|
# Output type should implement __concat__, as eg str, list, dict do
|
|
|
|
@ -48,7 +61,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|
|
|
|
other: Union[
|
|
|
|
|
Runnable[Any, Other],
|
|
|
|
|
Callable[[Any], Other],
|
|
|
|
|
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
|
|
|
|
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
|
|
|
|
],
|
|
|
|
|
) -> RunnableSequence[Input, Other]:
|
|
|
|
|
return RunnableSequence(first=self, last=coerce_to_runnable(other))
|
|
|
|
@ -58,7 +71,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|
|
|
|
other: Union[
|
|
|
|
|
Runnable[Other, Any],
|
|
|
|
|
Callable[[Any], Other],
|
|
|
|
|
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
|
|
|
|
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
|
|
|
|
|
],
|
|
|
|
|
) -> RunnableSequence[Other, Output]:
|
|
|
|
|
return RunnableSequence(first=coerce_to_runnable(other), last=self)
|
|
|
|
@ -135,7 +148,10 @@ class Runnable(Generic[Input, Output], ABC):
|
|
|
|
|
yield await self.ainvoke(input, config)
|
|
|
|
|
|
|
|
|
|
def transform(
|
|
|
|
|
self, input: Iterator[Input], config: Optional[RunnableConfig] = None
|
|
|
|
|
self,
|
|
|
|
|
input: Iterator[Input],
|
|
|
|
|
config: Optional[RunnableConfig] = None,
|
|
|
|
|
**kwargs: Optional[Any],
|
|
|
|
|
) -> Iterator[Output]:
|
|
|
|
|
"""
|
|
|
|
|
Default implementation of transform, which buffers input and then calls stream.
|
|
|
|
@ -152,10 +168,13 @@ class Runnable(Generic[Input, Output], ABC):
|
|
|
|
|
# This method should throw an error if gathering fails.
|
|
|
|
|
final += chunk # type: ignore[operator]
|
|
|
|
|
if final:
|
|
|
|
|
yield from self.stream(final, config)
|
|
|
|
|
yield from self.stream(final, config, **kwargs)
|
|
|
|
|
|
|
|
|
|
async def atransform(
|
|
|
|
|
self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None
|
|
|
|
|
self,
|
|
|
|
|
input: AsyncIterator[Input],
|
|
|
|
|
config: Optional[RunnableConfig] = None,
|
|
|
|
|
**kwargs: Optional[Any],
|
|
|
|
|
) -> AsyncIterator[Output]:
|
|
|
|
|
"""
|
|
|
|
|
Default implementation of atransform, which buffers input and calls astream.
|
|
|
|
@ -173,7 +192,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|
|
|
|
final += chunk # type: ignore[operator]
|
|
|
|
|
|
|
|
|
|
if final:
|
|
|
|
|
async for output in self.astream(final, config):
|
|
|
|
|
async for output in self.astream(final, config, **kwargs):
|
|
|
|
|
yield output
|
|
|
|
|
|
|
|
|
|
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
|
|
|
@ -217,7 +236,11 @@ class Runnable(Generic[Input, Output], ABC):
|
|
|
|
|
|
|
|
|
|
def _call_with_config(
|
|
|
|
|
self,
|
|
|
|
|
func: Callable[[Input], Output],
|
|
|
|
|
func: Union[
|
|
|
|
|
Callable[[Input], Output],
|
|
|
|
|
Callable[[Input, CallbackManagerForChainRun], Output],
|
|
|
|
|
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
|
|
|
|
|
],
|
|
|
|
|
input: Input,
|
|
|
|
|
config: Optional[RunnableConfig],
|
|
|
|
|
run_type: Optional[str] = None,
|
|
|
|
@ -238,7 +261,16 @@ class Runnable(Generic[Input, Output], ABC):
|
|
|
|
|
run_type=run_type,
|
|
|
|
|
)
|
|
|
|
|
try:
|
|
|
|
|
output = func(input)
|
|
|
|
|
if accepts_run_manager_and_config(func):
|
|
|
|
|
output = func(
|
|
|
|
|
input,
|
|
|
|
|
run_manager=run_manager,
|
|
|
|
|
config=config,
|
|
|
|
|
) # type: ignore[call-arg]
|
|
|
|
|
elif accepts_run_manager(func):
|
|
|
|
|
output = func(input, run_manager=run_manager) # type: ignore[call-arg]
|
|
|
|
|
else:
|
|
|
|
|
output = func(input) # type: ignore[call-arg]
|
|
|
|
|
except Exception as e:
|
|
|
|
|
run_manager.on_chain_error(e)
|
|
|
|
|
raise
|
|
|
|
@ -253,7 +285,14 @@ class Runnable(Generic[Input, Output], ABC):
|
|
|
|
|
|
|
|
|
|
async def _acall_with_config(
|
|
|
|
|
self,
|
|
|
|
|
func: Callable[[Input], Awaitable[Output]],
|
|
|
|
|
func: Union[
|
|
|
|
|
Callable[[Input], Awaitable[Output]],
|
|
|
|
|
Callable[[Input, AsyncCallbackManagerForChainRun], Awaitable[Output]],
|
|
|
|
|
Callable[
|
|
|
|
|
[Input, AsyncCallbackManagerForChainRun, RunnableConfig],
|
|
|
|
|
Awaitable[Output],
|
|
|
|
|
],
|
|
|
|
|
],
|
|
|
|
|
input: Input,
|
|
|
|
|
config: Optional[RunnableConfig],
|
|
|
|
|
run_type: Optional[str] = None,
|
|
|
|
@ -274,7 +313,19 @@ class Runnable(Generic[Input, Output], ABC):
|
|
|
|
|
run_type=run_type,
|
|
|
|
|
)
|
|
|
|
|
try:
|
|
|
|
|
output = await func(input)
|
|
|
|
|
if accepts_run_manager_and_config(func):
|
|
|
|
|
output = await func(
|
|
|
|
|
input,
|
|
|
|
|
run_manager=run_manager,
|
|
|
|
|
config=config,
|
|
|
|
|
) # type: ignore[call-arg]
|
|
|
|
|
elif accepts_run_manager(func):
|
|
|
|
|
output = await func(
|
|
|
|
|
input,
|
|
|
|
|
run_manager=run_manager,
|
|
|
|
|
) # type: ignore[call-arg]
|
|
|
|
|
else:
|
|
|
|
|
output = await func(input) # type: ignore[call-arg]
|
|
|
|
|
except Exception as e:
|
|
|
|
|
await run_manager.on_chain_error(e)
|
|
|
|
|
raise
|
|
|
|
@ -290,7 +341,18 @@ class Runnable(Generic[Input, Output], ABC):
|
|
|
|
|
def _transform_stream_with_config(
|
|
|
|
|
self,
|
|
|
|
|
input: Iterator[Input],
|
|
|
|
|
transformer: Callable[[Iterator[Input]], Iterator[Output]],
|
|
|
|
|
transformer: Union[
|
|
|
|
|
Callable[[Iterator[Input]], Iterator[Output]],
|
|
|
|
|
Callable[[Iterator[Input], CallbackManagerForChainRun], Iterator[Output]],
|
|
|
|
|
Callable[
|
|
|
|
|
[
|
|
|
|
|
Iterator[Input],
|
|
|
|
|
CallbackManagerForChainRun,
|
|
|
|
|
RunnableConfig,
|
|
|
|
|
],
|
|
|
|
|
Iterator[Output],
|
|
|
|
|
],
|
|
|
|
|
],
|
|
|
|
|
config: Optional[RunnableConfig],
|
|
|
|
|
run_type: Optional[str] = None,
|
|
|
|
|
) -> Iterator[Output]:
|
|
|
|
@ -319,7 +381,20 @@ class Runnable(Generic[Input, Output], ABC):
|
|
|
|
|
run_type=run_type,
|
|
|
|
|
)
|
|
|
|
|
try:
|
|
|
|
|
for chunk in transformer(input_for_transform):
|
|
|
|
|
if accepts_run_manager_and_config(transformer):
|
|
|
|
|
iterator = transformer(
|
|
|
|
|
input_for_transform,
|
|
|
|
|
run_manager=run_manager,
|
|
|
|
|
config=config,
|
|
|
|
|
) # type: ignore[call-arg]
|
|
|
|
|
elif accepts_run_manager(transformer):
|
|
|
|
|
iterator = transformer(
|
|
|
|
|
input_for_transform,
|
|
|
|
|
run_manager=run_manager,
|
|
|
|
|
) # type: ignore[call-arg]
|
|
|
|
|
else:
|
|
|
|
|
iterator = transformer(input_for_transform) # type: ignore[call-arg]
|
|
|
|
|
for chunk in iterator:
|
|
|
|
|
yield chunk
|
|
|
|
|
if final_output_supported:
|
|
|
|
|
if final_output is None:
|
|
|
|
@ -361,7 +436,21 @@ class Runnable(Generic[Input, Output], ABC):
|
|
|
|
|
async def _atransform_stream_with_config(
|
|
|
|
|
self,
|
|
|
|
|
input: AsyncIterator[Input],
|
|
|
|
|
transformer: Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
|
|
|
|
|
transformer: Union[
|
|
|
|
|
Callable[[AsyncIterator[Input]], AsyncIterator[Output]],
|
|
|
|
|
Callable[
|
|
|
|
|
[AsyncIterator[Input], AsyncCallbackManagerForChainRun],
|
|
|
|
|
AsyncIterator[Output],
|
|
|
|
|
],
|
|
|
|
|
Callable[
|
|
|
|
|
[
|
|
|
|
|
AsyncIterator[Input],
|
|
|
|
|
AsyncCallbackManagerForChainRun,
|
|
|
|
|
RunnableConfig,
|
|
|
|
|
],
|
|
|
|
|
AsyncIterator[Output],
|
|
|
|
|
],
|
|
|
|
|
],
|
|
|
|
|
config: Optional[RunnableConfig],
|
|
|
|
|
run_type: Optional[str] = None,
|
|
|
|
|
) -> AsyncIterator[Output]:
|
|
|
|
@ -390,7 +479,22 @@ class Runnable(Generic[Input, Output], ABC):
|
|
|
|
|
run_type=run_type,
|
|
|
|
|
)
|
|
|
|
|
try:
|
|
|
|
|
async for chunk in transformer(input_for_transform):
|
|
|
|
|
# mypy can't quite work out thew type guard here, but this is safe,
|
|
|
|
|
# check implementations of the accepts_* functions
|
|
|
|
|
if accepts_run_manager_and_config(transformer):
|
|
|
|
|
iterator = transformer(
|
|
|
|
|
input_for_transform,
|
|
|
|
|
run_manager=run_manager,
|
|
|
|
|
config=config,
|
|
|
|
|
) # type: ignore[call-arg]
|
|
|
|
|
elif accepts_run_manager(transformer):
|
|
|
|
|
iterator = transformer(
|
|
|
|
|
input_for_transform,
|
|
|
|
|
run_manager=run_manager,
|
|
|
|
|
) # type: ignore[call-arg]
|
|
|
|
|
else:
|
|
|
|
|
iterator = transformer(input_for_transform) # type: ignore[call-arg]
|
|
|
|
|
async for chunk in iterator:
|
|
|
|
|
yield chunk
|
|
|
|
|
if final_output_supported:
|
|
|
|
|
if final_output is None:
|
|
|
|
@ -700,7 +804,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|
|
|
|
other: Union[
|
|
|
|
|
Runnable[Any, Other],
|
|
|
|
|
Callable[[Any], Other],
|
|
|
|
|
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
|
|
|
|
|
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
|
|
|
|
],
|
|
|
|
|
) -> RunnableSequence[Input, Other]:
|
|
|
|
|
if isinstance(other, RunnableSequence):
|
|
|
|
@ -721,7 +825,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|
|
|
|
other: Union[
|
|
|
|
|
Runnable[Other, Any],
|
|
|
|
|
Callable[[Any], Other],
|
|
|
|
|
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
|
|
|
|
|
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
|
|
|
|
|
],
|
|
|
|
|
) -> RunnableSequence[Other, Output]:
|
|
|
|
|
if isinstance(other, RunnableSequence):
|
|
|
|
@ -875,7 +979,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|
|
|
|
) -> List[Output]:
|
|
|
|
|
from langchain.callbacks.manager import (
|
|
|
|
|
AsyncCallbackManager,
|
|
|
|
|
AsyncCallbackManagerForChainRun,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# setup callbacks
|
|
|
|
@ -1085,6 +1188,21 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RunnableMapChunk(Dict[str, Any]):
|
|
|
|
|
"""
|
|
|
|
|
Partial output from a RunnableMap
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __add__(self, other: RunnableMapChunk) -> RunnableMapChunk:
|
|
|
|
|
chunk = copy.deepcopy(self)
|
|
|
|
|
for key in other:
|
|
|
|
|
if key not in chunk or chunk[key] is None:
|
|
|
|
|
chunk[key] = other[key]
|
|
|
|
|
elif other[key] is not None:
|
|
|
|
|
chunk[key] += other[key]
|
|
|
|
|
return chunk
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|
|
|
|
"""
|
|
|
|
|
A runnable that runs a mapping of runnables in parallel,
|
|
|
|
@ -1134,7 +1252,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|
|
|
|
local_metadata=None,
|
|
|
|
|
)
|
|
|
|
|
# start the root run
|
|
|
|
|
run_manager = callback_manager.on_chain_start(dumpd(self), {"input": input})
|
|
|
|
|
run_manager = callback_manager.on_chain_start(
|
|
|
|
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# gather results from all steps
|
|
|
|
|
try:
|
|
|
|
@ -1177,7 +1297,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|
|
|
|
)
|
|
|
|
|
# start the root run
|
|
|
|
|
run_manager = await callback_manager.on_chain_start(
|
|
|
|
|
dumpd(self), {"input": input}
|
|
|
|
|
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# gather results from all steps
|
|
|
|
@ -1203,6 +1323,134 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
|
|
|
|
|
await run_manager.on_chain_end(output)
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
def _transform(
|
|
|
|
|
self,
|
|
|
|
|
input: Iterator[Input],
|
|
|
|
|
run_manager: CallbackManagerForChainRun,
|
|
|
|
|
config: RunnableConfig,
|
|
|
|
|
) -> Iterator[RunnableMapChunk]:
|
|
|
|
|
# Shallow copy steps to ignore mutations while in progress
|
|
|
|
|
steps = dict(self.steps)
|
|
|
|
|
# Each step gets a copy of the input iterator,
|
|
|
|
|
# which is consumed in parallel in a separate thread.
|
|
|
|
|
input_copies = list(safetee(input, len(steps), lock=threading.Lock()))
|
|
|
|
|
with ThreadPoolExecutor() as executor:
|
|
|
|
|
# Create the transform() generator for each step
|
|
|
|
|
named_generators = [
|
|
|
|
|
(
|
|
|
|
|
name,
|
|
|
|
|
step.transform(
|
|
|
|
|
input_copies.pop(),
|
|
|
|
|
patch_config(config, run_manager.get_child()),
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
for name, step in steps.items()
|
|
|
|
|
]
|
|
|
|
|
# Start the first iteration of each generator
|
|
|
|
|
futures = {
|
|
|
|
|
executor.submit(next, generator): (step_name, generator)
|
|
|
|
|
for step_name, generator in named_generators
|
|
|
|
|
}
|
|
|
|
|
# Yield chunks from each as they become available,
|
|
|
|
|
# and start the next iteration of that generator that yielded it.
|
|
|
|
|
# When all generators are exhausted, stop.
|
|
|
|
|
while futures:
|
|
|
|
|
completed_futures, _ = wait(futures, return_when=FIRST_COMPLETED)
|
|
|
|
|
for future in completed_futures:
|
|
|
|
|
(step_name, generator) = futures.pop(future)
|
|
|
|
|
try:
|
|
|
|
|
chunk = RunnableMapChunk({step_name: future.result()})
|
|
|
|
|
yield chunk
|
|
|
|
|
futures[executor.submit(next, generator)] = (
|
|
|
|
|
step_name,
|
|
|
|
|
generator,
|
|
|
|
|
)
|
|
|
|
|
except StopIteration:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def transform(
|
|
|
|
|
self,
|
|
|
|
|
input: Iterator[Input],
|
|
|
|
|
config: Optional[RunnableConfig] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> Iterator[Dict[str, Any]]:
|
|
|
|
|
yield from self._transform_stream_with_config(
|
|
|
|
|
input, self._transform, config, **kwargs
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def stream(
|
|
|
|
|
self, input: Input, config: Optional[RunnableConfig] = None
|
|
|
|
|
) -> Iterator[Dict[str, Any]]:
|
|
|
|
|
yield from self.transform(iter([input]), config)
|
|
|
|
|
|
|
|
|
|
async def _atransform(
|
|
|
|
|
self,
|
|
|
|
|
input: AsyncIterator[Input],
|
|
|
|
|
run_manager: AsyncCallbackManagerForChainRun,
|
|
|
|
|
config: RunnableConfig,
|
|
|
|
|
) -> AsyncIterator[RunnableMapChunk]:
|
|
|
|
|
# Shallow copy steps to ignore mutations while in progress
|
|
|
|
|
steps = dict(self.steps)
|
|
|
|
|
# Each step gets a copy of the input iterator,
|
|
|
|
|
# which is consumed in parallel in a separate thread.
|
|
|
|
|
input_copies = list(atee(input, len(steps), lock=asyncio.Lock()))
|
|
|
|
|
# Create the transform() generator for each step
|
|
|
|
|
named_generators = [
|
|
|
|
|
(
|
|
|
|
|
name,
|
|
|
|
|
step.atransform(
|
|
|
|
|
input_copies.pop(), patch_config(config, run_manager.get_child())
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
for name, step in steps.items()
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# Wrap in a coroutine to satisfy linter
|
|
|
|
|
async def get_next_chunk(generator: AsyncIterator) -> Optional[Output]:
|
|
|
|
|
return await py_anext(generator)
|
|
|
|
|
|
|
|
|
|
# Start the first iteration of each generator
|
|
|
|
|
tasks = {
|
|
|
|
|
asyncio.create_task(get_next_chunk(generator)): (step_name, generator)
|
|
|
|
|
for step_name, generator in named_generators
|
|
|
|
|
}
|
|
|
|
|
# Yield chunks from each as they become available,
|
|
|
|
|
# and start the next iteration of the generator that yielded it.
|
|
|
|
|
# When all generators are exhausted, stop.
|
|
|
|
|
while tasks:
|
|
|
|
|
completed_tasks, _ = await asyncio.wait(
|
|
|
|
|
tasks, return_when=asyncio.FIRST_COMPLETED
|
|
|
|
|
)
|
|
|
|
|
for task in completed_tasks:
|
|
|
|
|
(step_name, generator) = tasks.pop(task)
|
|
|
|
|
try:
|
|
|
|
|
chunk = RunnableMapChunk({step_name: task.result()})
|
|
|
|
|
yield chunk
|
|
|
|
|
new_task = asyncio.create_task(get_next_chunk(generator))
|
|
|
|
|
tasks[new_task] = (step_name, generator)
|
|
|
|
|
except StopAsyncIteration:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
async def atransform(
|
|
|
|
|
self,
|
|
|
|
|
input: AsyncIterator[Input],
|
|
|
|
|
config: Optional[RunnableConfig] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> AsyncIterator[Dict[str, Any]]:
|
|
|
|
|
async for chunk in self._atransform_stream_with_config(
|
|
|
|
|
input, self._atransform, config, **kwargs
|
|
|
|
|
):
|
|
|
|
|
yield chunk
|
|
|
|
|
|
|
|
|
|
async def astream(
|
|
|
|
|
self, input: Input, config: Optional[RunnableConfig] = None
|
|
|
|
|
) -> AsyncIterator[Dict[str, Any]]:
|
|
|
|
|
async def input_aiter() -> AsyncIterator[Input]:
|
|
|
|
|
yield input
|
|
|
|
|
|
|
|
|
|
async for chunk in self.atransform(input_aiter(), config):
|
|
|
|
|
yield chunk
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RunnableLambda(Runnable[Input, Output]):
|
|
|
|
|
"""
|
|
|
|
@ -1293,14 +1541,22 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
|
|
|
|
|
yield item
|
|
|
|
|
|
|
|
|
|
def transform(
|
|
|
|
|
self, input: Iterator[Input], config: Optional[RunnableConfig] = None
|
|
|
|
|
self,
|
|
|
|
|
input: Iterator[Input],
|
|
|
|
|
config: Optional[RunnableConfig] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> Iterator[Output]:
|
|
|
|
|
yield from self.bound.transform(input, config, **self.kwargs)
|
|
|
|
|
yield from self.bound.transform(input, config, **{**self.kwargs, **kwargs})
|
|
|
|
|
|
|
|
|
|
async def atransform(
|
|
|
|
|
self, input: AsyncIterator[Input], config: Optional[RunnableConfig] = None
|
|
|
|
|
self,
|
|
|
|
|
input: AsyncIterator[Input],
|
|
|
|
|
config: Optional[RunnableConfig] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> AsyncIterator[Output]:
|
|
|
|
|
async for item in self.bound.atransform(input, config, **self.kwargs):
|
|
|
|
|
async for item in self.bound.atransform(
|
|
|
|
|
input, config, **{**self.kwargs, **kwargs}
|
|
|
|
|
):
|
|
|
|
|
yield item
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -1316,7 +1572,7 @@ def coerce_to_runnable(
|
|
|
|
|
thing: Union[
|
|
|
|
|
Runnable[Input, Output],
|
|
|
|
|
Callable[[Input], Output],
|
|
|
|
|
Mapping[str, Union[Runnable[Input, Output], Callable[[Input], Output]]],
|
|
|
|
|
Mapping[str, Any],
|
|
|
|
|
]
|
|
|
|
|
) -> Runnable[Input, Output]:
|
|
|
|
|
if isinstance(thing, Runnable):
|
|
|
|
@ -1324,7 +1580,9 @@ def coerce_to_runnable(
|
|
|
|
|
elif callable(thing):
|
|
|
|
|
return RunnableLambda(thing)
|
|
|
|
|
elif isinstance(thing, dict):
|
|
|
|
|
runnables = {key: coerce_to_runnable(r) for key, r in thing.items()}
|
|
|
|
|
runnables: Mapping[str, Runnable[Any, Any]] = {
|
|
|
|
|
key: coerce_to_runnable(r) for key, r in thing.items()
|
|
|
|
|
}
|
|
|
|
|
return cast(Runnable[Input, Output], RunnableMap(steps=runnables))
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(
|
|
|
|
|