diff --git a/libs/langchain/langchain/vectorstores/opensearch_vector_search.py b/libs/langchain/langchain/vectorstores/opensearch_vector_search.py index 6de74de297..bb63c2f806 100644 --- a/libs/langchain/langchain/vectorstores/opensearch_vector_search.py +++ b/libs/langchain/langchain/vectorstores/opensearch_vector_search.py @@ -374,6 +374,7 @@ class OpenSearchVectorSearch(VectorStore): """ embeddings = self.embedding_function.embed_documents(list(texts)) _validate_embeddings_and_bulk_size(len(embeddings), bulk_size) + index_name = _get_kwargs_value(kwargs, "index_name", self.index_name) text_field = _get_kwargs_value(kwargs, "text_field", "text") dim = len(embeddings[0]) engine = _get_kwargs_value(kwargs, "engine", "nmslib") @@ -392,7 +393,7 @@ class OpenSearchVectorSearch(VectorStore): return _bulk_ingest_embeddings( self.client, - self.index_name, + index_name, embeddings, texts, metadatas=metadatas, @@ -526,6 +527,7 @@ class OpenSearchVectorSearch(VectorStore): embedding = self.embedding_function.embed_query(query) search_type = _get_kwargs_value(kwargs, "search_type", "approximate_search") vector_field = _get_kwargs_value(kwargs, "vector_field", "vector_field") + index_name = _get_kwargs_value(kwargs, "index_name", self.index_name) if ( self.is_aoss @@ -601,7 +603,7 @@ class OpenSearchVectorSearch(VectorStore): else: raise ValueError("Invalid `search_type` provided as an argument") - response = self.client.search(index=self.index_name, body=search_query) + response = self.client.search(index=index_name, body=search_query) return [hit for hit in response["hits"]["hits"]] @@ -663,6 +665,7 @@ class OpenSearchVectorSearch(VectorStore): embedding: Embeddings, metadatas: Optional[List[dict]] = None, bulk_size: int = 500, + ids: Optional[List[str]] = None, **kwargs: Any, ) -> OpenSearchVectorSearch: """Construct OpenSearchVectorSearch wrapper from raw documents. @@ -772,6 +775,7 @@ class OpenSearchVectorSearch(VectorStore): embeddings, texts, metadatas=metadatas, + ids=ids, vector_field=vector_field, text_field=text_field, mapping=mapping,