mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
cr
This commit is contained in:
parent
bd80cad6db
commit
c447e9a854
@ -243,8 +243,8 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
def _call_with_config(
|
||||
self,
|
||||
func: Callable[[Input], Output],
|
||||
input: Input,
|
||||
func: Callable[[Any], Output],
|
||||
input: Any,
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
) -> Output:
|
||||
@ -273,8 +273,8 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
|
||||
async def _acall_with_config(
|
||||
self,
|
||||
func: Callable[[Input], Awaitable[Output]],
|
||||
input: Input,
|
||||
func: Callable[[Any], Awaitable[Output]],
|
||||
input: Any,
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
) -> Output:
|
||||
|
@ -16,7 +16,7 @@ class PutLocalVar(RunnablePassthrough):
|
||||
stored in local state under the map keys.
|
||||
"""
|
||||
|
||||
def __init__(self, key: str, **kwargs: Any) -> None:
|
||||
def __init__(self, key: Union[str, Mapping[str, str]], **kwargs: Any) -> None:
|
||||
super().__init__(key=key, **kwargs)
|
||||
|
||||
def _put(self, input: Input, *, config: Optional[RunnableConfig] = None) -> None:
|
||||
@ -63,13 +63,13 @@ class PutLocalVar(RunnablePassthrough):
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
|
||||
self._put(input, config=config)
|
||||
return super().invoke(input, config)
|
||||
return super().invoke(input, config=config)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: RunnableConfig | None = None
|
||||
) -> Input:
|
||||
self._put(input, config=config)
|
||||
return await super().ainvoke(input, config)
|
||||
return await super().ainvoke(input, config=config)
|
||||
|
||||
def transform(
|
||||
self, input: Iterator[Input], config: RunnableConfig | None = None
|
||||
@ -102,14 +102,40 @@ class GetLocalVar(
|
||||
def __init__(self, key: str, **kwargs: Any) -> None:
|
||||
super().__init__(key=key, **kwargs)
|
||||
|
||||
def _get(self, full_input: Dict) -> Union[Output, Dict[str, Union[Input, Output]]]:
|
||||
if self.passthrough_key:
|
||||
return {
|
||||
self.key: full_input["locals"][self.key],
|
||||
self.passthrough_key: full_input["input"],
|
||||
}
|
||||
else:
|
||||
return full_input["locals"][self.key]
|
||||
|
||||
async def _aget(
|
||||
self, full_input: Dict
|
||||
) -> Union[Output, Dict[str, Union[Input, Output]]]:
|
||||
return self._get(full_input)
|
||||
|
||||
def invoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Union[Output, Dict[str, Union[Input, Output]]]:
|
||||
if config is None:
|
||||
raise ValueError(
|
||||
"PutLocalVar should only be used in a RunnableSequence, and should "
|
||||
"GetLocalVar should only be used in a RunnableSequence, and should "
|
||||
"therefore always receive a non-null config."
|
||||
)
|
||||
if self.passthrough_key is not None:
|
||||
return {self.key: config["_locals"][self.key], self.passthrough_key: input}
|
||||
return config["_locals"][self.key]
|
||||
|
||||
log_input = {"input": input, "locals": config["_locals"]}
|
||||
return self._call_with_config(self._get, log_input, config)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: Input, config: Optional[RunnableConfig] = None
|
||||
) -> Union[Output, Dict[str, Union[Input, Output]]]:
|
||||
if config is None:
|
||||
raise ValueError(
|
||||
"GetLocalVar should only be used in a RunnableSequence, and should "
|
||||
"therefore always receive a non-null config."
|
||||
)
|
||||
|
||||
log_input = {"input": input, "locals": config["_locals"]}
|
||||
return await self._acall_with_config(self._aget, log_input, config)
|
||||
|
Loading…
Reference in New Issue
Block a user