mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Add async support for transform chain (#8205)
This commit is contained in:
parent
8f158b72fc
commit
3662aca7d4
@ -1,9 +1,16 @@
|
|||||||
"""Chain that runs an arbitrary python function."""
|
"""Chain that runs an arbitrary python function."""
|
||||||
from typing import Callable, Dict, List, Optional
|
import functools
|
||||||
|
import logging
|
||||||
|
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import (
|
||||||
|
AsyncCallbackManagerForChainRun,
|
||||||
|
CallbackManagerForChainRun,
|
||||||
|
)
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TransformChain(Chain):
|
class TransformChain(Chain):
|
||||||
"""Chain transform chain output.
|
"""Chain transform chain output.
|
||||||
@ -17,8 +24,22 @@ class TransformChain(Chain):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
input_variables: List[str]
|
input_variables: List[str]
|
||||||
|
"""The keys expected by the transform's input dictionary."""
|
||||||
output_variables: List[str]
|
output_variables: List[str]
|
||||||
|
"""The keys returned by the transform's output dictionary."""
|
||||||
transform: Callable[[Dict[str, str]], Dict[str, str]]
|
transform: Callable[[Dict[str, str]], Dict[str, str]]
|
||||||
|
"""The transform function."""
|
||||||
|
atransform: Optional[Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]] = None
|
||||||
|
"""The async coroutine transform function."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@functools.lru_cache
|
||||||
|
def _log_once(msg: str) -> None:
|
||||||
|
"""Log a message once.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
logger.warning(msg)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_keys(self) -> List[str]:
|
def input_keys(self) -> List[str]:
|
||||||
@ -42,3 +63,17 @@ class TransformChain(Chain):
|
|||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
) -> Dict[str, str]:
|
) -> Dict[str, str]:
|
||||||
return self.transform(inputs)
|
return self.transform(inputs)
|
||||||
|
|
||||||
|
async def _acall(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
if self.atransform is not None:
|
||||||
|
return await self.atransform(inputs)
|
||||||
|
else:
|
||||||
|
self._log_once(
|
||||||
|
"TransformChain's atransform is not provided, falling"
|
||||||
|
" back to synchronous transform"
|
||||||
|
)
|
||||||
|
return self.transform(inputs)
|
||||||
|
Loading…
Reference in New Issue
Block a user