From 2cbd41145c9f00128bac1ac5b5624487b9a82aed Mon Sep 17 00:00:00 2001 From: Davis Chase <130488702+dev2049@users.noreply.github.com> Date: Mon, 24 Apr 2023 12:13:06 -0700 Subject: [PATCH] Bugfix: Not all combine docs chains takes kwargs `prompt` (#3462) Generalize ConversationalRetrievalChain.from_llm kwargs --------- Co-authored-by: shubham.suneja --- langchain/chains/conversational_retrieval/base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/langchain/chains/conversational_retrieval/base.py b/langchain/chains/conversational_retrieval/base.py index b7fb299e..900d8e7c 100644 --- a/langchain/chains/conversational_retrieval/base.py +++ b/langchain/chains/conversational_retrieval/base.py @@ -172,15 +172,16 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain): llm: BaseLanguageModel, retriever: BaseRetriever, condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT, - qa_prompt: Optional[BasePromptTemplate] = None, chain_type: str = "stuff", + combine_docs_chain_kwargs: Optional[Dict] = None, **kwargs: Any, ) -> BaseConversationalRetrievalChain: """Load chain from LLM.""" + combine_docs_chain_kwargs = combine_docs_chain_kwargs or {} doc_chain = load_qa_chain( llm, chain_type=chain_type, - prompt=qa_prompt, + **combine_docs_chain_kwargs, ) condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt) return cls( @@ -226,15 +227,16 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain): llm: BaseLanguageModel, vectorstore: VectorStore, condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT, - qa_prompt: Optional[BasePromptTemplate] = None, chain_type: str = "stuff", + combine_docs_chain_kwargs: Optional[Dict] = None, **kwargs: Any, ) -> BaseConversationalRetrievalChain: """Load chain from LLM.""" + combine_docs_chain_kwargs = combine_docs_chain_kwargs or {} doc_chain = load_qa_chain( llm, chain_type=chain_type, - prompt=qa_prompt, + **combine_docs_chain_kwargs, ) condense_question_chain = LLMChain(llm=llm, prompt=condense_question_prompt) return cls(