langchain/libs/partners/milvus/langchain_milvus/vectorstores/zilliz.py
Ohad Eytan b5d670498f
partners/milvus: allow creating a vectorstore with sparse embeddings (#25284)
# Description
Milvus (and `pymilvus`) recently added the option to use [sparse
vectors](https://milvus.io/docs/sparse_vector.md#Sparse-Vector) with
appropriate search methods (e.g., `SPARSE_INVERTED_INDEX`) and
embeddings (e.g., `BM25`, `SPLADE`).

This PR allow creating a vector store using langchain's `Milvus` class,
setting the matching vector field type to `DataType.SPARSE_FLOAT_VECTOR`
and the default index type to `SPARSE_INVERTED_INDEX`.

It is only extending functionality, and backward compatible. 

## Note
I also interested in extending the Milvus class further to support multi
vector search (aka hybrid search). Will be happy to discuss that. See
[here](https://github.com/langchain-ai/langchain/discussions/19955),
[here](https://github.com/langchain-ai/langchain/pull/20375), and
[here](https://github.com/langchain-ai/langchain/discussions/22886)
similar needs.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
2024-08-30 02:30:23 +00:00

198 lines
8.2 KiB
Python

from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional, Union
from langchain_core.embeddings import Embeddings
from langchain_milvus.utils.sparse import BaseSparseEmbedding
from langchain_milvus.vectorstores.milvus import Milvus
logger = logging.getLogger(__name__)
class Zilliz(Milvus):
"""`Zilliz` vector store.
You need to have `pymilvus` installed and a
running Zilliz database.
See the following documentation for how to run a Zilliz instance:
https://docs.zilliz.com/docs/create-cluster
IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA.
Args:
embedding_function (Embeddings): Function used to embed the text.
collection_name (str): Which Zilliz collection to use. Defaults to
"LangChainCollection".
connection_args (Optional[dict[str, any]]): The connection args used for
this class comes in the form of a dict.
consistency_level (str): The consistency level to use for a collection.
Defaults to "Session".
index_params (Optional[dict]): Which index params to use. Defaults to
HNSW/AUTOINDEX depending on service.
search_params (Optional[dict]): Which search params to use. Defaults to
default of index.
drop_old (Optional[bool]): Whether to drop the current collection. Defaults
to False.
auto_id (bool): Whether to enable auto id for primary key. Defaults to False.
If False, you needs to provide text ids (string less than 65535 bytes).
If True, Milvus will generate unique integers as primary keys.
The connection args used for this class comes in the form of a dict,
here are a few of the options:
address (str): The actual address of Zilliz
instance. Example address: "localhost:19530"
uri (str): The uri of Zilliz instance. Example uri:
"https://in03-ba4234asae.api.gcp-us-west1.zillizcloud.com",
host (str): The host of Zilliz instance. Default at "localhost",
PyMilvus will fill in the default host if only port is provided.
port (str/int): The port of Zilliz instance. Default at 19530, PyMilvus
will fill in the default port if only host is provided.
user (str): Use which user to connect to Zilliz instance. If user and
password are provided, we will add related header in every RPC call.
password (str): Required when user is provided. The password
corresponding to the user.
token (str): API key, for serverless clusters which can be used as
replacements for user and password.
secure (bool): Default is false. If set to true, tls will be enabled.
client_key_path (str): If use tls two-way authentication, need to
write the client.key path.
client_pem_path (str): If use tls two-way authentication, need to
write the client.pem path.
ca_pem_path (str): If use tls two-way authentication, need to write
the ca.pem path.
server_pem_path (str): If use tls one-way authentication, need to
write the server.pem path.
server_name (str): If use tls, need to write the common name.
Example:
.. code-block:: python
from langchain_community.vectorstores import Zilliz
from langchain_community.embeddings import OpenAIEmbeddings
embedding = OpenAIEmbeddings()
# Connect to a Zilliz instance
milvus_store = Milvus(
embedding_function = embedding,
collection_name = "LangChainCollection",
connection_args = {
"uri": "https://in03-ba4234asae.api.gcp-us-west1.zillizcloud.com",
"user": "temp",
"password": "temp",
"token": "temp", # API key as replacements for user and password
"secure": True
}
drop_old: True,
)
Raises:
ValueError: If the pymilvus python package is not installed.
"""
def _create_index(self) -> None:
"""Create a index on the collection"""
from pymilvus import Collection, MilvusException
if isinstance(self.col, Collection) and self._get_index() is None:
try:
# If no index params, use a default AutoIndex based one
if self.index_params is None:
self.index_params = {
"metric_type": "L2",
"index_type": "AUTOINDEX",
"params": {},
}
try:
self.col.create_index(
self._vector_field,
index_params=self.index_params,
using=self.alias,
)
# If default did not work, most likely Milvus self-hosted
except MilvusException:
# Use HNSW based index
self.index_params = {
"metric_type": "L2",
"index_type": "HNSW",
"params": {"M": 8, "efConstruction": 64},
}
self.col.create_index(
self._vector_field,
index_params=self.index_params,
using=self.alias,
)
logger.debug(
"Successfully created an index on collection: %s",
self.collection_name,
)
except MilvusException as e:
logger.error(
"Failed to create an index on collection: %s", self.collection_name
)
raise e
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Union[Embeddings, BaseSparseEmbedding],
metadatas: Optional[List[dict]] = None,
collection_name: str = "LangChainCollection",
connection_args: Optional[Dict[str, Any]] = None,
consistency_level: str = "Session",
index_params: Optional[dict] = None,
search_params: Optional[dict] = None,
drop_old: bool = False,
*,
ids: Optional[List[str]] = None,
auto_id: bool = False,
**kwargs: Any,
) -> Zilliz:
"""Create a Zilliz collection, indexes it with HNSW, and insert data.
Args:
texts (List[str]): Text data.
embedding (Embeddings): Embedding function.
metadatas (Optional[List[dict]]): Metadata for each text if it exists.
Defaults to None.
collection_name (str, optional): Collection name to use. Defaults to
"LangChainCollection".
connection_args (dict[str, Any], optional): Connection args to use. Defaults
to DEFAULT_MILVUS_CONNECTION.
consistency_level (str, optional): Which consistency level to use. Defaults
to "Session".
index_params (Optional[dict], optional): Which index_params to use.
Defaults to None.
search_params (Optional[dict], optional): Which search params to use.
Defaults to None.
drop_old (Optional[bool], optional): Whether to drop the collection with
that name if it exists. Defaults to False.
ids (Optional[List[str]]): List of text ids.
auto_id (bool): Whether to enable auto id for primary key. Defaults to
False. If False, you needs to provide text ids (string less than 65535
bytes). If True, Milvus will generate unique integers as primary keys.
Returns:
Zilliz: Zilliz Vector Store
"""
vector_db = cls(
embedding_function=embedding,
collection_name=collection_name,
connection_args=connection_args or {},
consistency_level=consistency_level,
index_params=index_params,
search_params=search_params,
drop_old=drop_old,
auto_id=auto_id,
**kwargs,
)
vector_db.add_texts(texts=texts, metadatas=metadatas, ids=ids)
return vector_db