You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

204 lines
6.5 KiB

from enum import Enum
from typing import Any, Dict, List, Optional, Union
import numpy as np
from pydantic import BaseModel
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
from langchain.vectorstores.utils import maximal_marginal_relevance
class SearchType(str, Enum):
similarity = "similarity"
mmr = "mmr"
class DocArrayRetriever(BaseRetriever, BaseModel):
Retriever class for DocArray Document Indices.
Currently, supports 5 backends:
InMemoryExactNNIndex, HnswDocumentIndex, QdrantDocumentIndex,
ElasticDocIndex, and WeaviateDocumentIndex.
index: One of the above-mentioned index instances
embeddings: Embedding model to represent text as vectors
search_field: Field to consider for searching in the documents.
Should be an embedding/vector/tensor.
content_field: Field that represents the main content in your document schema.
Will be used as a `page_content`. Everything else will go into `metadata`.
search_type: Type of search to perform (similarity / mmr)
filters: Filters applied for document retrieval.
top_k: Number of documents to return
index: Any
embeddings: Embeddings
search_field: str
content_field: str
search_type: SearchType = SearchType.similarity
top_k: int = 1
filters: Optional[Any] = None
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def get_relevant_documents(self, query: str) -> List[Document]:
"""Get documents relevant for a query.
query: string to find relevant documents for
List of relevant documents
query_emb = np.array(self.embeddings.embed_query(query))
if self.search_type == SearchType.similarity:
results = self._similarity_search(query_emb)
elif self.search_type == SearchType.mmr:
results = self._mmr_search(query_emb)
raise ValueError(
f"Search type {self.search_type} does not exist. "
f"Choose either 'similarity' or 'mmr'."
return results
def _search(
self, query_emb: np.ndarray, top_k: int
) -> List[Union[Dict[str, Any], Any]]:
Perform a search using the query embedding and return top_k documents.
query_emb: Query represented as an embedding
top_k: Number of documents to return
A list of top_k documents matching the query
from docarray.index import ElasticDocIndex, WeaviateDocumentIndex
filter_args = {}
search_field = self.search_field
if isinstance(self.index, WeaviateDocumentIndex):
filter_args["where_filter"] = self.filters
search_field = ""
elif isinstance(self.index, ElasticDocIndex):
filter_args["query"] = self.filters
filter_args["filter_query"] = self.filters
if self.filters:
query = (
self.index.build_query() # get empty query object
query=query_emb, search_field=search_field
) # add vector similarity search
.filter(**filter_args) # add filter search
.build(limit=top_k) # build the query
# execute the combined query and return the results
docs = self.index.execute_query(query)
if hasattr(docs, "documents"):
docs = docs.documents
docs = docs[:top_k]
docs = self.index.find(
query=query_emb, search_field=search_field, limit=top_k
return docs
def _similarity_search(self, query_emb: np.ndarray) -> List[Document]:
Perform a similarity search.
query_emb: Query represented as an embedding
A list of documents most similar to the query
docs = self._search(query_emb=query_emb, top_k=self.top_k)
results = [self._docarray_to_langchain_doc(doc) for doc in docs]
return results
def _mmr_search(self, query_emb: np.ndarray) -> List[Document]:
Perform a maximal marginal relevance (mmr) search.
query_emb: Query represented as an embedding
A list of diverse documents related to the query
docs = self._search(query_emb=query_emb, top_k=20)
mmr_selected = maximal_marginal_relevance(
if isinstance(doc, dict)
else getattr(doc, self.search_field)
for doc in docs
results = [self._docarray_to_langchain_doc(docs[idx]) for idx in mmr_selected]
return results
def _docarray_to_langchain_doc(self, doc: Union[Dict[str, Any], Any]) -> Document:
Convert a DocArray document (which also might be a dict)
to a langchain document format.
DocArray document can contain arbitrary fields, so the mapping is done
in the following way:
page_content <-> content_field
metadata <-> all other fields excluding
tensors and embeddings (so float, int, string)
doc: DocArray document
Document in langchain format
ValueError: If the document doesn't contain the content field
fields = doc.keys() if isinstance(doc, dict) else doc.__fields__
if self.content_field not in fields:
raise ValueError(
f"Document does not contain the content field - {self.content_field}."
lc_doc = Document(
if isinstance(doc, dict)
else getattr(doc, self.content_field)
for name in fields:
value = doc[name] if isinstance(doc, dict) else getattr(doc, name)
if (
isinstance(value, (str, int, float, bool))
and name != self.content_field
lc_doc.metadata[name] = value
return lc_doc
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError