From 3662aca7d45b684f09f3c3655aa848253b016b75 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Mon, 24 Jul 2023 17:45:17 -0700 Subject: [PATCH] Add async support for transform chain (#8205) --- libs/langchain/langchain/chains/transform.py | 39 +++++++++++++++++++- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/chains/transform.py b/libs/langchain/langchain/chains/transform.py index 90947b2b69..13e1e65aaa 100644 --- a/libs/langchain/langchain/chains/transform.py +++ b/libs/langchain/langchain/chains/transform.py @@ -1,9 +1,16 @@ """Chain that runs an arbitrary python function.""" -from typing import Callable, Dict, List, Optional +import functools +import logging +from typing import Any, Awaitable, Callable, Dict, List, Optional -from langchain.callbacks.manager import CallbackManagerForChainRun +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) from langchain.chains.base import Chain +logger = logging.getLogger(__name__) + class TransformChain(Chain): """Chain transform chain output. @@ -17,8 +24,22 @@ class TransformChain(Chain): """ input_variables: List[str] + """The keys expected by the transform's input dictionary.""" output_variables: List[str] + """The keys returned by the transform's output dictionary.""" transform: Callable[[Dict[str, str]], Dict[str, str]] + """The transform function.""" + atransform: Optional[Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]] = None + """The async coroutine transform function.""" + + @staticmethod + @functools.lru_cache + def _log_once(msg: str) -> None: + """Log a message once. + + :meta private: + """ + logger.warning(msg) @property def input_keys(self) -> List[str]: @@ -42,3 +63,17 @@ class TransformChain(Chain): run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, str]: return self.transform(inputs) + + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + if self.atransform is not None: + return await self.atransform(inputs) + else: + self._log_once( + "TransformChain's atransform is not provided, falling" + " back to synchronous transform" + ) + return self.transform(inputs)