diff --git a/langchain/chains/chat_vector_db/base.py b/langchain/chains/chat_vector_db/base.py index 76789e6c..f9431baf 100644 --- a/langchain/chains/chat_vector_db/base.py +++ b/langchain/chains/chat_vector_db/base.py @@ -81,13 +81,14 @@ class ChatVectorDBChain(Chain, BaseModel): def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: question = inputs["question"] chat_history_str = _get_chat_history(inputs["chat_history"]) + vectordbkwargs = inputs.get("vectordbkwargs", {}) if chat_history_str: new_question = self.question_generator.run( question=question, chat_history=chat_history_str ) else: new_question = question - docs = self.vectorstore.similarity_search(new_question, k=4) + docs = self.vectorstore.similarity_search(new_question, k=4, **vectordbkwargs) new_inputs = inputs.copy() new_inputs["question"] = new_question new_inputs["chat_history"] = chat_history_str @@ -100,6 +101,7 @@ class ChatVectorDBChain(Chain, BaseModel): async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]: question = inputs["question"] chat_history_str = _get_chat_history(inputs["chat_history"]) + vectordbkwargs = inputs.get("vectordbkwargs", {}) if chat_history_str: new_question = await self.question_generator.arun( question=question, chat_history=chat_history_str @@ -107,7 +109,7 @@ class ChatVectorDBChain(Chain, BaseModel): else: new_question = question # TODO: This blocks the event loop, but it's not clear how to avoid it. - docs = self.vectorstore.similarity_search(new_question, k=4) + docs = self.vectorstore.similarity_search(new_question, k=4, **vectordbkwargs) new_inputs = inputs.copy() new_inputs["question"] = new_question new_inputs["chat_history"] = chat_history_str