Add async support to routing chains (#5373)

# Add async support for (LLM) routing chains

<!--
Thank you for contributing to LangChain! Your PR will appear in our
release under the title you set. Please make sure it highlights your
valuable contribution.

Replace this with a description of the change, the issue it fixes (if
applicable), and relevant context. List any dependencies required for
this change.

After you're done, someone will review your PR. They may suggest
improvements. If no one reviews your PR within a few days, feel free to
@-mention the same people again, as notifications can get lost.
-->

<!-- Remove if not applicable -->

Add asynchronous LLM calls support for the routing chains. More
specifically:
- Add async `aroute` function (i.e. async version of `route`) to the
`RouterChain` which calls the routing LLM asynchronously
- Implement the async `_acall` for the `LLMRouterChain`
- Implement the async `_acall` function for `MultiRouteChain` which
first calls asynchronously the routing chain with its new `aroute`
function, and then calls asynchronously the relevant destination chain.

<!-- If you're adding a new integration, please include:

1. a test for the integration - favor unit tests that does not rely on
network access.
2. an example notebook showing its use


See contribution guidelines for more information on how to write tests,
lint
etc:


https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
-->

## Who can review?

- @agola11

<!-- For a quicker response, figure out the right person to tag with @

  @hwchase17 - project lead
  Async
  - @agola11
        
 -->
This commit is contained in:
Louis Amaudruz 2023-05-29 15:37:26 +02:00 committed by GitHub
parent 8b7721ebbb
commit e455ba4ed5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 57 additions and 2 deletions

View File

@ -6,7 +6,11 @@ from typing import Any, Dict, List, Mapping, NamedTuple, Optional
from pydantic import Extra
from langchain.callbacks.manager import CallbackManagerForChainRun, Callbacks
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
Callbacks,
)
from langchain.chains.base import Chain
@ -26,6 +30,12 @@ class RouterChain(Chain, ABC):
result = self(inputs, callbacks=callbacks)
return Route(result["destination"], result["next_inputs"])
async def aroute(
self, inputs: Dict[str, Any], callbacks: Callbacks = None
) -> Route:
result = await self.acall(inputs, callbacks=callbacks)
return Route(result["destination"], result["next_inputs"])
class MultiRouteChain(Chain):
"""Use a single chain to route an input to one of multiple candidate chains."""
@ -86,3 +96,32 @@ class MultiRouteChain(Chain):
raise ValueError(
f"Received invalid destination chain name '{route.destination}'"
)
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
route = await self.router_chain.aroute(inputs, callbacks=callbacks)
_run_manager.on_text(
str(route.destination) + ": " + str(route.next_inputs), verbose=self.verbose
)
if not route.destination:
return await self.default_chain.acall(
route.next_inputs, callbacks=callbacks
)
elif route.destination in self.destination_chains:
return await self.destination_chains[route.destination].acall(
route.next_inputs, callbacks=callbacks
)
elif self.silent_errors:
return await self.default_chain.acall(
route.next_inputs, callbacks=callbacks
)
else:
raise ValueError(
f"Received invalid destination chain name '{route.destination}'"
)

View File

@ -6,7 +6,10 @@ from typing import Any, Dict, List, Optional, Type, cast
from pydantic import root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains import LLMChain
from langchain.chains.router.base import RouterChain
from langchain.output_parsers.json import parse_and_check_json_markdown
@ -58,6 +61,19 @@ class LLMRouterChain(RouterChain):
)
return output
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
output = cast(
Dict[str, Any],
await self.llm_chain.apredict_and_parse(callbacks=callbacks, **inputs),
)
return output
@classmethod
def from_llm(
cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any