This commit is contained in:
Bagatur 2023-08-17 15:29:00 -07:00
parent bd80cad6db
commit c447e9a854
2 changed files with 37 additions and 11 deletions

View File

@ -243,8 +243,8 @@ class Runnable(Generic[Input, Output], ABC):
def _call_with_config(
self,
func: Callable[[Input], Output],
input: Input,
func: Callable[[Any], Output],
input: Any,
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
) -> Output:
@ -273,8 +273,8 @@ class Runnable(Generic[Input, Output], ABC):
async def _acall_with_config(
self,
func: Callable[[Input], Awaitable[Output]],
input: Input,
func: Callable[[Any], Awaitable[Output]],
input: Any,
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
) -> Output:

View File

@ -16,7 +16,7 @@ class PutLocalVar(RunnablePassthrough):
stored in local state under the map keys.
"""
def __init__(self, key: str, **kwargs: Any) -> None:
def __init__(self, key: Union[str, Mapping[str, str]], **kwargs: Any) -> None:
super().__init__(key=key, **kwargs)
def _put(self, input: Input, *, config: Optional[RunnableConfig] = None) -> None:
@ -63,13 +63,13 @@ class PutLocalVar(RunnablePassthrough):
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
self._put(input, config=config)
return super().invoke(input, config)
return super().invoke(input, config=config)
async def ainvoke(
self, input: Input, config: RunnableConfig | None = None
) -> Input:
self._put(input, config=config)
return await super().ainvoke(input, config)
return await super().ainvoke(input, config=config)
def transform(
self, input: Iterator[Input], config: RunnableConfig | None = None
@ -102,14 +102,40 @@ class GetLocalVar(
def __init__(self, key: str, **kwargs: Any) -> None:
super().__init__(key=key, **kwargs)
def _get(self, full_input: Dict) -> Union[Output, Dict[str, Union[Input, Output]]]:
if self.passthrough_key:
return {
self.key: full_input["locals"][self.key],
self.passthrough_key: full_input["input"],
}
else:
return full_input["locals"][self.key]
async def _aget(
self, full_input: Dict
) -> Union[Output, Dict[str, Union[Input, Output]]]:
return self._get(full_input)
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Union[Output, Dict[str, Union[Input, Output]]]:
if config is None:
raise ValueError(
"PutLocalVar should only be used in a RunnableSequence, and should "
"GetLocalVar should only be used in a RunnableSequence, and should "
"therefore always receive a non-null config."
)
if self.passthrough_key is not None:
return {self.key: config["_locals"][self.key], self.passthrough_key: input}
return config["_locals"][self.key]
log_input = {"input": input, "locals": config["_locals"]}
return self._call_with_config(self._get, log_input, config)
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Union[Output, Dict[str, Union[Input, Output]]]:
if config is None:
raise ValueError(
"GetLocalVar should only be used in a RunnableSequence, and should "
"therefore always receive a non-null config."
)
log_input = {"input": input, "locals": config["_locals"]}
return await self._acall_with_config(self._aget, log_input, config)