From a2bbe3dda4f03d02fdd9f87d413169cba9a4d131 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 17 Jun 2023 12:22:37 -0700 Subject: [PATCH] Harrison/mmr support for opensearch (#6349) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Mehmet Öner Yalçın --- .../integrations/opensearch.ipynb | 70 ++++------ .../vectorstores/opensearch_vector_search.py | 130 +++++++++++++++--- 2 files changed, 140 insertions(+), 60 deletions(-) diff --git a/docs/extras/modules/data_connection/vectorstores/integrations/opensearch.ipynb b/docs/extras/modules/data_connection/vectorstores/integrations/opensearch.ipynb index 654d9453fd..ee9fa2760e 100644 --- a/docs/extras/modules/data_connection/vectorstores/integrations/opensearch.ipynb +++ b/docs/extras/modules/data_connection/vectorstores/integrations/opensearch.ipynb @@ -129,11 +129,7 @@ "cell_type": "code", "execution_count": null, "id": "db3fa309", - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "query = \"What did the president say about Ketanji Brown Jackson\"\n", @@ -144,11 +140,7 @@ "cell_type": "code", "execution_count": null, "id": "c160d5bb", - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "print(docs[0].page_content)" @@ -158,11 +150,7 @@ "cell_type": "code", "execution_count": null, "id": "96215c90", - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "docsearch = OpenSearchVectorSearch.from_documents(\n", @@ -183,11 +171,7 @@ "cell_type": "code", "execution_count": null, "id": "62a7cea0", - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "print(docs[0].page_content)" @@ -207,11 +191,7 @@ "cell_type": "code", "execution_count": null, "id": "0a8e3c0e", - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "docsearch = OpenSearchVectorSearch.from_documents(\n", @@ -230,11 +210,7 @@ "cell_type": "code", "execution_count": null, "id": "92bc40db", - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "print(docs[0].page_content)" @@ -254,11 +230,7 @@ "cell_type": "code", "execution_count": null, "id": "6d9f436e", - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "docsearch = OpenSearchVectorSearch.from_documents(\n", @@ -278,16 +250,34 @@ "cell_type": "code", "execution_count": null, "id": "8ca50bce", - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "print(docs[0].page_content)" ] }, + { + "cell_type": "markdown", + "source": [ + "### Maximum marginal relevance search (MMR)\n", + "If you’d like to look up for some similar documents, but you’d also like to receive diverse results, MMR is method you should consider. Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "query = \"What did the president say about Ketanji Brown Jackson\"\n", + "docs = docsearch.max_marginal_relevance_search(query, k=2, fetch_k=10, lambda_param=0.5)" + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "markdown", "id": "73264864", diff --git a/langchain/vectorstores/opensearch_vector_search.py b/langchain/vectorstores/opensearch_vector_search.py index d33a39429a..dd7b36245d 100644 --- a/langchain/vectorstores/opensearch_vector_search.py +++ b/langchain/vectorstores/opensearch_vector_search.py @@ -4,10 +4,13 @@ from __future__ import annotations import uuid from typing import Any, Dict, Iterable, List, Optional, Tuple -from langchain.docstore.document import Document +import numpy as np + from langchain.embeddings.base import Embeddings +from langchain.schema import Document from langchain.utils import get_from_dict_or_env from langchain.vectorstores.base import VectorStore +from langchain.vectorstores.utils import maximal_marginal_relevance IMPORT_OPENSEARCH_PY_ERROR = ( "Could not import OpenSearch. Please install it with `pip install opensearch-py`." @@ -76,9 +79,12 @@ def _bulk_ingest_embeddings( metadatas: Optional[List[dict]] = None, vector_field: str = "vector_field", text_field: str = "text", - mapping: Dict = {}, + mapping: Optional[Dict] = None, ) -> List[str]: """Bulk Ingest Embeddings into given index.""" + if not mapping: + mapping = dict() + bulk = _import_bulk() not_found_error = _import_not_found_error() requests = [] @@ -201,10 +207,14 @@ def _approximate_search_query_with_lucene_filter( def _default_script_query( query_vector: List[float], space_type: str = "l2", - pre_filter: Dict = MATCH_ALL_QUERY, + pre_filter: Optional[Dict] = None, vector_field: str = "vector_field", ) -> Dict: """For Script Scoring Search, this is the default query.""" + + if not pre_filter: + pre_filter = MATCH_ALL_QUERY + return { "query": { "script_score": { @@ -245,10 +255,14 @@ def __get_painless_scripting_source( def _default_painless_scripting_query( query_vector: List[float], space_type: str = "l2Squared", - pre_filter: Dict = MATCH_ALL_QUERY, + pre_filter: Optional[Dict] = None, vector_field: str = "vector_field", ) -> Dict: """For Painless Scripting Search, this is the default query.""" + + if not pre_filter: + pre_filter = MATCH_ALL_QUERY + source = __get_painless_scripting_source(space_type, query_vector) return { "query": { @@ -355,7 +369,7 @@ class OpenSearchVectorSearch(VectorStore): ) -> List[Document]: """Return docs most similar to query. - By default supports Approximate Search. + By default, supports Approximate Search. Also supports Script Scoring and Painless Scripting. Args: @@ -413,7 +427,7 @@ class OpenSearchVectorSearch(VectorStore): ) -> List[Tuple[Document, float]]: """Return docs and it's scores most similar to query. - By default supports Approximate Search. + By default, supports Approximate Search. Also supports Script Scoring and Painless Scripting. Args: @@ -426,10 +440,47 @@ class OpenSearchVectorSearch(VectorStore): Optional Args: same as `similarity_search` """ - embedding = self.embedding_function.embed_query(query) - search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search") + text_field = _get_kwargs_value(kwargs, "text_field", "text") metadata_field = _get_kwargs_value(kwargs, "metadata_field", "metadata") + + hits = self._raw_similarity_search_with_score(query=query, k=k, **kwargs) + + documents_with_scores = [ + ( + Document( + page_content=hit["_source"][text_field], + metadata=hit["_source"] + if metadata_field == "*" or metadata_field not in hit["_source"] + else hit["_source"][metadata_field], + ), + hit["_score"], + ) + for hit in hits + ] + return documents_with_scores + + def _raw_similarity_search_with_score( + self, query: str, k: int = 4, **kwargs: Any + ) -> List[dict]: + """Return raw opensearch documents (dict) including vectors, + scores most similar to query. + + By default, supports Approximate Search. + Also supports Script Scoring and Painless Scripting. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + + Returns: + List of dict with its scores most similar to the query. + + Optional Args: + same as `similarity_search` + """ + embedding = self.embedding_function.embed_query(query) + search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search") vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field") if search_type == "approximate_search": @@ -473,20 +524,59 @@ class OpenSearchVectorSearch(VectorStore): raise ValueError("Invalid `search_type` provided as an argument") response = self.client.search(index=self.index_name, body=search_query) - hits = [hit for hit in response["hits"]["hits"][:k]] - documents_with_scores = [ - ( - Document( - page_content=hit["_source"][text_field], - metadata=hit["_source"] - if metadata_field == "*" or metadata_field not in hit["_source"] - else hit["_source"][metadata_field], - ), - hit["_score"], + + return [hit for hit in response["hits"]["hits"][:k]] + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> list[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + Defaults to 20. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + Returns: + List of Documents selected by maximal marginal relevance. + """ + + vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field") + text_field = _get_kwargs_value(kwargs, "text_field", "text") + metadata_field = _get_kwargs_value(kwargs, "metadata_field", "metadata") + + # Get embedding of the user query + embedding = self.embedding_function.embed_query(query) + + # Do ANN/KNN search to get top fetch_k results where fetch_k >= k + results = self._raw_similarity_search_with_score(query, fetch_k, **kwargs) + + embeddings = [result["_source"][vector_field] for result in results] + + # Rerank top k results using MMR, (mmr_selected is a list of indices) + mmr_selected = maximal_marginal_relevance( + np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult + ) + + return [ + Document( + page_content=results[i]["_source"][text_field], + metadata=results[i]["_source"][metadata_field], ) - for hit in hits + for i in mmr_selected ] - return documents_with_scores @classmethod def from_texts(