add search kwargs (#664)

pull/665/head
Harrison Chase 2 years ago committed by GitHub
parent 65f3a341b0
commit 983b73f47c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,7 +1,7 @@
"""Question-answering with sources over a vector database."""
from typing import Any, Dict, List
from pydantic import BaseModel
from pydantic import BaseModel, Field
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
from langchain.docstore.document import Document
@ -15,8 +15,8 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel):
"""Vector Database to connect to."""
k: int = 4
"""Number of results to return from store"""
search_kwargs: Dict[str, Any] = {}
"""Extra search args"""
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Extra search args."""
def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]:
question = inputs[self.question_key]

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import Any, Dict, List
from pydantic import BaseModel, Extra, root_validator
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
@ -39,6 +39,8 @@ class VectorDBQA(Chain, BaseModel):
output_key: str = "result" #: :meta private:
return_source_documents: bool = False
"""Return the source documents."""
search_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Extra search args."""
class Config:
"""Configuration for this pydantic object."""
@ -127,7 +129,9 @@ class VectorDBQA(Chain, BaseModel):
"""
question = inputs[self.input_key]
docs = self.vectorstore.similarity_search(question, k=self.k)
docs = self.vectorstore.similarity_search(
question, k=self.k, **self.search_kwargs
)
answer, _ = self.combine_documents_chain.combine_docs(docs, question=question)
if self.return_source_documents:

@ -26,7 +26,9 @@ class VectorStore(ABC):
"""
@abstractmethod
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to query."""
def max_marginal_relevance_search(

@ -106,7 +106,9 @@ class ElasticVectorSearch(VectorStore):
self.client.indices.refresh(index=self.index_name)
return ids
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to query.
Args:

@ -103,7 +103,9 @@ class FAISS(VectorStore):
docs.append((doc, scores[0][j]))
return docs
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to query.
Args:

@ -120,6 +120,7 @@ class Pinecone(VectorStore):
k: int = 5,
filter: Optional[dict] = None,
namespace: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Return pinecone documents most similar to query.

@ -71,7 +71,9 @@ class Weaviate(VectorStore):
ids.append(_id)
return ids
def similarity_search(self, query: str, k: int = 4) -> List[Document]:
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Look up similar documents in weaviate."""
content = {"concepts": [query]}
query_obj = self._client.query.get(self._index_name, self._query_attrs)

Loading…
Cancel
Save