Harrison/mmr support for opensearch (#6349)

Co-authored-by: Mehmet Öner Yalçın <oneryalcin@gmail.com>
This commit is contained in:
Harrison Chase 2023-06-17 12:22:37 -07:00 committed by GitHub
parent 2eea5d4cb4
commit a2bbe3dda4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 140 additions and 60 deletions

View File

@ -129,11 +129,7 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "db3fa309", "id": "db3fa309",
"metadata": { "metadata": {},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [], "outputs": [],
"source": [ "source": [
"query = \"What did the president say about Ketanji Brown Jackson\"\n", "query = \"What did the president say about Ketanji Brown Jackson\"\n",
@ -144,11 +140,7 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "c160d5bb", "id": "c160d5bb",
"metadata": { "metadata": {},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [], "outputs": [],
"source": [ "source": [
"print(docs[0].page_content)" "print(docs[0].page_content)"
@ -158,11 +150,7 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "96215c90", "id": "96215c90",
"metadata": { "metadata": {},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [], "outputs": [],
"source": [ "source": [
"docsearch = OpenSearchVectorSearch.from_documents(\n", "docsearch = OpenSearchVectorSearch.from_documents(\n",
@ -183,11 +171,7 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "62a7cea0", "id": "62a7cea0",
"metadata": { "metadata": {},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [], "outputs": [],
"source": [ "source": [
"print(docs[0].page_content)" "print(docs[0].page_content)"
@ -207,11 +191,7 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "0a8e3c0e", "id": "0a8e3c0e",
"metadata": { "metadata": {},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [], "outputs": [],
"source": [ "source": [
"docsearch = OpenSearchVectorSearch.from_documents(\n", "docsearch = OpenSearchVectorSearch.from_documents(\n",
@ -230,11 +210,7 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "92bc40db", "id": "92bc40db",
"metadata": { "metadata": {},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [], "outputs": [],
"source": [ "source": [
"print(docs[0].page_content)" "print(docs[0].page_content)"
@ -254,11 +230,7 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "6d9f436e", "id": "6d9f436e",
"metadata": { "metadata": {},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [], "outputs": [],
"source": [ "source": [
"docsearch = OpenSearchVectorSearch.from_documents(\n", "docsearch = OpenSearchVectorSearch.from_documents(\n",
@ -278,16 +250,34 @@
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "8ca50bce", "id": "8ca50bce",
"metadata": { "metadata": {},
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [], "outputs": [],
"source": [ "source": [
"print(docs[0].page_content)" "print(docs[0].page_content)"
] ]
}, },
{
"cell_type": "markdown",
"source": [
"### Maximum marginal relevance search (MMR)\n",
"If youd like to look up for some similar documents, but youd also like to receive diverse results, MMR is method you should consider. Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents."
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs = docsearch.max_marginal_relevance_search(query, k=2, fetch_k=10, lambda_param=0.5)"
],
"metadata": {
"collapsed": false
}
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "73264864", "id": "73264864",

View File

@ -4,10 +4,13 @@ from __future__ import annotations
import uuid import uuid
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple
from langchain.docstore.document import Document import numpy as np
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
from langchain.vectorstores.base import VectorStore from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
IMPORT_OPENSEARCH_PY_ERROR = ( IMPORT_OPENSEARCH_PY_ERROR = (
"Could not import OpenSearch. Please install it with `pip install opensearch-py`." "Could not import OpenSearch. Please install it with `pip install opensearch-py`."
@ -76,9 +79,12 @@ def _bulk_ingest_embeddings(
metadatas: Optional[List[dict]] = None, metadatas: Optional[List[dict]] = None,
vector_field: str = "vector_field", vector_field: str = "vector_field",
text_field: str = "text", text_field: str = "text",
mapping: Dict = {}, mapping: Optional[Dict] = None,
) -> List[str]: ) -> List[str]:
"""Bulk Ingest Embeddings into given index.""" """Bulk Ingest Embeddings into given index."""
if not mapping:
mapping = dict()
bulk = _import_bulk() bulk = _import_bulk()
not_found_error = _import_not_found_error() not_found_error = _import_not_found_error()
requests = [] requests = []
@ -201,10 +207,14 @@ def _approximate_search_query_with_lucene_filter(
def _default_script_query( def _default_script_query(
query_vector: List[float], query_vector: List[float],
space_type: str = "l2", space_type: str = "l2",
pre_filter: Dict = MATCH_ALL_QUERY, pre_filter: Optional[Dict] = None,
vector_field: str = "vector_field", vector_field: str = "vector_field",
) -> Dict: ) -> Dict:
"""For Script Scoring Search, this is the default query.""" """For Script Scoring Search, this is the default query."""
if not pre_filter:
pre_filter = MATCH_ALL_QUERY
return { return {
"query": { "query": {
"script_score": { "script_score": {
@ -245,10 +255,14 @@ def __get_painless_scripting_source(
def _default_painless_scripting_query( def _default_painless_scripting_query(
query_vector: List[float], query_vector: List[float],
space_type: str = "l2Squared", space_type: str = "l2Squared",
pre_filter: Dict = MATCH_ALL_QUERY, pre_filter: Optional[Dict] = None,
vector_field: str = "vector_field", vector_field: str = "vector_field",
) -> Dict: ) -> Dict:
"""For Painless Scripting Search, this is the default query.""" """For Painless Scripting Search, this is the default query."""
if not pre_filter:
pre_filter = MATCH_ALL_QUERY
source = __get_painless_scripting_source(space_type, query_vector) source = __get_painless_scripting_source(space_type, query_vector)
return { return {
"query": { "query": {
@ -355,7 +369,7 @@ class OpenSearchVectorSearch(VectorStore):
) -> List[Document]: ) -> List[Document]:
"""Return docs most similar to query. """Return docs most similar to query.
By default supports Approximate Search. By default, supports Approximate Search.
Also supports Script Scoring and Painless Scripting. Also supports Script Scoring and Painless Scripting.
Args: Args:
@ -413,7 +427,7 @@ class OpenSearchVectorSearch(VectorStore):
) -> List[Tuple[Document, float]]: ) -> List[Tuple[Document, float]]:
"""Return docs and it's scores most similar to query. """Return docs and it's scores most similar to query.
By default supports Approximate Search. By default, supports Approximate Search.
Also supports Script Scoring and Painless Scripting. Also supports Script Scoring and Painless Scripting.
Args: Args:
@ -426,10 +440,47 @@ class OpenSearchVectorSearch(VectorStore):
Optional Args: Optional Args:
same as `similarity_search` same as `similarity_search`
""" """
embedding = self.embedding_function.embed_query(query)
search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search")
text_field = _get_kwargs_value(kwargs, "text_field", "text") text_field = _get_kwargs_value(kwargs, "text_field", "text")
metadata_field = _get_kwargs_value(kwargs, "metadata_field", "metadata") metadata_field = _get_kwargs_value(kwargs, "metadata_field", "metadata")
hits = self._raw_similarity_search_with_score(query=query, k=k, **kwargs)
documents_with_scores = [
(
Document(
page_content=hit["_source"][text_field],
metadata=hit["_source"]
if metadata_field == "*" or metadata_field not in hit["_source"]
else hit["_source"][metadata_field],
),
hit["_score"],
)
for hit in hits
]
return documents_with_scores
def _raw_similarity_search_with_score(
self, query: str, k: int = 4, **kwargs: Any
) -> List[dict]:
"""Return raw opensearch documents (dict) including vectors,
scores most similar to query.
By default, supports Approximate Search.
Also supports Script Scoring and Painless Scripting.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of dict with its scores most similar to the query.
Optional Args:
same as `similarity_search`
"""
embedding = self.embedding_function.embed_query(query)
search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search")
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field") vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
if search_type == "approximate_search": if search_type == "approximate_search":
@ -473,20 +524,59 @@ class OpenSearchVectorSearch(VectorStore):
raise ValueError("Invalid `search_type` provided as an argument") raise ValueError("Invalid `search_type` provided as an argument")
response = self.client.search(index=self.index_name, body=search_query) response = self.client.search(index=self.index_name, body=search_query)
hits = [hit for hit in response["hits"]["hits"][:k]]
documents_with_scores = [ return [hit for hit in response["hits"]["hits"][:k]]
(
Document( def max_marginal_relevance_search(
page_content=hit["_source"][text_field], self,
metadata=hit["_source"] query: str,
if metadata_field == "*" or metadata_field not in hit["_source"] k: int = 4,
else hit["_source"][metadata_field], fetch_k: int = 20,
), lambda_mult: float = 0.5,
hit["_score"], **kwargs: Any,
) -> list[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
text_field = _get_kwargs_value(kwargs, "text_field", "text")
metadata_field = _get_kwargs_value(kwargs, "metadata_field", "metadata")
# Get embedding of the user query
embedding = self.embedding_function.embed_query(query)
# Do ANN/KNN search to get top fetch_k results where fetch_k >= k
results = self._raw_similarity_search_with_score(query, fetch_k, **kwargs)
embeddings = [result["_source"][vector_field] for result in results]
# Rerank top k results using MMR, (mmr_selected is a list of indices)
mmr_selected = maximal_marginal_relevance(
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
)
return [
Document(
page_content=results[i]["_source"][text_field],
metadata=results[i]["_source"][metadata_field],
) )
for hit in hits for i in mmr_selected
] ]
return documents_with_scores
@classmethod @classmethod
def from_texts( def from_texts(