From fbcd8e02f25259abec03c4ff1008547d44854f77 Mon Sep 17 00:00:00 2001 From: Michael Kim <59414764+xcellentbird@users.noreply.github.com> Date: Fri, 29 Sep 2023 07:59:25 +0900 Subject: [PATCH] Change type annotations from LLMChain to Chain in MultiPromptChain (#11082) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - **Description:** The types of 'destination_chains' and 'default_chain' in 'MultiPromptChain' were changed from 'LLMChain' to 'Chain'. and removed variables declared overlapping with the parent class - **Issue:** When a class that inherits only Chain and not LLMChain, such as 'SequentialChain' or 'RetrievalQA', is entered in 'destination_chains' and 'default_chain', a pydantic validation error is raised. - - codes ``` retrieval_chain = ConversationalRetrievalChain( retriever=doc_retriever, combine_docs_chain=combine_docs_chain, question_generator=question_gen_chain, ) destination_chains = { 'retrieval': retrieval_chain, } main_chain = MultiPromptChain( router_chain=router_chain, destination_chains=destination_chains, default_chain=default_chain, verbose=True, ) ``` ✅ `make format`, `make lint` and `make test` --- .../langchain/chains/router/multi_prompt.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/libs/langchain/langchain/chains/router/multi_prompt.py b/libs/langchain/langchain/chains/router/multi_prompt.py index 240d9018c8..f4031b968f 100644 --- a/libs/langchain/langchain/chains/router/multi_prompt.py +++ b/libs/langchain/langchain/chains/router/multi_prompt.py @@ -1,11 +1,12 @@ """Use a single chain to route an input to one of multiple llm chains.""" from __future__ import annotations -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, List, Optional from langchain.chains import ConversationChain +from langchain.chains.base import Chain from langchain.chains.llm import LLMChain -from langchain.chains.router.base import MultiRouteChain, RouterChain +from langchain.chains.router.base import MultiRouteChain from langchain.chains.router.llm_router import LLMRouterChain, RouterOutputParser from langchain.chains.router.multi_prompt_prompt import MULTI_PROMPT_ROUTER_TEMPLATE from langchain.prompts import PromptTemplate @@ -15,13 +16,6 @@ from langchain.schema.language_model import BaseLanguageModel class MultiPromptChain(MultiRouteChain): """A multi-route chain that uses an LLM router chain to choose amongst prompts.""" - router_chain: RouterChain - """Chain for deciding a destination chain and the input to it.""" - destination_chains: Mapping[str, LLMChain] - """Map of name to candidate chains that inputs can be routed to.""" - default_chain: LLMChain - """Default chain to use when router doesn't map input to one of the destinations.""" - @property def output_keys(self) -> List[str]: return ["text"] @@ -31,7 +25,7 @@ class MultiPromptChain(MultiRouteChain): cls, llm: BaseLanguageModel, prompt_infos: List[Dict[str, str]], - default_chain: Optional[LLMChain] = None, + default_chain: Optional[Chain] = None, **kwargs: Any, ) -> MultiPromptChain: """Convenience constructor for instantiating from destination prompts."""