Add async support for transform chain (#8205)

This commit is contained in:
William FH 2023-07-24 17:45:17 -07:00 committed by GitHub
parent 8f158b72fc
commit 3662aca7d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,9 +1,16 @@
"""Chain that runs an arbitrary python function.""" """Chain that runs an arbitrary python function."""
from typing import Callable, Dict, List, Optional import functools
import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain from langchain.chains.base import Chain
logger = logging.getLogger(__name__)
class TransformChain(Chain): class TransformChain(Chain):
"""Chain transform chain output. """Chain transform chain output.
@ -17,8 +24,22 @@ class TransformChain(Chain):
""" """
input_variables: List[str] input_variables: List[str]
"""The keys expected by the transform's input dictionary."""
output_variables: List[str] output_variables: List[str]
"""The keys returned by the transform's output dictionary."""
transform: Callable[[Dict[str, str]], Dict[str, str]] transform: Callable[[Dict[str, str]], Dict[str, str]]
"""The transform function."""
atransform: Optional[Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]] = None
"""The async coroutine transform function."""
@staticmethod
@functools.lru_cache
def _log_once(msg: str) -> None:
"""Log a message once.
:meta private:
"""
logger.warning(msg)
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:
@ -42,3 +63,17 @@ class TransformChain(Chain):
run_manager: Optional[CallbackManagerForChainRun] = None, run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]: ) -> Dict[str, str]:
return self.transform(inputs) return self.transform(inputs)
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if self.atransform is not None:
return await self.atransform(inputs)
else:
self._log_once(
"TransformChain's atransform is not provided, falling"
" back to synchronous transform"
)
return self.transform(inputs)