Harrison/optional ids opensearch (#6684)

Co-authored-by: taekimsmar <66041442+taekimsmar@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2023-06-24 09:19:57 -07:00 committed by GitHub
parent 2518e6c95b
commit c289cc891a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -77,6 +77,7 @@ def _bulk_ingest_embeddings(
embeddings: List[List[float]],
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
vector_field: str = "vector_field",
text_field: str = "text",
mapping: Optional[Dict] = None,
@ -88,7 +89,7 @@ def _bulk_ingest_embeddings(
bulk = _import_bulk()
not_found_error = _import_not_found_error()
requests = []
ids = []
return_ids = []
mapping = mapping
try:
@ -98,7 +99,7 @@ def _bulk_ingest_embeddings(
for i, text in enumerate(texts):
metadata = metadatas[i] if metadatas else {}
_id = str(uuid.uuid4())
_id = ids[i] if ids else str(uuid.uuid4())
request = {
"_op_type": "index",
"_index": index_name,
@ -108,10 +109,10 @@ def _bulk_ingest_embeddings(
"_id": _id,
}
requests.append(request)
ids.append(_id)
return_ids.append(_id)
bulk(client, requests)
client.indices.refresh(index=index_name)
return ids
return return_ids
def _default_scripting_text_mapping(
@ -318,6 +319,7 @@ class OpenSearchVectorSearch(VectorStore):
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
bulk_size: int = 500,
**kwargs: Any,
) -> List[str]:
@ -326,6 +328,7 @@ class OpenSearchVectorSearch(VectorStore):
Args:
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
ids: Optional list of ids to associate with the texts.
bulk_size: Bulk API request count; Default: 500
Returns:
@ -358,10 +361,11 @@ class OpenSearchVectorSearch(VectorStore):
self.index_name,
embeddings,
texts,
metadatas,
vector_field,
text_field,
mapping,
metadatas=metadatas,
ids=ids,
vector_field=vector_field,
text_field=text_field,
mapping=mapping,
)
def similarity_search(
@ -679,9 +683,9 @@ class OpenSearchVectorSearch(VectorStore):
index_name,
embeddings,
texts,
metadatas,
vector_field,
text_field,
mapping,
metadatas=metadatas,
vector_field=vector_field,
text_field=text_field,
mapping=mapping,
)
return cls(opensearch_url, index_name, embedding, **kwargs)