mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
208 lines
6.6 KiB
Python
208 lines
6.6 KiB
Python
|
from enum import Enum
|
||
|
from typing import Any, Dict, List, Optional, Union
|
||
|
|
||
|
import numpy as np
|
||
|
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||
|
from langchain_core.documents import Document
|
||
|
from langchain_core.embeddings import Embeddings
|
||
|
from langchain_core.retrievers import BaseRetriever
|
||
|
|
||
|
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||
|
|
||
|
|
||
|
class SearchType(str, Enum):
|
||
|
"""Enumerator of the types of search to perform."""
|
||
|
|
||
|
similarity = "similarity"
|
||
|
mmr = "mmr"
|
||
|
|
||
|
|
||
|
class DocArrayRetriever(BaseRetriever):
|
||
|
"""`DocArray Document Indices` retriever.
|
||
|
|
||
|
Currently, it supports 5 backends:
|
||
|
InMemoryExactNNIndex, HnswDocumentIndex, QdrantDocumentIndex,
|
||
|
ElasticDocIndex, and WeaviateDocumentIndex.
|
||
|
|
||
|
Args:
|
||
|
index: One of the above-mentioned index instances
|
||
|
embeddings: Embedding model to represent text as vectors
|
||
|
search_field: Field to consider for searching in the documents.
|
||
|
Should be an embedding/vector/tensor.
|
||
|
content_field: Field that represents the main content in your document schema.
|
||
|
Will be used as a `page_content`. Everything else will go into `metadata`.
|
||
|
search_type: Type of search to perform (similarity / mmr)
|
||
|
filters: Filters applied for document retrieval.
|
||
|
top_k: Number of documents to return
|
||
|
"""
|
||
|
|
||
|
index: Any
|
||
|
embeddings: Embeddings
|
||
|
search_field: str
|
||
|
content_field: str
|
||
|
search_type: SearchType = SearchType.similarity
|
||
|
top_k: int = 1
|
||
|
filters: Optional[Any] = None
|
||
|
|
||
|
class Config:
|
||
|
"""Configuration for this pydantic object."""
|
||
|
|
||
|
arbitrary_types_allowed = True
|
||
|
|
||
|
def _get_relevant_documents(
|
||
|
self,
|
||
|
query: str,
|
||
|
*,
|
||
|
run_manager: CallbackManagerForRetrieverRun,
|
||
|
) -> List[Document]:
|
||
|
"""Get documents relevant for a query.
|
||
|
|
||
|
Args:
|
||
|
query: string to find relevant documents for
|
||
|
|
||
|
Returns:
|
||
|
List of relevant documents
|
||
|
"""
|
||
|
query_emb = np.array(self.embeddings.embed_query(query))
|
||
|
|
||
|
if self.search_type == SearchType.similarity:
|
||
|
results = self._similarity_search(query_emb)
|
||
|
elif self.search_type == SearchType.mmr:
|
||
|
results = self._mmr_search(query_emb)
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
f"Search type {self.search_type} does not exist. "
|
||
|
f"Choose either 'similarity' or 'mmr'."
|
||
|
)
|
||
|
|
||
|
return results
|
||
|
|
||
|
def _search(
|
||
|
self, query_emb: np.ndarray, top_k: int
|
||
|
) -> List[Union[Dict[str, Any], Any]]:
|
||
|
"""
|
||
|
Perform a search using the query embedding and return top_k documents.
|
||
|
|
||
|
Args:
|
||
|
query_emb: Query represented as an embedding
|
||
|
top_k: Number of documents to return
|
||
|
|
||
|
Returns:
|
||
|
A list of top_k documents matching the query
|
||
|
"""
|
||
|
|
||
|
from docarray.index import ElasticDocIndex, WeaviateDocumentIndex
|
||
|
|
||
|
filter_args = {}
|
||
|
search_field = self.search_field
|
||
|
if isinstance(self.index, WeaviateDocumentIndex):
|
||
|
filter_args["where_filter"] = self.filters
|
||
|
search_field = ""
|
||
|
elif isinstance(self.index, ElasticDocIndex):
|
||
|
filter_args["query"] = self.filters
|
||
|
else:
|
||
|
filter_args["filter_query"] = self.filters
|
||
|
|
||
|
if self.filters:
|
||
|
query = (
|
||
|
self.index.build_query() # get empty query object
|
||
|
.find(
|
||
|
query=query_emb, search_field=search_field
|
||
|
) # add vector similarity search
|
||
|
.filter(**filter_args) # add filter search
|
||
|
.build(limit=top_k) # build the query
|
||
|
)
|
||
|
# execute the combined query and return the results
|
||
|
docs = self.index.execute_query(query)
|
||
|
if hasattr(docs, "documents"):
|
||
|
docs = docs.documents
|
||
|
docs = docs[:top_k]
|
||
|
else:
|
||
|
docs = self.index.find(
|
||
|
query=query_emb, search_field=search_field, limit=top_k
|
||
|
).documents
|
||
|
return docs
|
||
|
|
||
|
def _similarity_search(self, query_emb: np.ndarray) -> List[Document]:
|
||
|
"""
|
||
|
Perform a similarity search.
|
||
|
|
||
|
Args:
|
||
|
query_emb: Query represented as an embedding
|
||
|
|
||
|
Returns:
|
||
|
A list of documents most similar to the query
|
||
|
"""
|
||
|
docs = self._search(query_emb=query_emb, top_k=self.top_k)
|
||
|
results = [self._docarray_to_langchain_doc(doc) for doc in docs]
|
||
|
return results
|
||
|
|
||
|
def _mmr_search(self, query_emb: np.ndarray) -> List[Document]:
|
||
|
"""
|
||
|
Perform a maximal marginal relevance (mmr) search.
|
||
|
|
||
|
Args:
|
||
|
query_emb: Query represented as an embedding
|
||
|
|
||
|
Returns:
|
||
|
A list of diverse documents related to the query
|
||
|
"""
|
||
|
docs = self._search(query_emb=query_emb, top_k=20)
|
||
|
|
||
|
mmr_selected = maximal_marginal_relevance(
|
||
|
query_emb,
|
||
|
[
|
||
|
doc[self.search_field]
|
||
|
if isinstance(doc, dict)
|
||
|
else getattr(doc, self.search_field)
|
||
|
for doc in docs
|
||
|
],
|
||
|
k=self.top_k,
|
||
|
)
|
||
|
results = [self._docarray_to_langchain_doc(docs[idx]) for idx in mmr_selected]
|
||
|
return results
|
||
|
|
||
|
def _docarray_to_langchain_doc(self, doc: Union[Dict[str, Any], Any]) -> Document:
|
||
|
"""
|
||
|
Convert a DocArray document (which also might be a dict)
|
||
|
to a langchain document format.
|
||
|
|
||
|
DocArray document can contain arbitrary fields, so the mapping is done
|
||
|
in the following way:
|
||
|
|
||
|
page_content <-> content_field
|
||
|
metadata <-> all other fields excluding
|
||
|
tensors and embeddings (so float, int, string)
|
||
|
|
||
|
Args:
|
||
|
doc: DocArray document
|
||
|
|
||
|
Returns:
|
||
|
Document in langchain format
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If the document doesn't contain the content field
|
||
|
"""
|
||
|
|
||
|
fields = doc.keys() if isinstance(doc, dict) else doc.__fields__
|
||
|
|
||
|
if self.content_field not in fields:
|
||
|
raise ValueError(
|
||
|
f"Document does not contain the content field - {self.content_field}."
|
||
|
)
|
||
|
lc_doc = Document(
|
||
|
page_content=doc[self.content_field]
|
||
|
if isinstance(doc, dict)
|
||
|
else getattr(doc, self.content_field)
|
||
|
)
|
||
|
|
||
|
for name in fields:
|
||
|
value = doc[name] if isinstance(doc, dict) else getattr(doc, name)
|
||
|
if (
|
||
|
isinstance(value, (str, int, float, bool))
|
||
|
and name != self.content_field
|
||
|
):
|
||
|
lc_doc.metadata[name] = value
|
||
|
|
||
|
return lc_doc
|