From ab749fa1bb4eb8358662d1d34f90d992ec7d7448 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Wed, 26 Apr 2023 22:08:03 -0700 Subject: [PATCH] Harrison/opensearch logic (#3631) Co-authored-by: engineer-matsuo <95115586+engineer-matsuo@users.noreply.github.com> --- .../vectorstores/opensearch_vector_search.py | 42 +++++++++++++++++-- .../vectorstores/test_opensearch.py | 23 ++++++++++ 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/langchain/vectorstores/opensearch_vector_search.py b/langchain/vectorstores/opensearch_vector_search.py index de41863d..5f432b2d 100644 --- a/langchain/vectorstores/opensearch_vector_search.py +++ b/langchain/vectorstores/opensearch_vector_search.py @@ -35,6 +35,15 @@ def _import_bulk() -> Any: return bulk +def _import_not_found_error() -> Any: + """Import not found error if available, otherwise raise error.""" + try: + from opensearchpy.exceptions import NotFoundError + except ImportError: + raise ValueError(IMPORT_OPENSEARCH_PY_ERROR) + return NotFoundError + + def _get_opensearch_client(opensearch_url: str, **kwargs: Any) -> Any: """Get OpenSearch client from the opensearch_url, otherwise raise error.""" try: @@ -67,11 +76,20 @@ def _bulk_ingest_embeddings( metadatas: Optional[List[dict]] = None, vector_field: str = "vector_field", text_field: str = "text", + mapping: Dict = {}, ) -> List[str]: """Bulk Ingest Embeddings into given index.""" bulk = _import_bulk() + not_found_error = _import_not_found_error() requests = [] ids = [] + mapping = mapping + + try: + client.indices.get(index=index_name) + except not_found_error: + client.indices.create(index=index_name, body=mapping) + for i, text in enumerate(texts): metadata = metadatas[i] if metadatas else {} _id = str(uuid.uuid4()) @@ -311,8 +329,19 @@ class OpenSearchVectorSearch(VectorStore): """ embeddings = self.embedding_function.embed_documents(list(texts)) _validate_embeddings_and_bulk_size(len(embeddings), bulk_size) - vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field") text_field = _get_kwargs_value(kwargs, "text_field", "text") + dim = len(embeddings[0]) + engine = _get_kwargs_value(kwargs, "engine", "nmslib") + space_type = _get_kwargs_value(kwargs, "space_type", "l2") + ef_search = _get_kwargs_value(kwargs, "ef_search", 512) + ef_construction = _get_kwargs_value(kwargs, "ef_construction", 512) + m = _get_kwargs_value(kwargs, "m", 16) + vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field") + + mapping = _default_text_mapping( + dim, engine, space_type, ef_search, ef_construction, m, vector_field + ) + return _bulk_ingest_embeddings( self.client, self.index_name, @@ -321,6 +350,7 @@ class OpenSearchVectorSearch(VectorStore): metadatas, vector_field, text_field, + mapping, ) def similarity_search( @@ -532,8 +562,14 @@ class OpenSearchVectorSearch(VectorStore): [kwargs.pop(key, None) for key in keys_list] client = _get_opensearch_client(opensearch_url, **kwargs) - client.indices.create(index=index_name, body=mapping) _bulk_ingest_embeddings( - client, index_name, embeddings, texts, metadatas, vector_field, text_field + client, + index_name, + embeddings, + texts, + metadatas, + vector_field, + text_field, + mapping, ) return cls(opensearch_url, index_name, embedding, **kwargs) diff --git a/tests/integration_tests/vectorstores/test_opensearch.py b/tests/integration_tests/vectorstores/test_opensearch.py index 36fb3208..88bdec09 100644 --- a/tests/integration_tests/vectorstores/test_opensearch.py +++ b/tests/integration_tests/vectorstores/test_opensearch.py @@ -174,3 +174,26 @@ def test_appx_search_with_lucene_filter() -> None: ) output = docsearch.similarity_search("foo", k=3, lucene_filter=lucene_filter_val) assert output == [Document(page_content="bar")] + + +def test_opensearch_with_custom_field_name_appx_true() -> None: + """Test Approximate Search with custom field name appx true.""" + text_input = ["test", "add", "text", "method"] + docsearch = OpenSearchVectorSearch.from_texts( + text_input, + FakeEmbeddings(), + opensearch_url=DEFAULT_OPENSEARCH_URL, + is_appx_search=True, + ) + output = docsearch.similarity_search("add", k=1) + assert output == [Document(page_content="add")] + + +def test_opensearch_with_custom_field_name_appx_false() -> None: + """Test Approximate Search with custom field name appx true.""" + text_input = ["test", "add", "text", "method"] + docsearch = OpenSearchVectorSearch.from_texts( + text_input, FakeEmbeddings(), opensearch_url=DEFAULT_OPENSEARCH_URL + ) + output = docsearch.similarity_search("add", k=1) + assert output == [Document(page_content="add")]