mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Change type annotations from LLMChain to Chain in MultiPromptChain (#11082)
- **Description:** The types of 'destination_chains' and 'default_chain'
in 'MultiPromptChain' were changed from 'LLMChain' to 'Chain'. and
removed variables declared overlapping with the parent class
- **Issue:** When a class that inherits only Chain and not LLMChain,
such as 'SequentialChain' or 'RetrievalQA', is entered in
'destination_chains' and 'default_chain', a pydantic validation error is
raised.
- - codes
```
retrieval_chain = ConversationalRetrievalChain(
retriever=doc_retriever,
combine_docs_chain=combine_docs_chain,
question_generator=question_gen_chain,
)
destination_chains = {
'retrieval': retrieval_chain,
}
main_chain = MultiPromptChain(
router_chain=router_chain,
destination_chains=destination_chains,
default_chain=default_chain,
verbose=True,
)
```
✅ `make format`, `make lint` and `make test`
This commit is contained in:
parent
8ed013d278
commit
fbcd8e02f2
@ -1,11 +1,12 @@
|
|||||||
"""Use a single chain to route an input to one of multiple llm chains."""
|
"""Use a single chain to route an input to one of multiple llm chains."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from langchain.chains import ConversationChain
|
from langchain.chains import ConversationChain
|
||||||
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.chains.router.base import MultiRouteChain, RouterChain
|
from langchain.chains.router.base import MultiRouteChain
|
||||||
from langchain.chains.router.llm_router import LLMRouterChain, RouterOutputParser
|
from langchain.chains.router.llm_router import LLMRouterChain, RouterOutputParser
|
||||||
from langchain.chains.router.multi_prompt_prompt import MULTI_PROMPT_ROUTER_TEMPLATE
|
from langchain.chains.router.multi_prompt_prompt import MULTI_PROMPT_ROUTER_TEMPLATE
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
@ -15,13 +16,6 @@ from langchain.schema.language_model import BaseLanguageModel
|
|||||||
class MultiPromptChain(MultiRouteChain):
|
class MultiPromptChain(MultiRouteChain):
|
||||||
"""A multi-route chain that uses an LLM router chain to choose amongst prompts."""
|
"""A multi-route chain that uses an LLM router chain to choose amongst prompts."""
|
||||||
|
|
||||||
router_chain: RouterChain
|
|
||||||
"""Chain for deciding a destination chain and the input to it."""
|
|
||||||
destination_chains: Mapping[str, LLMChain]
|
|
||||||
"""Map of name to candidate chains that inputs can be routed to."""
|
|
||||||
default_chain: LLMChain
|
|
||||||
"""Default chain to use when router doesn't map input to one of the destinations."""
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_keys(self) -> List[str]:
|
def output_keys(self) -> List[str]:
|
||||||
return ["text"]
|
return ["text"]
|
||||||
@ -31,7 +25,7 @@ class MultiPromptChain(MultiRouteChain):
|
|||||||
cls,
|
cls,
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
prompt_infos: List[Dict[str, str]],
|
prompt_infos: List[Dict[str, str]],
|
||||||
default_chain: Optional[LLMChain] = None,
|
default_chain: Optional[Chain] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> MultiPromptChain:
|
) -> MultiPromptChain:
|
||||||
"""Convenience constructor for instantiating from destination prompts."""
|
"""Convenience constructor for instantiating from destination prompts."""
|
||||||
|
Loading…
Reference in New Issue
Block a user