forked from Archives/langchain
Add async support to routing chains (#5373)
# Add async support for (LLM) routing chains <!-- Thank you for contributing to LangChain! Your PR will appear in our release under the title you set. Please make sure it highlights your valuable contribution. Replace this with a description of the change, the issue it fixes (if applicable), and relevant context. List any dependencies required for this change. After you're done, someone will review your PR. They may suggest improvements. If no one reviews your PR within a few days, feel free to @-mention the same people again, as notifications can get lost. --> <!-- Remove if not applicable --> Add asynchronous LLM calls support for the routing chains. More specifically: - Add async `aroute` function (i.e. async version of `route`) to the `RouterChain` which calls the routing LLM asynchronously - Implement the async `_acall` for the `LLMRouterChain` - Implement the async `_acall` function for `MultiRouteChain` which first calls asynchronously the routing chain with its new `aroute` function, and then calls asynchronously the relevant destination chain. <!-- If you're adding a new integration, please include: 1. a test for the integration - favor unit tests that does not rely on network access. 2. an example notebook showing its use See contribution guidelines for more information on how to write tests, lint etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md --> ## Who can review? - @agola11 <!-- For a quicker response, figure out the right person to tag with @ @hwchase17 - project lead Async - @agola11 -->
This commit is contained in:
parent
8b7721ebbb
commit
e455ba4ed5
@ -6,7 +6,11 @@ from typing import Any, Dict, List, Mapping, NamedTuple, Optional
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
|
||||
@ -26,6 +30,12 @@ class RouterChain(Chain, ABC):
|
||||
result = self(inputs, callbacks=callbacks)
|
||||
return Route(result["destination"], result["next_inputs"])
|
||||
|
||||
async def aroute(
|
||||
self, inputs: Dict[str, Any], callbacks: Callbacks = None
|
||||
) -> Route:
|
||||
result = await self.acall(inputs, callbacks=callbacks)
|
||||
return Route(result["destination"], result["next_inputs"])
|
||||
|
||||
|
||||
class MultiRouteChain(Chain):
|
||||
"""Use a single chain to route an input to one of multiple candidate chains."""
|
||||
@ -86,3 +96,32 @@ class MultiRouteChain(Chain):
|
||||
raise ValueError(
|
||||
f"Received invalid destination chain name '{route.destination}'"
|
||||
)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
route = await self.router_chain.aroute(inputs, callbacks=callbacks)
|
||||
|
||||
_run_manager.on_text(
|
||||
str(route.destination) + ": " + str(route.next_inputs), verbose=self.verbose
|
||||
)
|
||||
if not route.destination:
|
||||
return await self.default_chain.acall(
|
||||
route.next_inputs, callbacks=callbacks
|
||||
)
|
||||
elif route.destination in self.destination_chains:
|
||||
return await self.destination_chains[route.destination].acall(
|
||||
route.next_inputs, callbacks=callbacks
|
||||
)
|
||||
elif self.silent_errors:
|
||||
return await self.default_chain.acall(
|
||||
route.next_inputs, callbacks=callbacks
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Received invalid destination chain name '{route.destination}'"
|
||||
)
|
||||
|
@ -6,7 +6,10 @@ from typing import Any, Dict, List, Optional, Type, cast
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.chains.router.base import RouterChain
|
||||
from langchain.output_parsers.json import parse_and_check_json_markdown
|
||||
@ -58,6 +61,19 @@ class LLMRouterChain(RouterChain):
|
||||
)
|
||||
return output
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
output = cast(
|
||||
Dict[str, Any],
|
||||
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
|
||||
)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any
|
||||
|
Loading…
Reference in New Issue
Block a user