From b3916f74a754cb7712818986197aa59b229ae05b Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 30 Jan 2023 14:48:24 -0800 Subject: [PATCH] enable mmr search (#807) --- langchain/chains/vector_db_qa/base.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/langchain/chains/vector_db_qa/base.py b/langchain/chains/vector_db_qa/base.py index 73bd4c57..3a30a67f 100644 --- a/langchain/chains/vector_db_qa/base.py +++ b/langchain/chains/vector_db_qa/base.py @@ -41,6 +41,8 @@ class VectorDBQA(Chain, BaseModel): """Return the source documents.""" search_kwargs: Dict[str, Any] = Field(default_factory=dict) """Extra search args.""" + search_type: str = "similarity" + """Search type to use over vectorstore. `similarity` or `mmr`.""" class Config: """Configuration for this pydantic object.""" @@ -90,6 +92,15 @@ class VectorDBQA(Chain, BaseModel): values["combine_documents_chain"] = combine_documents_chain return values + @root_validator() + def validate_search_type(cls, values: Dict) -> Dict: + """Validate search type.""" + if "search_type" in values: + search_type = values["search_type"] + if search_type not in ("similarity", "mmr"): + raise ValueError(f"search_type of {search_type} not allowed.") + return values + @classmethod def from_llm( cls, llm: BaseLLM, prompt: PromptTemplate = PROMPT, **kwargs: Any @@ -129,9 +140,16 @@ class VectorDBQA(Chain, BaseModel): """ question = inputs[self.input_key] - docs = self.vectorstore.similarity_search( - question, k=self.k, **self.search_kwargs - ) + if self.search_type == "similarity": + docs = self.vectorstore.similarity_search( + question, k=self.k, **self.search_kwargs + ) + elif self.search_type == "mmr": + docs = self.vectorstore.max_marginal_relevance_search( + question, k=self.k, **self.search_kwargs + ) + else: + raise ValueError(f"search_type of {self.search_type} not allowed.") answer, _ = self.combine_documents_chain.combine_docs(docs, question=question) if self.return_source_documents: