mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Harrison/mmr support for opensearch (#6349)
Co-authored-by: Mehmet Öner Yalçın <oneryalcin@gmail.com>
This commit is contained in:
parent
2eea5d4cb4
commit
a2bbe3dda4
@ -129,11 +129,7 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "db3fa309",
|
"id": "db3fa309",
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"pycharm": {
|
|
||||||
"name": "#%%\n"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||||
@ -144,11 +140,7 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "c160d5bb",
|
"id": "c160d5bb",
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"pycharm": {
|
|
||||||
"name": "#%%\n"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"print(docs[0].page_content)"
|
"print(docs[0].page_content)"
|
||||||
@ -158,11 +150,7 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "96215c90",
|
"id": "96215c90",
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"pycharm": {
|
|
||||||
"name": "#%%\n"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"docsearch = OpenSearchVectorSearch.from_documents(\n",
|
"docsearch = OpenSearchVectorSearch.from_documents(\n",
|
||||||
@ -183,11 +171,7 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "62a7cea0",
|
"id": "62a7cea0",
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"pycharm": {
|
|
||||||
"name": "#%%\n"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"print(docs[0].page_content)"
|
"print(docs[0].page_content)"
|
||||||
@ -207,11 +191,7 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "0a8e3c0e",
|
"id": "0a8e3c0e",
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"pycharm": {
|
|
||||||
"name": "#%%\n"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"docsearch = OpenSearchVectorSearch.from_documents(\n",
|
"docsearch = OpenSearchVectorSearch.from_documents(\n",
|
||||||
@ -230,11 +210,7 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "92bc40db",
|
"id": "92bc40db",
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"pycharm": {
|
|
||||||
"name": "#%%\n"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"print(docs[0].page_content)"
|
"print(docs[0].page_content)"
|
||||||
@ -254,11 +230,7 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "6d9f436e",
|
"id": "6d9f436e",
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"pycharm": {
|
|
||||||
"name": "#%%\n"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"docsearch = OpenSearchVectorSearch.from_documents(\n",
|
"docsearch = OpenSearchVectorSearch.from_documents(\n",
|
||||||
@ -278,16 +250,34 @@
|
|||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"id": "8ca50bce",
|
"id": "8ca50bce",
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"pycharm": {
|
|
||||||
"name": "#%%\n"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"print(docs[0].page_content)"
|
"print(docs[0].page_content)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"### Maximum marginal relevance search (MMR)\n",
|
||||||
|
"If you’d like to look up for some similar documents, but you’d also like to receive diverse results, MMR is method you should consider. Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents."
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||||
|
"docs = docsearch.max_marginal_relevance_search(query, k=2, fetch_k=10, lambda_param=0.5)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "73264864",
|
"id": "73264864",
|
||||||
|
@ -4,10 +4,13 @@ from __future__ import annotations
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
import numpy as np
|
||||||
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.schema import Document
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
from langchain.vectorstores.base import VectorStore
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
IMPORT_OPENSEARCH_PY_ERROR = (
|
IMPORT_OPENSEARCH_PY_ERROR = (
|
||||||
"Could not import OpenSearch. Please install it with `pip install opensearch-py`."
|
"Could not import OpenSearch. Please install it with `pip install opensearch-py`."
|
||||||
@ -76,9 +79,12 @@ def _bulk_ingest_embeddings(
|
|||||||
metadatas: Optional[List[dict]] = None,
|
metadatas: Optional[List[dict]] = None,
|
||||||
vector_field: str = "vector_field",
|
vector_field: str = "vector_field",
|
||||||
text_field: str = "text",
|
text_field: str = "text",
|
||||||
mapping: Dict = {},
|
mapping: Optional[Dict] = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Bulk Ingest Embeddings into given index."""
|
"""Bulk Ingest Embeddings into given index."""
|
||||||
|
if not mapping:
|
||||||
|
mapping = dict()
|
||||||
|
|
||||||
bulk = _import_bulk()
|
bulk = _import_bulk()
|
||||||
not_found_error = _import_not_found_error()
|
not_found_error = _import_not_found_error()
|
||||||
requests = []
|
requests = []
|
||||||
@ -201,10 +207,14 @@ def _approximate_search_query_with_lucene_filter(
|
|||||||
def _default_script_query(
|
def _default_script_query(
|
||||||
query_vector: List[float],
|
query_vector: List[float],
|
||||||
space_type: str = "l2",
|
space_type: str = "l2",
|
||||||
pre_filter: Dict = MATCH_ALL_QUERY,
|
pre_filter: Optional[Dict] = None,
|
||||||
vector_field: str = "vector_field",
|
vector_field: str = "vector_field",
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""For Script Scoring Search, this is the default query."""
|
"""For Script Scoring Search, this is the default query."""
|
||||||
|
|
||||||
|
if not pre_filter:
|
||||||
|
pre_filter = MATCH_ALL_QUERY
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"query": {
|
"query": {
|
||||||
"script_score": {
|
"script_score": {
|
||||||
@ -245,10 +255,14 @@ def __get_painless_scripting_source(
|
|||||||
def _default_painless_scripting_query(
|
def _default_painless_scripting_query(
|
||||||
query_vector: List[float],
|
query_vector: List[float],
|
||||||
space_type: str = "l2Squared",
|
space_type: str = "l2Squared",
|
||||||
pre_filter: Dict = MATCH_ALL_QUERY,
|
pre_filter: Optional[Dict] = None,
|
||||||
vector_field: str = "vector_field",
|
vector_field: str = "vector_field",
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""For Painless Scripting Search, this is the default query."""
|
"""For Painless Scripting Search, this is the default query."""
|
||||||
|
|
||||||
|
if not pre_filter:
|
||||||
|
pre_filter = MATCH_ALL_QUERY
|
||||||
|
|
||||||
source = __get_painless_scripting_source(space_type, query_vector)
|
source = __get_painless_scripting_source(space_type, query_vector)
|
||||||
return {
|
return {
|
||||||
"query": {
|
"query": {
|
||||||
@ -355,7 +369,7 @@ class OpenSearchVectorSearch(VectorStore):
|
|||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
"""Return docs most similar to query.
|
"""Return docs most similar to query.
|
||||||
|
|
||||||
By default supports Approximate Search.
|
By default, supports Approximate Search.
|
||||||
Also supports Script Scoring and Painless Scripting.
|
Also supports Script Scoring and Painless Scripting.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -413,7 +427,7 @@ class OpenSearchVectorSearch(VectorStore):
|
|||||||
) -> List[Tuple[Document, float]]:
|
) -> List[Tuple[Document, float]]:
|
||||||
"""Return docs and it's scores most similar to query.
|
"""Return docs and it's scores most similar to query.
|
||||||
|
|
||||||
By default supports Approximate Search.
|
By default, supports Approximate Search.
|
||||||
Also supports Script Scoring and Painless Scripting.
|
Also supports Script Scoring and Painless Scripting.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -426,10 +440,47 @@ class OpenSearchVectorSearch(VectorStore):
|
|||||||
Optional Args:
|
Optional Args:
|
||||||
same as `similarity_search`
|
same as `similarity_search`
|
||||||
"""
|
"""
|
||||||
embedding = self.embedding_function.embed_query(query)
|
|
||||||
search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search")
|
|
||||||
text_field = _get_kwargs_value(kwargs, "text_field", "text")
|
text_field = _get_kwargs_value(kwargs, "text_field", "text")
|
||||||
metadata_field = _get_kwargs_value(kwargs, "metadata_field", "metadata")
|
metadata_field = _get_kwargs_value(kwargs, "metadata_field", "metadata")
|
||||||
|
|
||||||
|
hits = self._raw_similarity_search_with_score(query=query, k=k, **kwargs)
|
||||||
|
|
||||||
|
documents_with_scores = [
|
||||||
|
(
|
||||||
|
Document(
|
||||||
|
page_content=hit["_source"][text_field],
|
||||||
|
metadata=hit["_source"]
|
||||||
|
if metadata_field == "*" or metadata_field not in hit["_source"]
|
||||||
|
else hit["_source"][metadata_field],
|
||||||
|
),
|
||||||
|
hit["_score"],
|
||||||
|
)
|
||||||
|
for hit in hits
|
||||||
|
]
|
||||||
|
return documents_with_scores
|
||||||
|
|
||||||
|
def _raw_similarity_search_with_score(
|
||||||
|
self, query: str, k: int = 4, **kwargs: Any
|
||||||
|
) -> List[dict]:
|
||||||
|
"""Return raw opensearch documents (dict) including vectors,
|
||||||
|
scores most similar to query.
|
||||||
|
|
||||||
|
By default, supports Approximate Search.
|
||||||
|
Also supports Script Scoring and Painless Scripting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dict with its scores most similar to the query.
|
||||||
|
|
||||||
|
Optional Args:
|
||||||
|
same as `similarity_search`
|
||||||
|
"""
|
||||||
|
embedding = self.embedding_function.embed_query(query)
|
||||||
|
search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search")
|
||||||
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
|
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
|
||||||
|
|
||||||
if search_type == "approximate_search":
|
if search_type == "approximate_search":
|
||||||
@ -473,20 +524,59 @@ class OpenSearchVectorSearch(VectorStore):
|
|||||||
raise ValueError("Invalid `search_type` provided as an argument")
|
raise ValueError("Invalid `search_type` provided as an argument")
|
||||||
|
|
||||||
response = self.client.search(index=self.index_name, body=search_query)
|
response = self.client.search(index=self.index_name, body=search_query)
|
||||||
hits = [hit for hit in response["hits"]["hits"][:k]]
|
|
||||||
documents_with_scores = [
|
return [hit for hit in response["hits"]["hits"][:k]]
|
||||||
(
|
|
||||||
Document(
|
def max_marginal_relevance_search(
|
||||||
page_content=hit["_source"][text_field],
|
self,
|
||||||
metadata=hit["_source"]
|
query: str,
|
||||||
if metadata_field == "*" or metadata_field not in hit["_source"]
|
k: int = 4,
|
||||||
else hit["_source"][metadata_field],
|
fetch_k: int = 20,
|
||||||
),
|
lambda_mult: float = 0.5,
|
||||||
hit["_score"],
|
**kwargs: Any,
|
||||||
|
) -> list[Document]:
|
||||||
|
"""Return docs selected using the maximal marginal relevance.
|
||||||
|
|
||||||
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||||
|
among selected documents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||||
|
Defaults to 20.
|
||||||
|
lambda_mult: Number between 0 and 1 that determines the degree
|
||||||
|
of diversity among the results with 0 corresponding
|
||||||
|
to maximum diversity and 1 to minimum diversity.
|
||||||
|
Defaults to 0.5.
|
||||||
|
Returns:
|
||||||
|
List of Documents selected by maximal marginal relevance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
|
||||||
|
text_field = _get_kwargs_value(kwargs, "text_field", "text")
|
||||||
|
metadata_field = _get_kwargs_value(kwargs, "metadata_field", "metadata")
|
||||||
|
|
||||||
|
# Get embedding of the user query
|
||||||
|
embedding = self.embedding_function.embed_query(query)
|
||||||
|
|
||||||
|
# Do ANN/KNN search to get top fetch_k results where fetch_k >= k
|
||||||
|
results = self._raw_similarity_search_with_score(query, fetch_k, **kwargs)
|
||||||
|
|
||||||
|
embeddings = [result["_source"][vector_field] for result in results]
|
||||||
|
|
||||||
|
# Rerank top k results using MMR, (mmr_selected is a list of indices)
|
||||||
|
mmr_selected = maximal_marginal_relevance(
|
||||||
|
np.array(embedding), embeddings, k=k, lambda_mult=lambda_mult
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
Document(
|
||||||
|
page_content=results[i]["_source"][text_field],
|
||||||
|
metadata=results[i]["_source"][metadata_field],
|
||||||
)
|
)
|
||||||
for hit in hits
|
for i in mmr_selected
|
||||||
]
|
]
|
||||||
return documents_with_scores
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_texts(
|
def from_texts(
|
||||||
|
Loading…
Reference in New Issue
Block a user