From da0f750a0bf4eacf00abe7325b5c76ad609c7773 Mon Sep 17 00:00:00 2001 From: YISH Date: Wed, 3 Jan 2024 08:12:00 +0800 Subject: [PATCH] Milvus allows to store metadata as json field (#14636) Because Milvus doesn't support nullable fields, but document metadata is very rich, so it makes more sense to store it as json. https://github.com/milvus-io/pymilvus/issues/1705#issuecomment-1731112372 --------- Co-authored-by: Harrison Chase --- .../vectorstores/milvus.py | 77 ++++++++++++------- 1 file changed, 50 insertions(+), 27 deletions(-) diff --git a/libs/community/langchain_community/vectorstores/milvus.py b/libs/community/langchain_community/vectorstores/milvus.py index 213e86c7b7..0d297f87de 100644 --- a/libs/community/langchain_community/vectorstores/milvus.py +++ b/libs/community/langchain_community/vectorstores/milvus.py @@ -53,6 +53,9 @@ class Milvus(VectorStore): primary_field (str): Name of the primary key field. Defaults to "pk". text_field (str): Name of the text field. Defaults to "text". vector_field (str): Name of the vector field. Defaults to "vector". + metadata_field (str): Name of the metadta field. Defaults to None. + When metadata_field is specified, + the document's metadata will store as json. The connection args used for this class comes in the form of a dict, here are a few of the options: @@ -112,6 +115,7 @@ class Milvus(VectorStore): primary_field: str = "pk", text_field: str = "text", vector_field: str = "vector", + metadata_field: Optional[str] = None, ): """Initialize the Milvus vector store.""" try: @@ -148,6 +152,7 @@ class Milvus(VectorStore): self._text_field = text_field # In order for compatibility, the vector field needs to be called "vector" self._vector_field = vector_field + self._metadata_field = metadata_field self.fields: list[str] = [] # Create the connection to the server if connection_args is None: @@ -250,24 +255,32 @@ class Milvus(VectorStore): # Determine embedding dim dim = len(embeddings[0]) fields = [] - # Determine metadata schema - if metadatas: - # Create FieldSchema for each entry in metadata. - for key, value in metadatas[0].items(): - # Infer the corresponding datatype of the metadata - dtype = infer_dtype_bydata(value) - # Datatype isn't compatible - if dtype == DataType.UNKNOWN or dtype == DataType.NONE: - logger.error( - "Failure to create collection, unrecognized dtype for key: %s", - key, - ) - raise ValueError(f"Unrecognized datatype for {key}.") - # Dataype is a string/varchar equivalent - elif dtype == DataType.VARCHAR: - fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535)) - else: - fields.append(FieldSchema(key, dtype)) + if self._metadata_field is not None: + fields.append(FieldSchema(self._metadata_field, DataType.JSON)) + else: + # Determine metadata schema + if metadatas: + # Create FieldSchema for each entry in metadata. + for key, value in metadatas[0].items(): + # Infer the corresponding datatype of the metadata + dtype = infer_dtype_bydata(value) + # Datatype isn't compatible + if dtype == DataType.UNKNOWN or dtype == DataType.NONE: + logger.error( + ( + "Failure to create collection, " + "unrecognized dtype for key: %s" + ), + key, + ) + raise ValueError(f"Unrecognized datatype for {key}.") + # Dataype is a string/varchar equivalent + elif dtype == DataType.VARCHAR: + fields.append( + FieldSchema(key, DataType.VARCHAR, max_length=65_535) + ) + else: + fields.append(FieldSchema(key, dtype)) # Create the text field fields.append( @@ -442,12 +455,16 @@ class Milvus(VectorStore): self._vector_field: embeddings, } - # Collect the metadata into the insert dict. - if metadatas is not None: + if self._metadata_field is not None: for d in metadatas: - for key, value in d.items(): - if key in self.fields: - insert_dict.setdefault(key, []).append(value) + insert_dict.setdefault(self._metadata_field, []).append(d) + else: + # Collect the metadata into the insert dict. + if metadatas is not None: + for d in metadatas: + for key, value in d.items(): + if key in self.fields: + insert_dict.setdefault(key, []).append(value) # Total insert count vectors: list = insert_dict[self._vector_field] @@ -630,8 +647,8 @@ class Milvus(VectorStore): # Organize results. ret = [] for result in res[0]: - meta = {x: result.entity.get(x) for x in output_fields} - doc = Document(page_content=meta.pop(self._text_field), metadata=meta) + data = {x: result.entity.get(x) for x in output_fields} + doc = self._parse_document(data) pair = (doc, result.score) ret.append(pair) @@ -746,8 +763,8 @@ class Milvus(VectorStore): documents = [] scores = [] for result in res[0]: - meta = {x: result.entity.get(x) for x in output_fields} - doc = Document(page_content=meta.pop(self._text_field), metadata=meta) + data = {x: result.entity.get(x) for x in output_fields} + doc = self._parse_document(data) documents.append(doc) scores.append(result.score) ids.append(result.id) @@ -826,3 +843,9 @@ class Milvus(VectorStore): ) vector_db.add_texts(texts=texts, metadatas=metadatas) return vector_db + + def _parse_document(self, data: dict) -> Document: + return Document( + page_content=data.pop(self._text_field), + metadata=data.pop(self._metadata_field) if self._metadata_field else data, + )