Use passed LLM for default chain in MultiPromptChain (#4418)

Currently, MultiPromptChain instantiates a ChatOpenAI LLM instance for
the default chain to use if none of the prompts passed match. This seems
like an error as it means that you can't use your choice of LLM, or
configure how to instantiate the default LLM (e.g. passing in an API key
that isn't in the usual env variable).
This commit is contained in:
jrhe 2023-05-10 00:15:25 +01:00 committed by GitHub
parent 5c8e12558d
commit 28091c2101
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -9,7 +9,6 @@ from langchain.chains.llm import LLMChain
from langchain.chains.router.base import MultiRouteChain from langchain.chains.router.base import MultiRouteChain
from langchain.chains.router.llm_router import LLMRouterChain, RouterOutputParser from langchain.chains.router.llm_router import LLMRouterChain, RouterOutputParser
from langchain.chains.router.multi_prompt_prompt import MULTI_PROMPT_ROUTER_TEMPLATE from langchain.chains.router.multi_prompt_prompt import MULTI_PROMPT_ROUTER_TEMPLATE
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
@ -54,9 +53,7 @@ class MultiPromptChain(MultiRouteChain):
prompt = PromptTemplate(template=prompt_template, input_variables=["input"]) prompt = PromptTemplate(template=prompt_template, input_variables=["input"])
chain = LLMChain(llm=llm, prompt=prompt) chain = LLMChain(llm=llm, prompt=prompt)
destination_chains[name] = chain destination_chains[name] = chain
_default_chain = default_chain or ConversationChain( _default_chain = default_chain or ConversationChain(llm=llm, output_key="text")
llm=ChatOpenAI(), output_key="text"
)
return cls( return cls(
router_chain=router_chain, router_chain=router_chain,
destination_chains=destination_chains, destination_chains=destination_chains,