partners[milvus]: add dynamic field (#24544)

add dynamic field feature to langchain_milvus
more unittest, more robustic

plan to deprecate the `metadata_field` in the future, because it's
function is the same as `enable_dynamic_field`, but the latter one is a
more advanced concept in milvus

Signed-off-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
ChengZi 2024-07-25 04:01:58 +08:00 committed by GitHub
parent 20fe4deea0
commit 29a3b3a711
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 291 additions and 86 deletions

View File

@ -130,24 +130,46 @@ class Milvus(VectorStore):
drop_old (Optional[bool]): Whether to drop the current collection. Defaults
to False.
auto_id (bool): Whether to enable auto id for primary key. Defaults to False.
If False, you needs to provide text ids (string less than 65535 bytes).
If False, you need to provide text ids (string less than 65535 bytes).
If True, Milvus will generate unique integers as primary keys.
primary_field (str): Name of the primary key field. Defaults to "pk".
text_field (str): Name of the text field. Defaults to "text".
vector_field (str): Name of the vector field. Defaults to "vector".
metadata_field (str): Name of the metadta field. Defaults to None.
enable_dynamic_field (Optional[bool]): Whether to enable
dynamic schema or not in Milvus. Defaults to False.
For more information about dynamic schema, please refer to
https://milvus.io/docs/enable-dynamic-field.md
metadata_field (str): Name of the metadata field. Defaults to None.
When metadata_field is specified,
the document's metadata will store as json.
This argument is about to be deprecated,
because it can be replaced by setting `enable_dynamic_field`=True.
partition_key_field (Optional[str]): Name of the partition key field.
Defaults to None. For more information about partition key, please refer to
https://milvus.io/docs/use-partition-key.md#Use-Partition-Key
partition_names (Optional[list]): List of specific partition names.
Defaults to None. For more information about partition, please refer to
https://milvus.io/docs/manage-partitions.md#Manage-Partitions
replica_number (int): The number of replicas for the collection. Defaults to 1.
For more information about replica, please refer to
https://milvus.io/docs/replica.md#In-Memory-Replica
timeout (Optional[float]): The timeout for Milvus operations. Defaults to None.
An optional duration of time in seconds to allow for the RPCs.
If timeout is not set, the client keeps waiting until the server responds
or an error occurs.
num_shards (Optional[int]): The number of shards for the collection.
Defaults to None. For more information about shards, please refer to
https://milvus.io/docs/glossary.md#Shard
The connection args used for this class comes in the form of a dict,
here are a few of the options:
address (str): The actual address of Milvus
instance. Example address: "localhost:19530"
uri (str): The uri of Milvus instance. Example uri:
"path/to/local/directory/milvus_demo.db" for Milvus Lite.
"http://randomwebsite:19530",
"tcp:foobarsite:19530",
"https://ok.s3.south.com:19530".
or "path/to/local/directory/milvus_demo.db" for Milvus Lite.
host (str): The host of Milvus instance. Default at "localhost",
PyMilvus will fill in the default host if only port is provided.
port (str/int): The port of Milvus instance. Default at 19530, PyMilvus
@ -178,6 +200,7 @@ class Milvus(VectorStore):
milvus_store = Milvus(
embedding_function = Embeddings,
collection_name = "LangChainCollection",
connection_args = {"uri": "./milvus_demo.db"},
drop_old = True,
auto_id = True
)
@ -202,6 +225,7 @@ class Milvus(VectorStore):
primary_field: str = "pk",
text_field: str = "text",
vector_field: str = "vector",
enable_dynamic_field: bool = False,
metadata_field: Optional[str] = None,
partition_key_field: Optional[str] = None,
partition_names: Optional[list] = None,
@ -260,6 +284,17 @@ class Milvus(VectorStore):
self._text_field = text_field
# In order for compatibility, the vector field needs to be called "vector"
self._vector_field = vector_field
if metadata_field:
logger.warning(
"DeprecationWarning: `metadata_field` is about to be deprecated, "
"please set `enable_dynamic_field`=True instead."
)
if enable_dynamic_field and metadata_field:
metadata_field = None
logger.warning(
"When `enable_dynamic_field` is True, `metadata_field` is ignored."
)
self.enable_dynamic_field = enable_dynamic_field
self._metadata_field = metadata_field
self._partition_key_field = partition_key_field
self.fields: list[str] = []
@ -389,13 +424,36 @@ class Milvus(VectorStore):
# Determine embedding dim
dim = len(embeddings[0])
fields = []
if self._metadata_field is not None:
# If enable_dynamic_field, we don't need to create fields, and just pass it.
# In the future, when metadata_field is deprecated,
# This logical structure will be simplified like this:
# ```
# if not self.enable_dynamic_field and metadatas:
# for key, value in metadatas[0].items():
# ...
# ```
if self.enable_dynamic_field:
pass
elif self._metadata_field is not None:
fields.append(FieldSchema(self._metadata_field, DataType.JSON))
else:
# Determine metadata schema
if metadatas:
# Create FieldSchema for each entry in metadata.
for key, value in metadatas[0].items():
if key in [
self._vector_field,
self._primary_field,
self._text_field,
]:
logger.error(
(
"Failure to create collection, "
"metadata key: %s is reserved."
),
key,
)
raise ValueError(f"Metadata key {key} is reserved.")
# Infer the corresponding datatype of the metadata
dtype = infer_dtype_bydata(value)
# Datatype isn't compatible
@ -408,7 +466,7 @@ class Milvus(VectorStore):
key,
)
raise ValueError(f"Unrecognized datatype for {key}.")
# Dataype is a string/varchar equivalent
# Datatype is a string/varchar equivalent
elif dtype == DataType.VARCHAR:
fields.append(
FieldSchema(key, DataType.VARCHAR, max_length=65_535)
@ -447,6 +505,7 @@ class Milvus(VectorStore):
fields,
description=self.collection_description,
partition_key_field=self._partition_key_field,
enable_dynamic_field=self.enable_dynamic_field,
)
# Create the collection
@ -617,16 +676,26 @@ class Milvus(VectorStore):
texts = list(texts)
if not self.auto_id:
assert isinstance(
ids, list
), "A list of valid ids are required when auto_id is False."
assert isinstance(ids, list), (
"A list of valid ids are required when auto_id is False. "
"You can set `auto_id` to True in this Milvus instance to generate "
"ids automatically, or specify string-type ids for each text."
)
assert len(set(ids)) == len(
texts
), "Different lengths of texts and unique ids are provided."
assert all(isinstance(x, str) for x in ids), "All ids should be strings."
assert all(
len(x.encode()) <= 65_535 for x in ids
), "Each id should be a string less than 65535 bytes."
else:
if ids is not None:
logger.warning(
"The ids parameter is ignored when auto_id is True. "
"The ids will be generated automatically."
)
try:
embeddings = self.embedding_func.embed_documents(texts)
except NotImplementedError:
@ -647,34 +716,39 @@ class Milvus(VectorStore):
kwargs["timeout"] = self.timeout
self._init(**kwargs)
# Dict to hold all insert columns
insert_dict: dict[str, list] = {
self._text_field: texts,
self._vector_field: embeddings,
}
insert_list: list[dict] = []
if not self.auto_id:
insert_dict[self._primary_field] = ids # type: ignore[assignment]
assert len(texts) == len(
embeddings
), "Mismatched lengths of texts and embeddings."
if metadatas is not None:
assert len(texts) == len(
metadatas
), "Mismatched lengths of texts and metadatas."
if self._metadata_field is not None:
for d in metadatas: # type: ignore[union-attr]
insert_dict.setdefault(self._metadata_field, []).append(d)
else:
# Collect the metadata into the insert dict.
if metadatas is not None:
for d in metadatas:
for key, value in d.items():
keys = (
[x for x in self.fields if x != self._primary_field]
if self.auto_id
else [x for x in self.fields]
)
if key in keys:
insert_dict.setdefault(key, []).append(value)
for i, text, embedding in zip(range(len(texts)), texts, embeddings):
entity_dict = {}
metadata = metadatas[i] if metadatas else {}
if not self.auto_id:
entity_dict[self._primary_field] = ids[i] # type: ignore[index]
entity_dict[self._text_field] = text
entity_dict[self._vector_field] = embedding
if self._metadata_field and not self.enable_dynamic_field:
entity_dict[self._metadata_field] = metadata
else:
for key, value in metadata.items():
# if not enable_dynamic_field, skip fields not in the collection.
if not self.enable_dynamic_field and key not in self.fields:
continue
# If enable_dynamic_field, all fields are allowed.
entity_dict[key] = value
insert_list.append(entity_dict)
# Total insert count
vectors: list = insert_dict[self._vector_field]
total_count = len(vectors)
total_count = len(insert_list)
pks: list[str] = []
@ -682,15 +756,12 @@ class Milvus(VectorStore):
for i in range(0, total_count, batch_size):
# Grab end index
end = min(i + batch_size, total_count)
# Convert dict to list of lists batch for insertion
insert_list = [
insert_dict[x][i:end] for x in self.fields if x in insert_dict
]
batch_insert_list = insert_list[i:end]
# Insert into the collection.
try:
res: Collection
timeout = self.timeout or timeout
res = self.col.insert(insert_list, timeout=timeout, **kwargs)
res = self.col.insert(batch_insert_list, timeout=timeout, **kwargs)
pks.extend(res.primary_keys)
except MilvusException as e:
logger.error(
@ -699,6 +770,61 @@ class Milvus(VectorStore):
raise e
return pks
def _collection_search(
self,
embedding: List[float],
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[float] = None,
**kwargs: Any,
) -> "pymilvus.client.abstract.SearchResult | None": # type: ignore[name-defined] # noqa: F821
"""Perform a search on an embedding and return milvus search results.
For more information about the search parameters, take a look at the pymilvus
documentation found here:
https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/search.md
Args:
embedding (List[float]): The embedding vector being searched.
k (int, optional): The amount of results to return. Defaults to 4.
param (dict): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (float, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
pymilvus.client.abstract.SearchResult: Milvus search result.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return None
if param is None:
param = self.search_params
# Determine result metadata fields with PK.
if self.enable_dynamic_field:
output_fields = ["*"]
else:
output_fields = self.fields[:]
output_fields.remove(self._vector_field)
timeout = self.timeout or timeout
# Perform the search.
res = self.col.search(
data=[embedding],
anns_field=self._vector_field,
param=param,
limit=k,
expr=expr,
output_fields=output_fields,
timeout=timeout,
**kwargs,
)
return res
def similarity_search(
self,
query: str,
@ -778,7 +904,7 @@ class Milvus(VectorStore):
For more information about the search parameters, take a look at the pymilvus
documentation found here:
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/search.md
Args:
query (str): The text being searched.
@ -814,11 +940,11 @@ class Milvus(VectorStore):
timeout: Optional[float] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Perform a search on a query string and return results with score.
"""Perform a search on an embedding and return results with score.
For more information about the search parameters, take a look at the pymilvus
documentation found here:
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/search.md
Args:
embedding (List[float]): The embedding vector being searched.
@ -833,32 +959,14 @@ class Milvus(VectorStore):
Returns:
List[Tuple[Document, float]]: Result doc and score.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
if param is None:
param = self.search_params
# Determine result metadata fields with PK.
output_fields = self.fields[:]
output_fields.remove(self._vector_field)
timeout = self.timeout or timeout
# Perform the search.
res = self.col.search(
data=[embedding],
anns_field=self._vector_field,
param=param,
limit=k,
expr=expr,
output_fields=output_fields,
timeout=timeout,
**kwargs,
col_search_res = self._collection_search(
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
)
# Organize results.
if col_search_res is None:
return []
ret = []
for result in res[0]:
data = {x: result.entity.get(x) for x in output_fields}
for result in col_search_res[0]:
data = {x: result.entity.get(x) for x in result.entity.fields}
doc = self._parse_document(data)
pair = (doc, result.score)
ret.append(pair)
@ -947,40 +1055,27 @@ class Milvus(VectorStore):
Returns:
List[Document]: Document results for search.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
if param is None:
param = self.search_params
# Determine result metadata fields.
output_fields = self.fields[:]
output_fields.remove(self._vector_field)
timeout = self.timeout or timeout
# Perform the search.
res = self.col.search(
data=[embedding],
anns_field=self._vector_field,
col_search_res = self._collection_search(
embedding=embedding,
k=fetch_k,
param=param,
limit=fetch_k,
expr=expr,
output_fields=output_fields,
timeout=timeout,
**kwargs,
)
# Organize results.
if col_search_res is None:
return []
ids = []
documents = []
scores = []
for result in res[0]:
data = {x: result.entity.get(x) for x in output_fields}
for result in col_search_res[0]:
data = {x: result.entity.get(x) for x in result.entity.fields}
doc = self._parse_document(data)
documents.append(doc)
scores.append(result.score)
ids.append(result.id)
vectors = self.col.query(
vectors = self.col.query( # type: ignore[union-attr]
expr=f"{self._primary_field} in {ids}",
output_fields=[self._primary_field, self._vector_field],
timeout=timeout,
@ -1089,6 +1184,8 @@ class Milvus(VectorStore):
return vector_db
def _parse_document(self, data: dict) -> Document:
if self._vector_field in data:
data.pop(self._vector_field)
return Document(
page_content=data.pop(self._text_field),
metadata=data.pop(self._metadata_field) if self._metadata_field else data,

View File

@ -1,6 +1,7 @@
"""Test Milvus functionality."""
from typing import Any, List, Optional
import pytest
from langchain_core.documents import Document
from langchain_milvus.vectorstores import Milvus
@ -27,6 +28,7 @@ def _milvus_from_texts(
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
drop: bool = True,
**kwargs: Any,
) -> Milvus:
return Milvus.from_texts(
fake_texts,
@ -36,6 +38,7 @@ def _milvus_from_texts(
# connection_args={"uri": "http://127.0.0.1:19530"},
connection_args={"uri": "./milvus_demo.db"},
drop_old=drop,
**kwargs,
)
@ -50,6 +53,15 @@ def test_milvus() -> None:
assert_docs_equal_without_pk(output, [Document(page_content="foo")])
def test_milvus_vector_search() -> None:
"""Test end to end construction and search by vector."""
docsearch = _milvus_from_texts()
output = docsearch.similarity_search_by_vector(
FakeEmbeddings().embed_query("foo"), k=1
)
assert_docs_equal_without_pk(output, [Document(page_content="foo")])
def test_milvus_with_metadata() -> None:
"""Test with metadata"""
docsearch = _milvus_from_texts(metadatas=[{"label": "test"}] * len(fake_texts))
@ -110,6 +122,21 @@ def test_milvus_max_marginal_relevance_search() -> None:
)
def test_milvus_max_marginal_relevance_search_with_dynamic_field() -> None:
"""Test end to end construction and MRR search with enabling dynamic field."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _milvus_from_texts(metadatas=metadatas, enable_dynamic_field=True)
output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3)
assert_docs_equal_without_pk(
output,
[
Document(page_content="foo", metadata={"page": 0}),
Document(page_content="baz", metadata={"page": 2}),
],
)
def test_milvus_add_extra() -> None:
"""Test end to end construction and MRR search."""
texts = ["foo", "bar", "baz"]
@ -123,7 +150,7 @@ def test_milvus_add_extra() -> None:
def test_milvus_no_drop() -> None:
"""Test end to end construction and MRR search."""
"""Test construction without dropping old data."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _milvus_from_texts(metadatas=metadatas)
@ -171,14 +198,95 @@ def test_milvus_upsert_entities() -> None:
assert len(ids) == 2 # type: ignore[arg-type]
def test_milvus_enable_dynamic_field() -> None:
"""Test end to end construction and enable dynamic field"""
texts = ["foo", "bar", "baz"]
metadatas = [{"id": i} for i in range(len(texts))]
docsearch = _milvus_from_texts(metadatas=metadatas, enable_dynamic_field=True)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 3
# When enable dynamic field, any new field data will be added to the collection.
new_metadatas = [{"id_new": i} for i in range(len(texts))]
docsearch.add_texts(texts, new_metadatas)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 6
assert set(docsearch.fields) == {
docsearch._primary_field,
docsearch._text_field,
docsearch._vector_field,
}
def test_milvus_disable_dynamic_field() -> None:
"""Test end to end construction and disable dynamic field"""
texts = ["foo", "bar", "baz"]
metadatas = [{"id": i} for i in range(len(texts))]
docsearch = _milvus_from_texts(metadatas=metadatas, enable_dynamic_field=False)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 3
# ["pk", "text", "vector", "id"]
assert set(docsearch.fields) == {
docsearch._primary_field,
docsearch._text_field,
docsearch._vector_field,
"id",
}
# Try to add new fields "id_new", but since dynamic field is disabled,
# all fields in the collection is specified as ["pk", "text", "vector", "id"],
# new field information "id_new" will not be added.
new_metadatas = [{"id": i, "id_new": i} for i in range(len(texts))]
docsearch.add_texts(texts, new_metadatas)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 6
for doc in output:
assert set(doc.metadata.keys()) == {"id", "pk"} # `id_new` is not added.
# When disable dynamic field,
# missing data of the created fields "id", will raise an exception.
with pytest.raises(Exception):
new_metadatas = [{"id_new": i} for i in range(len(texts))]
docsearch.add_texts(texts, new_metadatas)
def test_milvus_metadata_field() -> None:
"""Test end to end construction and use metadata field"""
texts = ["foo", "bar", "baz"]
metadatas = [{"id": i} for i in range(len(texts))]
docsearch = _milvus_from_texts(metadatas=metadatas, metadata_field="metadata")
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 3
new_metadatas = [{"id_new": i} for i in range(len(texts))]
docsearch.add_texts(texts, new_metadatas)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 6
assert set(docsearch.fields) == {
docsearch._primary_field,
docsearch._text_field,
docsearch._vector_field,
docsearch._metadata_field,
}
# if __name__ == "__main__":
# test_milvus()
# test_milvus_vector_search()
# test_milvus_with_metadata()
# test_milvus_with_id()
# test_milvus_with_score()
# test_milvus_max_marginal_relevance_search()
# test_milvus_max_marginal_relevance_search_with_dynamic_field()
# test_milvus_add_extra()
# test_milvus_no_drop()
# test_milvus_get_pks()
# test_milvus_delete_entities()
# test_milvus_upsert_entities()
# test_milvus_enable_dynamic_field()
# test_milvus_disable_dynamic_field()
# test_milvus_metadata_field()