Change type annotations from LLMChain to Chain in MultiPromptChain (#11082)

- **Description:** The types of 'destination_chains' and 'default_chain'
in 'MultiPromptChain' were changed from 'LLMChain' to 'Chain'. and
removed variables declared overlapping with the parent class
- **Issue:** When a class that inherits only Chain and not LLMChain,
such as 'SequentialChain' or 'RetrievalQA', is entered in
'destination_chains' and 'default_chain', a pydantic validation error is
raised.
-  -  codes
```
retrieval_chain = ConversationalRetrievalChain(
        retriever=doc_retriever,
        combine_docs_chain=combine_docs_chain,
        question_generator=question_gen_chain,
    )
    
    destination_chains = {
        'retrieval': retrieval_chain,
    }
    
    main_chain = MultiPromptChain(
        router_chain=router_chain,
        destination_chains=destination_chains,
        default_chain=default_chain,
        verbose=True,
    )
```

 `make format`, `make lint` and `make test`
This commit is contained in:
Michael Kim 2023-09-29 07:59:25 +09:00 committed by GitHub
parent 8ed013d278
commit fbcd8e02f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,11 +1,12 @@
"""Use a single chain to route an input to one of multiple llm chains.""" """Use a single chain to route an input to one of multiple llm chains."""
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Mapping, Optional from typing import Any, Dict, List, Optional
from langchain.chains import ConversationChain from langchain.chains import ConversationChain
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.router.base import MultiRouteChain, RouterChain from langchain.chains.router.base import MultiRouteChain
from langchain.chains.router.llm_router import LLMRouterChain, RouterOutputParser from langchain.chains.router.llm_router import LLMRouterChain, RouterOutputParser
from langchain.chains.router.multi_prompt_prompt import MULTI_PROMPT_ROUTER_TEMPLATE from langchain.chains.router.multi_prompt_prompt import MULTI_PROMPT_ROUTER_TEMPLATE
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
@ -15,13 +16,6 @@ from langchain.schema.language_model import BaseLanguageModel
class MultiPromptChain(MultiRouteChain): class MultiPromptChain(MultiRouteChain):
"""A multi-route chain that uses an LLM router chain to choose amongst prompts.""" """A multi-route chain that uses an LLM router chain to choose amongst prompts."""
router_chain: RouterChain
"""Chain for deciding a destination chain and the input to it."""
destination_chains: Mapping[str, LLMChain]
"""Map of name to candidate chains that inputs can be routed to."""
default_chain: LLMChain
"""Default chain to use when router doesn't map input to one of the destinations."""
@property @property
def output_keys(self) -> List[str]: def output_keys(self) -> List[str]:
return ["text"] return ["text"]
@ -31,7 +25,7 @@ class MultiPromptChain(MultiRouteChain):
cls, cls,
llm: BaseLanguageModel, llm: BaseLanguageModel,
prompt_infos: List[Dict[str, str]], prompt_infos: List[Dict[str, str]],
default_chain: Optional[LLMChain] = None, default_chain: Optional[Chain] = None,
**kwargs: Any, **kwargs: Any,
) -> MultiPromptChain: ) -> MultiPromptChain:
"""Convenience constructor for instantiating from destination prompts.""" """Convenience constructor for instantiating from destination prompts."""