[ OpenSearch ] : Add AOSS Support to OpenSearch (#8256)

### Description

This PR includes the following changes:

- Adds AOSS (Amazon OpenSearch Service Serverless) support to
OpenSearch. Please refer to the documentation on how to use it.
- While creating an index, AOSS only supports Approximate Search with
`nmslib` and `faiss` engines. During Search, only Approximate Search and
Script Scoring (on doc values) are supported.
- This PR also adds support to `efficient_filter` which can be used with
`faiss` and `lucene` engines.
- The `lucene_filter` is deprecated. Instead please use the
`efficient_filter` for the lucene engine.


Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
pull/8278/head
Naveen Tatikonda 1 year ago committed by GitHub
parent 7a00f17033
commit 9cbefcc56c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -315,6 +315,101 @@
" metadata_field=\"message_metadata\",\n",
")"
]
},
{
"cell_type": "markdown",
"source": [
"## Using AOSS (Amazon OpenSearch Service Serverless)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# This is just an example to show how to use AOSS with faiss engine and efficient_filter, you need to set proper values.\n",
"\n",
"service = 'aoss' # must set the service as 'aoss'\n",
"region = 'us-east-2'\n",
"credentials = boto3.Session(aws_access_key_id='xxxxxx',aws_secret_access_key='xxxxx').get_credentials()\n",
"awsauth = AWS4Auth('xxxxx', 'xxxxxx', region,service, session_token=credentials.token)\n",
"\n",
"docsearch = OpenSearchVectorSearch.from_documents(\n",
" docs,\n",
" embeddings,\n",
" opensearch_url=\"host url\",\n",
" http_auth=awsauth,\n",
" timeout = 300,\n",
" use_ssl = True,\n",
" verify_certs = True,\n",
" connection_class = RequestsHttpConnection,\n",
" index_name=\"test-index-using-aoss\",\n",
" engine=\"faiss\",\n",
")\n",
"\n",
"docs = docsearch.similarity_search(\n",
" \"What is feature selection\",\n",
" efficient_filter=filter,\n",
" k=200,\n",
")"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## Using AOS (Amazon OpenSearch Service)"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"# This is just an example to show how to use AOS , you need to set proper values.\n",
"\n",
"service = 'es' # must set the service as 'es'\n",
"region = 'us-east-2'\n",
"credentials = boto3.Session(aws_access_key_id='xxxxxx',aws_secret_access_key='xxxxx').get_credentials()\n",
"awsauth = AWS4Auth('xxxxx', 'xxxxxx', region,service, session_token=credentials.token)\n",
"\n",
"docsearch = OpenSearchVectorSearch.from_documents(\n",
" docs,\n",
" embeddings,\n",
" opensearch_url=\"host url\",\n",
" http_auth=awsauth,\n",
" timeout = 300,\n",
" use_ssl = True,\n",
" verify_certs = True,\n",
" connection_class = RequestsHttpConnection,\n",
" index_name=\"test-index\",\n",
")\n",
"\n",
"docs = docsearch.similarity_search(\n",
" \"What is feature selection\",\n",
" k=200,\n",
")"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
@ -338,4 +433,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

@ -2,6 +2,7 @@
from __future__ import annotations
import uuid
import warnings
from typing import Any, Dict, Iterable, List, Optional, Tuple
import numpy as np
@ -71,6 +72,26 @@ def _validate_embeddings_and_bulk_size(embeddings_length: int, bulk_size: int) -
)
def _validate_aoss_with_engines(is_aoss: bool, engine: str) -> None:
"""Validate AOSS with the engine."""
if is_aoss and engine != "nmslib" and engine != "faiss":
raise ValueError(
"Amazon OpenSearch Service Serverless only "
"supports `nmslib` or `faiss` engines"
)
def _is_aoss_enabled(http_auth: Any) -> bool:
"""Check if the service is http_auth is set as `aoss`."""
if (
http_auth is not None
and http_auth.service is not None
and http_auth.service == "aoss"
):
return True
return False
def _bulk_ingest_embeddings(
client: Any,
index_name: str,
@ -82,6 +103,7 @@ def _bulk_ingest_embeddings(
text_field: str = "text",
mapping: Optional[Dict] = None,
max_chunk_bytes: Optional[int] = 1 * 1024 * 1024,
is_aoss: bool = False,
) -> List[str]:
"""Bulk Ingest Embeddings into given index."""
if not mapping:
@ -107,12 +129,16 @@ def _bulk_ingest_embeddings(
vector_field: embeddings[i],
text_field: text,
"metadata": metadata,
"_id": _id,
}
if is_aoss:
request["id"] = _id
else:
request["_id"] = _id
requests.append(request)
return_ids.append(_id)
bulk(client, requests, max_chunk_bytes=max_chunk_bytes)
client.indices.refresh(index=index_name)
if not is_aoss:
client.indices.refresh(index=index_name)
return return_ids
@ -192,17 +218,18 @@ def _approximate_search_query_with_boolean_filter(
}
def _approximate_search_query_with_lucene_filter(
def _approximate_search_query_with_efficient_filter(
query_vector: List[float],
lucene_filter: Dict,
efficient_filter: Dict,
k: int = 4,
vector_field: str = "vector_field",
) -> Dict:
"""For Approximate k-NN Search, with Lucene Filter."""
"""For Approximate k-NN Search, with Efficient Filter for Lucene and
Faiss Engines."""
search_query = _default_approximate_search_query(
query_vector, k=k, vector_field=vector_field
)
search_query["query"]["knn"][vector_field]["filter"] = lucene_filter
search_query["query"]["knn"][vector_field]["filter"] = efficient_filter
return search_query
@ -309,11 +336,13 @@ class OpenSearchVectorSearch(VectorStore):
opensearch_url: str,
index_name: str,
embedding_function: Embeddings,
is_aoss: bool,
**kwargs: Any,
):
"""Initialize with necessary components."""
self.embedding_function = embedding_function
self.index_name = index_name
self.is_aoss = is_aoss
self.client = _get_opensearch_client(opensearch_url, **kwargs)
@property
@ -358,6 +387,8 @@ class OpenSearchVectorSearch(VectorStore):
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
max_chunk_bytes = _get_kwargs_value(kwargs, "max_chunk_bytes", 1 * 1024 * 1024)
_validate_aoss_with_engines(self.is_aoss, engine)
mapping = _default_text_mapping(
dim, engine, space_type, ef_search, ef_construction, m, vector_field
)
@ -373,6 +404,7 @@ class OpenSearchVectorSearch(VectorStore):
text_field=text_field,
mapping=mapping,
max_chunk_bytes=max_chunk_bytes,
is_aoss=self.is_aoss,
)
def similarity_search(
@ -404,14 +436,18 @@ class OpenSearchVectorSearch(VectorStore):
Optional Args for Approximate Search:
search_type: "approximate_search"; default: "approximate_search"
boolean_filter: A Boolean filter consists of a Boolean query that
contains a k-NN query and a filter.
boolean_filter: A Boolean filter is a post filter consists of a Boolean
query that contains a k-NN query and a filter.
subquery_clause: Query clause on the knn vector field; default: "must"
lucene_filter: the Lucene algorithm decides whether to perform an exact
k-NN search with pre-filtering or an approximate search with modified
post-filtering.
post-filtering. (deprecated, use `efficient_filter`)
efficient_filter: the Lucene Engine or Faiss Engine decides whether to
perform an exact k-NN search with pre-filtering or an approximate search
with modified post-filtering.
Optional Args for Script Scoring Search:
search_type: "script_scoring"; default: "approximate_search"
@ -494,15 +530,41 @@ class OpenSearchVectorSearch(VectorStore):
search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search")
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
if (
self.is_aoss
and search_type != "approximate_search"
and search_type != SCRIPT_SCORING_SEARCH
):
raise ValueError(
"Amazon OpenSearch Service Serverless only "
"supports `approximate_search` and `script_scoring`"
)
if search_type == "approximate_search":
boolean_filter = _get_kwargs_value(kwargs, "boolean_filter", {})
subquery_clause = _get_kwargs_value(kwargs, "subquery_clause", "must")
efficient_filter = _get_kwargs_value(kwargs, "efficient_filter", {})
# `lucene_filter` is deprecated, added for Backwards Compatibility
lucene_filter = _get_kwargs_value(kwargs, "lucene_filter", {})
if boolean_filter != {} and lucene_filter != {}:
if boolean_filter != {} and efficient_filter != {}:
raise ValueError(
"Both `boolean_filter` and `lucene_filter` are provided which "
"Both `boolean_filter` and `efficient_filter` are provided which "
"is invalid"
)
if lucene_filter != {} and efficient_filter != {}:
raise ValueError(
"Both `lucene_filter` and `efficient_filter` are provided which "
"is invalid. `lucene_filter` is deprecated"
)
if lucene_filter != {} and boolean_filter != {}:
raise ValueError(
"Both `lucene_filter` and `boolean_filter` are provided which "
"is invalid. `lucene_filter` is deprecated"
)
if boolean_filter != {}:
search_query = _approximate_search_query_with_boolean_filter(
embedding,
@ -511,8 +573,16 @@ class OpenSearchVectorSearch(VectorStore):
vector_field=vector_field,
subquery_clause=subquery_clause,
)
elif efficient_filter != {}:
search_query = _approximate_search_query_with_efficient_filter(
embedding, efficient_filter, k=k, vector_field=vector_field
)
elif lucene_filter != {}:
search_query = _approximate_search_query_with_lucene_filter(
warnings.warn(
"`lucene_filter` is deprecated. Please use the keyword argument"
" `efficient_filter`"
)
search_query = _approximate_search_query_with_efficient_filter(
embedding, lucene_filter, k=k, vector_field=vector_field
)
else:
@ -659,6 +729,7 @@ class OpenSearchVectorSearch(VectorStore):
"ef_construction",
"m",
"max_chunk_bytes",
"is_aoss",
]
embeddings = embedding.embed_documents(texts)
_validate_embeddings_and_bulk_size(len(embeddings), bulk_size)
@ -672,6 +743,15 @@ class OpenSearchVectorSearch(VectorStore):
vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field")
text_field = _get_kwargs_value(kwargs, "text_field", "text")
max_chunk_bytes = _get_kwargs_value(kwargs, "max_chunk_bytes", 1 * 1024 * 1024)
http_auth = _get_kwargs_value(kwargs, "http_auth", None)
is_aoss = _is_aoss_enabled(http_auth=http_auth)
if is_aoss and not is_appx_search:
raise ValueError(
"Amazon OpenSearch Service Serverless only "
"supports `approximate_search`"
)
if is_appx_search:
engine = _get_kwargs_value(kwargs, "engine", "nmslib")
space_type = _get_kwargs_value(kwargs, "space_type", "l2")
@ -679,6 +759,8 @@ class OpenSearchVectorSearch(VectorStore):
ef_construction = _get_kwargs_value(kwargs, "ef_construction", 512)
m = _get_kwargs_value(kwargs, "m", 16)
_validate_aoss_with_engines(is_aoss, engine)
mapping = _default_text_mapping(
dim, engine, space_type, ef_search, ef_construction, m, vector_field
)
@ -697,5 +779,6 @@ class OpenSearchVectorSearch(VectorStore):
text_field=text_field,
mapping=mapping,
max_chunk_bytes=max_chunk_bytes,
is_aoss=is_aoss,
)
return cls(opensearch_url, index_name, embedding, **kwargs)
return cls(opensearch_url, index_name, embedding, is_aoss, **kwargs)

@ -1,6 +1,8 @@
"""Test OpenSearch functionality."""
import boto3
import pytest
from opensearchpy import AWSV4SignerAuth
from langchain.docstore.document import Document
from langchain.vectorstores.opensearch_vector_search import (
@ -213,3 +215,95 @@ def test_opensearch_with_custom_field_name_appx_false() -> None:
)
output = docsearch.similarity_search("add", k=1)
assert output == [Document(page_content="add")]
def test_opensearch_serverless_with_scripting_search_indexing_throws_error() -> None:
"""Test to validate indexing using Serverless without Approximate Search."""
region = "test-region"
service = "aoss"
credentials = boto3.Session().get_credentials()
auth = AWSV4SignerAuth(credentials, region, service)
with pytest.raises(ValueError):
OpenSearchVectorSearch.from_texts(
texts,
FakeEmbeddings(),
opensearch_url=DEFAULT_OPENSEARCH_URL,
is_appx_search=False,
http_auth=auth,
)
def test_opensearch_serverless_with_lucene_engine_throws_error() -> None:
"""Test to validate indexing using lucene engine with Serverless."""
region = "test-region"
service = "aoss"
credentials = boto3.Session().get_credentials()
auth = AWSV4SignerAuth(credentials, region, service)
with pytest.raises(ValueError):
OpenSearchVectorSearch.from_texts(
texts,
FakeEmbeddings(),
opensearch_url=DEFAULT_OPENSEARCH_URL,
engine="lucene",
http_auth=auth,
)
def test_appx_search_with_efficient_and_bool_filter_throws_error() -> None:
"""Test Approximate Search with Efficient and Bool Filter throws Error."""
efficient_filter_val = {"bool": {"must": [{"term": {"text": "baz"}}]}}
boolean_filter_val = {"bool": {"must": [{"term": {"text": "bar"}}]}}
docsearch = OpenSearchVectorSearch.from_texts(
texts, FakeEmbeddings(), opensearch_url=DEFAULT_OPENSEARCH_URL, engine="lucene"
)
with pytest.raises(ValueError):
docsearch.similarity_search(
"foo",
k=3,
efficient_filter=efficient_filter_val,
boolean_filter=boolean_filter_val,
)
def test_appx_search_with_efficient_and_lucene_filter_throws_error() -> None:
"""Test Approximate Search with Efficient and Lucene Filter throws Error."""
efficient_filter_val = {"bool": {"must": [{"term": {"text": "baz"}}]}}
lucene_filter_val = {"bool": {"must": [{"term": {"text": "bar"}}]}}
docsearch = OpenSearchVectorSearch.from_texts(
texts, FakeEmbeddings(), opensearch_url=DEFAULT_OPENSEARCH_URL, engine="lucene"
)
with pytest.raises(ValueError):
docsearch.similarity_search(
"foo",
k=3,
efficient_filter=efficient_filter_val,
lucene_filter=lucene_filter_val,
)
def test_appx_search_with_boolean_and_lucene_filter_throws_error() -> None:
"""Test Approximate Search with Boolean and Lucene Filter throws Error."""
boolean_filter_val = {"bool": {"must": [{"term": {"text": "baz"}}]}}
lucene_filter_val = {"bool": {"must": [{"term": {"text": "bar"}}]}}
docsearch = OpenSearchVectorSearch.from_texts(
texts, FakeEmbeddings(), opensearch_url=DEFAULT_OPENSEARCH_URL, engine="lucene"
)
with pytest.raises(ValueError):
docsearch.similarity_search(
"foo",
k=3,
boolean_filter=boolean_filter_val,
lucene_filter=lucene_filter_val,
)
def test_appx_search_with_faiss_efficient_filter() -> None:
"""Test Approximate Search with Faiss Efficient Filter."""
efficient_filter_val = {"bool": {"must": [{"term": {"text": "bar"}}]}}
docsearch = OpenSearchVectorSearch.from_texts(
texts, FakeEmbeddings(), opensearch_url=DEFAULT_OPENSEARCH_URL, engine="faiss"
)
output = docsearch.similarity_search(
"foo", k=3, efficient_filter=efficient_filter_val
)
assert output == [Document(page_content="bar")]

Loading…
Cancel
Save