mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
[ OpenSearch ] : Add AOSS Support to OpenSearch (#8256)
### Description This PR includes the following changes: - Adds AOSS (Amazon OpenSearch Service Serverless) support to OpenSearch. Please refer to the documentation on how to use it. - While creating an index, AOSS only supports Approximate Search with `nmslib` and `faiss` engines. During Search, only Approximate Search and Script Scoring (on doc values) are supported. - This PR also adds support to `efficient_filter` which can be used with `faiss` and `lucene` engines. - The `lucene_filter` is deprecated. Instead please use the `efficient_filter` for the lucene engine. Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
This commit is contained in:
parent
7a00f17033
commit
9cbefcc56c
@ -315,6 +315,101 @@
|
||||
" metadata_field=\"message_metadata\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Using AOSS (Amazon OpenSearch Service Serverless)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This is just an example to show how to use AOSS with faiss engine and efficient_filter, you need to set proper values.\n",
|
||||
"\n",
|
||||
"service = 'aoss' # must set the service as 'aoss'\n",
|
||||
"region = 'us-east-2'\n",
|
||||
"credentials = boto3.Session(aws_access_key_id='xxxxxx',aws_secret_access_key='xxxxx').get_credentials()\n",
|
||||
"awsauth = AWS4Auth('xxxxx', 'xxxxxx', region,service, session_token=credentials.token)\n",
|
||||
"\n",
|
||||
"docsearch = OpenSearchVectorSearch.from_documents(\n",
|
||||
" docs,\n",
|
||||
" embeddings,\n",
|
||||
" opensearch_url=\"host url\",\n",
|
||||
" http_auth=awsauth,\n",
|
||||
" timeout = 300,\n",
|
||||
" use_ssl = True,\n",
|
||||
" verify_certs = True,\n",
|
||||
" connection_class = RequestsHttpConnection,\n",
|
||||
" index_name=\"test-index-using-aoss\",\n",
|
||||
" engine=\"faiss\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"docs = docsearch.similarity_search(\n",
|
||||
" \"What is feature selection\",\n",
|
||||
" efficient_filter=filter,\n",
|
||||
" k=200,\n",
|
||||
")"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## Using AOS (Amazon OpenSearch Service)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This is just an example to show how to use AOS , you need to set proper values.\n",
|
||||
"\n",
|
||||
"service = 'es' # must set the service as 'es'\n",
|
||||
"region = 'us-east-2'\n",
|
||||
"credentials = boto3.Session(aws_access_key_id='xxxxxx',aws_secret_access_key='xxxxx').get_credentials()\n",
|
||||
"awsauth = AWS4Auth('xxxxx', 'xxxxxx', region,service, session_token=credentials.token)\n",
|
||||
"\n",
|
||||
"docsearch = OpenSearchVectorSearch.from_documents(\n",
|
||||
" docs,\n",
|
||||
" embeddings,\n",
|
||||
" opensearch_url=\"host url\",\n",
|
||||
" http_auth=awsauth,\n",
|
||||
" timeout = 300,\n",
|
||||
" use_ssl = True,\n",
|
||||
" verify_certs = True,\n",
|
||||
" connection_class = RequestsHttpConnection,\n",
|
||||
" index_name=\"test-index\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"docs = docsearch.similarity_search(\n",
|
||||
" \"What is feature selection\",\n",
|
||||
" k=200,\n",
|
||||
")"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -71,6 +72,26 @@ def _validate_embeddings_and_bulk_size(embeddings_length: int, bulk_size: int) -
|
||||
)
|
||||
|
||||
|
||||
def _validate_aoss_with_engines(is_aoss: bool, engine: str) -> None:
|
||||
"""Validate AOSS with the engine."""
|
||||
if is_aoss and engine != "nmslib" and engine != "faiss":
|
||||
raise ValueError(
|
||||
"Amazon OpenSearch Service Serverless only "
|
||||
"supports `nmslib` or `faiss` engines"
|
||||
)
|
||||
|
||||
|
||||
def _is_aoss_enabled(http_auth: Any) -> bool:
|
||||
"""Check if the service is http_auth is set as `aoss`."""
|
||||
if (
|
||||
http_auth is not None
|
||||
and http_auth.service is not None
|
||||
and http_auth.service == "aoss"
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _bulk_ingest_embeddings(
|
||||
client: Any,
|
||||
index_name: str,
|
||||
@ -82,6 +103,7 @@ def _bulk_ingest_embeddings(
|
||||
text_field: str = "text",
|
||||
mapping: Optional[Dict] = None,
|
||||
max_chunk_bytes: Optional[int] = 1 * 1024 * 1024,
|
||||
is_aoss: bool = False,
|
||||
) -> List[str]:
|
||||
"""Bulk Ingest Embeddings into given index."""
|
||||
if not mapping:
|
||||
@ -107,11 +129,15 @@ def _bulk_ingest_embeddings(
|
||||
vector_field: embeddings[i],
|
||||
text_field: text,
|
||||
"metadata": metadata,
|
||||
"_id": _id,
|
||||
}
|
||||
if is_aoss:
|
||||
request["id"] = _id
|
||||
else:
|
||||
request["_id"] = _id
|
||||
requests.append(request)
|
||||
return_ids.append(_id)
|
||||
bulk(client, requests, max_chunk_bytes=max_chunk_bytes)
|
||||
if not is_aoss:
|
||||
client.indices.refresh(index=index_name)
|
||||
return return_ids
|
||||
|
||||
@ -192,17 +218,18 @@ def _approximate_search_query_with_boolean_filter(
|
||||
}
|
||||
|
||||
|
||||
def _approximate_search_query_with_lucene_filter(
|
||||
def _approximate_search_query_with_efficient_filter(
|
||||
query_vector: List[float],
|
||||
lucene_filter: Dict,
|
||||
efficient_filter: Dict,
|
||||
k: int = 4,
|
||||
vector_field: str = "vector_field",
|
||||
) -> Dict:
|
||||
"""For Approximate k-NN Search, with Lucene Filter."""
|
||||
"""For Approximate k-NN Search, with Efficient Filter for Lucene and
|
||||
Faiss Engines."""
|
||||
search_query = _default_approximate_search_query(
|
||||
query_vector, k=k, vector_field=vector_field
|
||||
)
|
||||
search_query["query"]["knn"][vector_field]["filter"] = lucene_filter
|
||||
search_query["query"]["knn"][vector_field]["filter"] = efficient_filter
|
||||
return search_query
|
||||
|
||||
|
||||
@ -309,11 +336,13 @@ class OpenSearchVectorSearch(VectorStore):
|
||||
opensearch_url: str,
|
||||
index_name: str,
|
||||
embedding_function: Embeddings,
|
||||
is_aoss: bool,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize with necessary components."""
|
||||
self.embedding_function = embedding_function
|
||||
self.index_name = index_name
|
||||
self.is_aoss = is_aoss
|
||||
self.client = _get_opensearch_client(opensearch_url, **kwargs)
|
||||
|
||||
@property
|
||||
@ -358,6 +387,8 @@ class OpenSearchVectorSearch(VectorStore):
|
||||
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
|
||||
max_chunk_bytes = _get_kwargs_value(kwargs, "max_chunk_bytes", 1 * 1024 * 1024)
|
||||
|
||||
_validate_aoss_with_engines(self.is_aoss, engine)
|
||||
|
||||
mapping = _default_text_mapping(
|
||||
dim, engine, space_type, ef_search, ef_construction, m, vector_field
|
||||
)
|
||||
@ -373,6 +404,7 @@ class OpenSearchVectorSearch(VectorStore):
|
||||
text_field=text_field,
|
||||
mapping=mapping,
|
||||
max_chunk_bytes=max_chunk_bytes,
|
||||
is_aoss=self.is_aoss,
|
||||
)
|
||||
|
||||
def similarity_search(
|
||||
@ -404,14 +436,18 @@ class OpenSearchVectorSearch(VectorStore):
|
||||
Optional Args for Approximate Search:
|
||||
search_type: "approximate_search"; default: "approximate_search"
|
||||
|
||||
boolean_filter: A Boolean filter consists of a Boolean query that
|
||||
contains a k-NN query and a filter.
|
||||
boolean_filter: A Boolean filter is a post filter consists of a Boolean
|
||||
query that contains a k-NN query and a filter.
|
||||
|
||||
subquery_clause: Query clause on the knn vector field; default: "must"
|
||||
|
||||
lucene_filter: the Lucene algorithm decides whether to perform an exact
|
||||
k-NN search with pre-filtering or an approximate search with modified
|
||||
post-filtering.
|
||||
post-filtering. (deprecated, use `efficient_filter`)
|
||||
|
||||
efficient_filter: the Lucene Engine or Faiss Engine decides whether to
|
||||
perform an exact k-NN search with pre-filtering or an approximate search
|
||||
with modified post-filtering.
|
||||
|
||||
Optional Args for Script Scoring Search:
|
||||
search_type: "script_scoring"; default: "approximate_search"
|
||||
@ -494,15 +530,41 @@ class OpenSearchVectorSearch(VectorStore):
|
||||
search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search")
|
||||
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
|
||||
|
||||
if (
|
||||
self.is_aoss
|
||||
and search_type != "approximate_search"
|
||||
and search_type != SCRIPT_SCORING_SEARCH
|
||||
):
|
||||
raise ValueError(
|
||||
"Amazon OpenSearch Service Serverless only "
|
||||
"supports `approximate_search` and `script_scoring`"
|
||||
)
|
||||
|
||||
if search_type == "approximate_search":
|
||||
boolean_filter = _get_kwargs_value(kwargs, "boolean_filter", {})
|
||||
subquery_clause = _get_kwargs_value(kwargs, "subquery_clause", "must")
|
||||
efficient_filter = _get_kwargs_value(kwargs, "efficient_filter", {})
|
||||
# `lucene_filter` is deprecated, added for Backwards Compatibility
|
||||
lucene_filter = _get_kwargs_value(kwargs, "lucene_filter", {})
|
||||
if boolean_filter != {} and lucene_filter != {}:
|
||||
|
||||
if boolean_filter != {} and efficient_filter != {}:
|
||||
raise ValueError(
|
||||
"Both `boolean_filter` and `lucene_filter` are provided which "
|
||||
"Both `boolean_filter` and `efficient_filter` are provided which "
|
||||
"is invalid"
|
||||
)
|
||||
|
||||
if lucene_filter != {} and efficient_filter != {}:
|
||||
raise ValueError(
|
||||
"Both `lucene_filter` and `efficient_filter` are provided which "
|
||||
"is invalid. `lucene_filter` is deprecated"
|
||||
)
|
||||
|
||||
if lucene_filter != {} and boolean_filter != {}:
|
||||
raise ValueError(
|
||||
"Both `lucene_filter` and `boolean_filter` are provided which "
|
||||
"is invalid. `lucene_filter` is deprecated"
|
||||
)
|
||||
|
||||
if boolean_filter != {}:
|
||||
search_query = _approximate_search_query_with_boolean_filter(
|
||||
embedding,
|
||||
@ -511,8 +573,16 @@ class OpenSearchVectorSearch(VectorStore):
|
||||
vector_field=vector_field,
|
||||
subquery_clause=subquery_clause,
|
||||
)
|
||||
elif efficient_filter != {}:
|
||||
search_query = _approximate_search_query_with_efficient_filter(
|
||||
embedding, efficient_filter, k=k, vector_field=vector_field
|
||||
)
|
||||
elif lucene_filter != {}:
|
||||
search_query = _approximate_search_query_with_lucene_filter(
|
||||
warnings.warn(
|
||||
"`lucene_filter` is deprecated. Please use the keyword argument"
|
||||
" `efficient_filter`"
|
||||
)
|
||||
search_query = _approximate_search_query_with_efficient_filter(
|
||||
embedding, lucene_filter, k=k, vector_field=vector_field
|
||||
)
|
||||
else:
|
||||
@ -659,6 +729,7 @@ class OpenSearchVectorSearch(VectorStore):
|
||||
"ef_construction",
|
||||
"m",
|
||||
"max_chunk_bytes",
|
||||
"is_aoss",
|
||||
]
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
_validate_embeddings_and_bulk_size(len(embeddings), bulk_size)
|
||||
@ -672,6 +743,15 @@ class OpenSearchVectorSearch(VectorStore):
|
||||
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
|
||||
text_field = _get_kwargs_value(kwargs, "text_field", "text")
|
||||
max_chunk_bytes = _get_kwargs_value(kwargs, "max_chunk_bytes", 1 * 1024 * 1024)
|
||||
http_auth = _get_kwargs_value(kwargs, "http_auth", None)
|
||||
is_aoss = _is_aoss_enabled(http_auth=http_auth)
|
||||
|
||||
if is_aoss and not is_appx_search:
|
||||
raise ValueError(
|
||||
"Amazon OpenSearch Service Serverless only "
|
||||
"supports `approximate_search`"
|
||||
)
|
||||
|
||||
if is_appx_search:
|
||||
engine = _get_kwargs_value(kwargs, "engine", "nmslib")
|
||||
space_type = _get_kwargs_value(kwargs, "space_type", "l2")
|
||||
@ -679,6 +759,8 @@ class OpenSearchVectorSearch(VectorStore):
|
||||
ef_construction = _get_kwargs_value(kwargs, "ef_construction", 512)
|
||||
m = _get_kwargs_value(kwargs, "m", 16)
|
||||
|
||||
_validate_aoss_with_engines(is_aoss, engine)
|
||||
|
||||
mapping = _default_text_mapping(
|
||||
dim, engine, space_type, ef_search, ef_construction, m, vector_field
|
||||
)
|
||||
@ -697,5 +779,6 @@ class OpenSearchVectorSearch(VectorStore):
|
||||
text_field=text_field,
|
||||
mapping=mapping,
|
||||
max_chunk_bytes=max_chunk_bytes,
|
||||
is_aoss=is_aoss,
|
||||
)
|
||||
return cls(opensearch_url, index_name, embedding, **kwargs)
|
||||
return cls(opensearch_url, index_name, embedding, is_aoss, **kwargs)
|
||||
|
@ -1,6 +1,8 @@
|
||||
"""Test OpenSearch functionality."""
|
||||
|
||||
import boto3
|
||||
import pytest
|
||||
from opensearchpy import AWSV4SignerAuth
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.vectorstores.opensearch_vector_search import (
|
||||
@ -213,3 +215,95 @@ def test_opensearch_with_custom_field_name_appx_false() -> None:
|
||||
)
|
||||
output = docsearch.similarity_search("add", k=1)
|
||||
assert output == [Document(page_content="add")]
|
||||
|
||||
|
||||
def test_opensearch_serverless_with_scripting_search_indexing_throws_error() -> None:
|
||||
"""Test to validate indexing using Serverless without Approximate Search."""
|
||||
region = "test-region"
|
||||
service = "aoss"
|
||||
credentials = boto3.Session().get_credentials()
|
||||
auth = AWSV4SignerAuth(credentials, region, service)
|
||||
with pytest.raises(ValueError):
|
||||
OpenSearchVectorSearch.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
opensearch_url=DEFAULT_OPENSEARCH_URL,
|
||||
is_appx_search=False,
|
||||
http_auth=auth,
|
||||
)
|
||||
|
||||
|
||||
def test_opensearch_serverless_with_lucene_engine_throws_error() -> None:
|
||||
"""Test to validate indexing using lucene engine with Serverless."""
|
||||
region = "test-region"
|
||||
service = "aoss"
|
||||
credentials = boto3.Session().get_credentials()
|
||||
auth = AWSV4SignerAuth(credentials, region, service)
|
||||
with pytest.raises(ValueError):
|
||||
OpenSearchVectorSearch.from_texts(
|
||||
texts,
|
||||
FakeEmbeddings(),
|
||||
opensearch_url=DEFAULT_OPENSEARCH_URL,
|
||||
engine="lucene",
|
||||
http_auth=auth,
|
||||
)
|
||||
|
||||
|
||||
def test_appx_search_with_efficient_and_bool_filter_throws_error() -> None:
|
||||
"""Test Approximate Search with Efficient and Bool Filter throws Error."""
|
||||
efficient_filter_val = {"bool": {"must": [{"term": {"text": "baz"}}]}}
|
||||
boolean_filter_val = {"bool": {"must": [{"term": {"text": "bar"}}]}}
|
||||
docsearch = OpenSearchVectorSearch.from_texts(
|
||||
texts, FakeEmbeddings(), opensearch_url=DEFAULT_OPENSEARCH_URL, engine="lucene"
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
docsearch.similarity_search(
|
||||
"foo",
|
||||
k=3,
|
||||
efficient_filter=efficient_filter_val,
|
||||
boolean_filter=boolean_filter_val,
|
||||
)
|
||||
|
||||
|
||||
def test_appx_search_with_efficient_and_lucene_filter_throws_error() -> None:
|
||||
"""Test Approximate Search with Efficient and Lucene Filter throws Error."""
|
||||
efficient_filter_val = {"bool": {"must": [{"term": {"text": "baz"}}]}}
|
||||
lucene_filter_val = {"bool": {"must": [{"term": {"text": "bar"}}]}}
|
||||
docsearch = OpenSearchVectorSearch.from_texts(
|
||||
texts, FakeEmbeddings(), opensearch_url=DEFAULT_OPENSEARCH_URL, engine="lucene"
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
docsearch.similarity_search(
|
||||
"foo",
|
||||
k=3,
|
||||
efficient_filter=efficient_filter_val,
|
||||
lucene_filter=lucene_filter_val,
|
||||
)
|
||||
|
||||
|
||||
def test_appx_search_with_boolean_and_lucene_filter_throws_error() -> None:
|
||||
"""Test Approximate Search with Boolean and Lucene Filter throws Error."""
|
||||
boolean_filter_val = {"bool": {"must": [{"term": {"text": "baz"}}]}}
|
||||
lucene_filter_val = {"bool": {"must": [{"term": {"text": "bar"}}]}}
|
||||
docsearch = OpenSearchVectorSearch.from_texts(
|
||||
texts, FakeEmbeddings(), opensearch_url=DEFAULT_OPENSEARCH_URL, engine="lucene"
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
docsearch.similarity_search(
|
||||
"foo",
|
||||
k=3,
|
||||
boolean_filter=boolean_filter_val,
|
||||
lucene_filter=lucene_filter_val,
|
||||
)
|
||||
|
||||
|
||||
def test_appx_search_with_faiss_efficient_filter() -> None:
|
||||
"""Test Approximate Search with Faiss Efficient Filter."""
|
||||
efficient_filter_val = {"bool": {"must": [{"term": {"text": "bar"}}]}}
|
||||
docsearch = OpenSearchVectorSearch.from_texts(
|
||||
texts, FakeEmbeddings(), opensearch_url=DEFAULT_OPENSEARCH_URL, engine="faiss"
|
||||
)
|
||||
output = docsearch.similarity_search(
|
||||
"foo", k=3, efficient_filter=efficient_filter_val
|
||||
)
|
||||
assert output == [Document(page_content="bar")]
|
||||
|
Loading…
Reference in New Issue
Block a user