diff --git a/langchain/chains/chat_vector_db/base.py b/langchain/chains/chat_vector_db/base.py index f9431baf..c2bbd4df 100644 --- a/langchain/chains/chat_vector_db/base.py +++ b/langchain/chains/chat_vector_db/base.py @@ -32,6 +32,7 @@ class ChatVectorDBChain(Chain, BaseModel): question_generator: LLMChain output_key: str = "answer" return_source_documents: bool = False + top_k_docs_for_context: int = 4 """Return the source documents.""" @property @@ -88,7 +89,9 @@ class ChatVectorDBChain(Chain, BaseModel): ) else: new_question = question - docs = self.vectorstore.similarity_search(new_question, k=4, **vectordbkwargs) + docs = self.vectorstore.similarity_search( + new_question, k=self.top_k_docs_for_context, **vectordbkwargs + ) new_inputs = inputs.copy() new_inputs["question"] = new_question new_inputs["chat_history"] = chat_history_str @@ -109,7 +112,9 @@ class ChatVectorDBChain(Chain, BaseModel): else: new_question = question # TODO: This blocks the event loop, but it's not clear how to avoid it. - docs = self.vectorstore.similarity_search(new_question, k=4, **vectordbkwargs) + docs = self.vectorstore.similarity_search( + new_question, k=self.top_k_docs_for_context, **vectordbkwargs + ) new_inputs = inputs.copy() new_inputs["question"] = new_question new_inputs["chat_history"] = chat_history_str