Change type annotations from LLMChain to Chain in MultiPromptChain (#11082)

- **Description:** The types of 'destination_chains' and 'default_chain'
in 'MultiPromptChain' were changed from 'LLMChain' to 'Chain'. and
removed variables declared overlapping with the parent class
- **Issue:** When a class that inherits only Chain and not LLMChain,
such as 'SequentialChain' or 'RetrievalQA', is entered in
'destination_chains' and 'default_chain', a pydantic validation error is
raised.
-  -  codes
```
retrieval_chain = ConversationalRetrievalChain(
        retriever=doc_retriever,
        combine_docs_chain=combine_docs_chain,
        question_generator=question_gen_chain,
    )
    
    destination_chains = {
        'retrieval': retrieval_chain,
    }
    
    main_chain = MultiPromptChain(
        router_chain=router_chain,
        destination_chains=destination_chains,
        default_chain=default_chain,
        verbose=True,
    )
```

 `make format`, `make lint` and `make test`
pull/11202/head
Michael Kim 11 months ago committed by GitHub
parent 8ed013d278
commit fbcd8e02f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,11 +1,12 @@
"""Use a single chain to route an input to one of multiple llm chains."""
from __future__ import annotations
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Dict, List, Optional
from langchain.chains import ConversationChain
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.router.base import MultiRouteChain, RouterChain
from langchain.chains.router.base import MultiRouteChain
from langchain.chains.router.llm_router import LLMRouterChain, RouterOutputParser
from langchain.chains.router.multi_prompt_prompt import MULTI_PROMPT_ROUTER_TEMPLATE
from langchain.prompts import PromptTemplate
@ -15,13 +16,6 @@ from langchain.schema.language_model import BaseLanguageModel
class MultiPromptChain(MultiRouteChain):
"""A multi-route chain that uses an LLM router chain to choose amongst prompts."""
router_chain: RouterChain
"""Chain for deciding a destination chain and the input to it."""
destination_chains: Mapping[str, LLMChain]
"""Map of name to candidate chains that inputs can be routed to."""
default_chain: LLMChain
"""Default chain to use when router doesn't map input to one of the destinations."""
@property
def output_keys(self) -> List[str]:
return ["text"]
@ -31,7 +25,7 @@ class MultiPromptChain(MultiRouteChain):
cls,
llm: BaseLanguageModel,
prompt_infos: List[Dict[str, str]],
default_chain: Optional[LLMChain] = None,
default_chain: Optional[Chain] = None,
**kwargs: Any,
) -> MultiPromptChain:
"""Convenience constructor for instantiating from destination prompts."""

Loading…
Cancel
Save