Add async support for transform chain (#8205)

pull/8206/head^2
William FH 1 year ago committed by GitHub
parent 8f158b72fc
commit 3662aca7d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,9 +1,16 @@
"""Chain that runs an arbitrary python function."""
from typing import Callable, Dict, List, Optional
import functools
import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
logger = logging.getLogger(__name__)
class TransformChain(Chain):
"""Chain transform chain output.
@ -17,8 +24,22 @@ class TransformChain(Chain):
"""
input_variables: List[str]
"""The keys expected by the transform's input dictionary."""
output_variables: List[str]
"""The keys returned by the transform's output dictionary."""
transform: Callable[[Dict[str, str]], Dict[str, str]]
"""The transform function."""
atransform: Optional[Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]] = None
"""The async coroutine transform function."""
@staticmethod
@functools.lru_cache
def _log_once(msg: str) -> None:
"""Log a message once.
:meta private:
"""
logger.warning(msg)
@property
def input_keys(self) -> List[str]:
@ -42,3 +63,17 @@ class TransformChain(Chain):
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
return self.transform(inputs)
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if self.atransform is not None:
return await self.atransform(inputs)
else:
self._log_once(
"TransformChain's atransform is not provided, falling"
" back to synchronous transform"
)
return self.transform(inputs)

Loading…
Cancel
Save