From e455ba4ed5c683fb9903fb19c980ae25010c91bb Mon Sep 17 00:00:00 2001 From: Louis Amaudruz Date: Mon, 29 May 2023 15:37:26 +0200 Subject: [PATCH] Add async support to routing chains (#5373) # Add async support for (LLM) routing chains Add asynchronous LLM calls support for the routing chains. More specifically: - Add async `aroute` function (i.e. async version of `route`) to the `RouterChain` which calls the routing LLM asynchronously - Implement the async `_acall` for the `LLMRouterChain` - Implement the async `_acall` function for `MultiRouteChain` which first calls asynchronously the routing chain with its new `aroute` function, and then calls asynchronously the relevant destination chain. ## Who can review? - @agola11 --- langchain/chains/router/base.py | 41 ++++++++++++++++++++++++++- langchain/chains/router/llm_router.py | 18 +++++++++++- 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/langchain/chains/router/base.py b/langchain/chains/router/base.py index 9cb44d51..fe948a56 100644 --- a/langchain/chains/router/base.py +++ b/langchain/chains/router/base.py @@ -6,7 +6,11 @@ from typing import Any, Dict, List, Mapping, NamedTuple, Optional from pydantic import Extra -from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, + Callbacks, +) from langchain.chains.base import Chain @@ -26,6 +30,12 @@ class RouterChain(Chain, ABC): result = self(inputs, callbacks=callbacks) return Route(result["destination"], result["next_inputs"]) + async def aroute( + self, inputs: Dict[str, Any], callbacks: Callbacks = None + ) -> Route: + result = await self.acall(inputs, callbacks=callbacks) + return Route(result["destination"], result["next_inputs"]) + class MultiRouteChain(Chain): """Use a single chain to route an input to one of multiple candidate chains.""" @@ -86,3 +96,32 @@ class MultiRouteChain(Chain): raise ValueError( f"Received invalid destination chain name '{route.destination}'" ) + + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + route = await self.router_chain.aroute(inputs, callbacks=callbacks) + + _run_manager.on_text( + str(route.destination) + ": " + str(route.next_inputs), verbose=self.verbose + ) + if not route.destination: + return await self.default_chain.acall( + route.next_inputs, callbacks=callbacks + ) + elif route.destination in self.destination_chains: + return await self.destination_chains[route.destination].acall( + route.next_inputs, callbacks=callbacks + ) + elif self.silent_errors: + return await self.default_chain.acall( + route.next_inputs, callbacks=callbacks + ) + else: + raise ValueError( + f"Received invalid destination chain name '{route.destination}'" + ) diff --git a/langchain/chains/router/llm_router.py b/langchain/chains/router/llm_router.py index 3276324f..cf8392c1 100644 --- a/langchain/chains/router/llm_router.py +++ b/langchain/chains/router/llm_router.py @@ -6,7 +6,10 @@ from typing import Any, Dict, List, Optional, Type, cast from pydantic import root_validator from langchain.base_language import BaseLanguageModel -from langchain.callbacks.manager import CallbackManagerForChainRun +from langchain.callbacks.manager import ( + AsyncCallbackManagerForChainRun, + CallbackManagerForChainRun, +) from langchain.chains import LLMChain from langchain.chains.router.base import RouterChain from langchain.output_parsers.json import parse_and_check_json_markdown @@ -58,6 +61,19 @@ class LLMRouterChain(RouterChain): ) return output + async def _acall( + self, + inputs: Dict[str, Any], + run_manager: Optional[AsyncCallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + output = cast( + Dict[str, Any], + await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs), + ) + return output + @classmethod def from_llm( cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any