diff --git a/langchain/vectorstores/__init__.py b/langchain/vectorstores/__init__.py index 30743967..55b317cb 100644 --- a/langchain/vectorstores/__init__.py +++ b/langchain/vectorstores/__init__.py @@ -13,6 +13,7 @@ from langchain.vectorstores.pinecone import Pinecone from langchain.vectorstores.qdrant import Qdrant from langchain.vectorstores.supabase import SupabaseVectorStore from langchain.vectorstores.weaviate import Weaviate +from langchain.vectorstores.zilliz import Zilliz __all__ = [ "ElasticVectorSearch", @@ -22,6 +23,7 @@ __all__ = [ "Weaviate", "Qdrant", "Milvus", + "Zilliz", "Chroma", "OpenSearchVectorSearch", "AtlasDB", diff --git a/langchain/vectorstores/milvus.py b/langchain/vectorstores/milvus.py index a6a0b208..ab3b66de 100644 --- a/langchain/vectorstores/milvus.py +++ b/langchain/vectorstores/milvus.py @@ -1,8 +1,9 @@ """Wrapper around the Milvus vector database.""" from __future__ import annotations -import uuid -from typing import Any, Iterable, List, Optional, Tuple +import logging +from typing import Any, Iterable, List, Optional, Tuple, Union +from uuid import uuid4 import numpy as np @@ -11,6 +12,16 @@ from langchain.embeddings.base import Embeddings from langchain.vectorstores.base import VectorStore from langchain.vectorstores.utils import maximal_marginal_relevance +logger = logging.getLogger(__name__) + +DEFAULT_MILVUS_CONNECTION = { + "host": "localhost", + "port": "19530", + "user": "", + "password": "", + "secure": False, +} + class Milvus(VectorStore): """Wrapper around the Milvus vector database.""" @@ -18,153 +29,578 @@ class Milvus(VectorStore): def __init__( self, embedding_function: Embeddings, - connection_args: dict, - collection_name: str, - text_field: str, + collection_name: str = "LangChainCollection", + connection_args: Optional[dict[str, Any]] = None, + consistency_level: str = "Session", + index_params: Optional[dict] = None, + search_params: Optional[dict] = None, + drop_old: Optional[bool] = False, ): """Initialize wrapper around the milvus vector database. In order to use this you need to have `pymilvus` installed and a - running Milvus instance. + running Milvus/Zilliz Cloud instance. See the following documentation for how to run a Milvus instance: https://milvus.io/docs/install_standalone-docker.md + If looking for a hosted Milvus, take a looka this documentation: + https://zilliz.com/cloud + + IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA. + + The connection args used for this class comes in the form of a dict, + here are a few of the options: + address (str): The actual address of Milvus + instance. Example address: "localhost:19530" + uri (str): The uri of Milvus instance. Example uri: + "http://randomwebsite:19530", + "tcp:foobarsite:19530", + "https://ok.s3.south.com:19530". + host (str): The host of Milvus instance. Default at "localhost", + PyMilvus will fill in the default host if only port is provided. + port (str/int): The port of Milvus instance. Default at 19530, PyMilvus + will fill in the default port if only host is provided. + user (str): Use which user to connect to Milvus instance. If user and + password are provided, we will add related header in every RPC call. + password (str): Required when user is provided. The password + corresponding to the user. + secure (bool): Default is false. If set to true, tls will be enabled. + client_key_path (str): If use tls two-way authentication, need to + write the client.key path. + client_pem_path (str): If use tls two-way authentication, need to + write the client.pem path. + ca_pem_path (str): If use tls two-way authentication, need to write + the ca.pem path. + server_pem_path (str): If use tls one-way authentication, need to + write the server.pem path. + server_name (str): If use tls, need to write the common name. + Args: - embedding_function (Embeddings): Function used to embed the text - connection_args (dict): Arguments for pymilvus connections.connect() - collection_name (str): The name of the collection to search. - text_field (str): The field in Milvus schema where the - original text is stored. + embedding_function (Embeddings): Function used to embed the text. + collection_name (str): Which Milvus collection to use. Defaults to + "LangChainCollection". + connection_args (Optional[dict[str, any]]): The arguments for connection to + Milvus/Zilliz instance. Defaults to DEFAULT_MILVUS_CONNECTION. + consistency_level (str): The consistency level to use for a collection. + Defaults to "Session". + index_params (Optional[dict]): Which index params to use. Defaults to + HNSW/AUTOINDEX depending on service. + search_params (Optional[dict]): Which search params to use. Defaults to + default of index. + drop_old (Optional[bool]): Whether to drop the current collection. Defaults + to False. """ try: - from pymilvus import Collection, DataType, connections + from pymilvus import Collection, utility except ImportError: raise ValueError( "Could not import pymilvus python package. " "Please install it with `pip install pymilvus`." ) - # Connecting to Milvus instance - if not connections.has_connection("default"): - connections.connect(**connection_args) + + # Default search params when one is not provided. + self.default_search_params = { + "IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}}, + "IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}}, + "IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}}, + "HNSW": {"metric_type": "L2", "params": {"ef": 10}}, + "RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}}, + "RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}}, + "RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}}, + "IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}}, + "ANNOY": {"metric_type": "L2", "params": {"search_k": 10}}, + "AUTOINDEX": {"metric_type": "L2", "params": {}}, + } + self.embedding_func = embedding_function self.collection_name = collection_name + self.index_params = index_params + self.search_params = search_params + self.consistency_level = consistency_level + + # In order for a collection to be compatible, pk needs to be auto'id and int + self._primary_field = "pk" + # In order for compatiblility, the text field will need to be called "text" + self._text_field = "text" + # In order for compatbility, the vector field needs to be called "vector" + self._vector_field = "vector" + self.fields: list[str] = [] + # Create the connection to the server + if connection_args is None: + connection_args = DEFAULT_MILVUS_CONNECTION + self.alias = self._create_connection_alias(connection_args) + self.col: Optional[Collection] = None + + # Grab the existing colection if it exists + if utility.has_collection(self.collection_name, using=self.alias): + self.col = Collection( + self.collection_name, + using=self.alias, + ) + # If need to drop old, drop it + if drop_old and isinstance(self.col, Collection): + self.col.drop() + self.col = None + + # Initialize the vector store + self._init() + + def _create_connection_alias(self, connection_args: dict) -> str: + """Create the connection to the Milvus server.""" + from pymilvus import MilvusException, connections + + # Grab the connection arguments that are used for checking existing connection + host: str = connection_args.get("host", None) + port: Union[str, int] = connection_args.get("port", None) + address: str = connection_args.get("address", None) + uri: str = connection_args.get("uri", None) + user = connection_args.get("user", None) + + # Order of use is host/port, uri, address + if host is not None and port is not None: + given_address = str(host) + ":" + str(port) + elif uri is not None: + given_address = uri.split("https://")[1] + elif address is not None: + given_address = address + else: + given_address = None + logger.debug("Missing standard address type for reuse atttempt") + + # User defaults to empty string when getting connection info + if user is not None: + tmp_user = user + else: + tmp_user = "" + + # If a valid address was given, then check if a connection exists + if given_address is not None: + for con in connections.list_connections(): + addr = connections.get_connection_addr(con[0]) + if ( + con[1] + and ("address" in addr) + and (addr["address"] == given_address) + and ("user" in addr) + and (addr["user"] == tmp_user) + ): + logger.debug("Using previous connection: %s", con[0]) + return con[0] + + # Generate a new connection if one doesnt exist + alias = uuid4().hex + try: + connections.connect(alias=alias, **connection_args) + logger.debug("Created new connection using: %s", alias) + return alias + except MilvusException as e: + logger.error("Failed to create new connection using: %s", alias) + raise e + + def _init( + self, embeddings: Optional[list] = None, metadatas: Optional[list[dict]] = None + ) -> None: + if embeddings is not None: + self._create_collection(embeddings, metadatas) + self._extract_fields() + self._create_index() + self._create_search_params() + self._load() + + def _create_collection( + self, embeddings: list, metadatas: Optional[list[dict]] = None + ) -> None: + from pymilvus import ( + Collection, + CollectionSchema, + DataType, + FieldSchema, + MilvusException, + ) + from pymilvus.orm.types import infer_dtype_bydata - self.text_field = text_field - self.auto_id = False - self.primary_field = None - self.vector_field = None - self.fields = [] - - self.col = Collection(self.collection_name) - schema = self.col.schema - - # Grabbing the fields for the existing collection. - for x in schema.fields: - self.fields.append(x.name) - if x.auto_id: - self.fields.remove(x.name) - if x.is_primary: - self.primary_field = x.name - if x.dtype == DataType.FLOAT_VECTOR or x.dtype == DataType.BINARY_VECTOR: - self.vector_field = x.name + # Determine embedding dim + dim = len(embeddings[0]) + fields = [] + # Determine metadata schema + if metadatas: + # Create FieldSchema for each entry in metadata. + for key, value in metadatas[0].items(): + # Infer the corresponding datatype of the metadata + dtype = infer_dtype_bydata(value) + # Datatype isnt compatible + if dtype == DataType.UNKNOWN or dtype == DataType.NONE: + logger.error( + "Failure to create collection, unrecognized dtype for key: %s", + key, + ) + raise ValueError(f"Unrecognized datatype for {key}.") + # Dataype is a string/varchar equivalent + elif dtype == DataType.VARCHAR: + fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535)) + else: + fields.append(FieldSchema(key, dtype)) - # Default search params when one is not provided. - self.index_params = { - "IVF_FLAT": {"params": {"nprobe": 10}}, - "IVF_SQ8": {"params": {"nprobe": 10}}, - "IVF_PQ": {"params": {"nprobe": 10}}, - "HNSW": {"params": {"ef": 10}}, - "RHNSW_FLAT": {"params": {"ef": 10}}, - "RHNSW_SQ": {"params": {"ef": 10}}, - "RHNSW_PQ": {"params": {"ef": 10}}, - "IVF_HNSW": {"params": {"nprobe": 10, "ef": 10}}, - "ANNOY": {"params": {"search_k": 10}}, - } + # Create the text field + fields.append( + FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535) + ) + # Create the primary key field + fields.append( + FieldSchema( + self._primary_field, DataType.INT64, is_primary=True, auto_id=True + ) + ) + # Create the vector field, supports binary or float vectors + fields.append( + FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim) + ) + + # Create the schema for the collection + schema = CollectionSchema(fields) + + # Create the collection + try: + self.col = Collection( + name=self.collection_name, + schema=schema, + consistency_level=self.consistency_level, + using=self.alias, + ) + except MilvusException as e: + logger.error( + "Failed to create collection: %s error: %s", self.collection_name, e + ) + raise e + + def _extract_fields(self) -> None: + """Grab the existing fields from the Collection""" + from pymilvus import Collection + + if isinstance(self.col, Collection): + schema = self.col.schema + for x in schema.fields: + self.fields.append(x.name) + # Since primary field is auto-id, no need to track it + self.fields.remove(self._primary_field) + + def _get_index(self) -> Optional[dict[str, Any]]: + """Return the vector index information if it exists""" + from pymilvus import Collection + + if isinstance(self.col, Collection): + for x in self.col.indexes: + if x.field_name == self._vector_field: + return x.to_dict() + return None + + def _create_index(self) -> None: + """Create a index on the collection""" + from pymilvus import Collection, MilvusException + + if isinstance(self.col, Collection) and self._get_index() is None: + try: + # If no index params, use a default HNSW based one + if self.index_params is None: + self.index_params = { + "metric_type": "L2", + "index_type": "HNSW", + "params": {"M": 8, "efConstruction": 64}, + } + + try: + self.col.create_index( + self._vector_field, + index_params=self.index_params, + using=self.alias, + ) + + # If default did not work, most likely on Zilliz Cloud + except MilvusException: + # Use AUTOINDEX based index + self.index_params = { + "metric_type": "L2", + "index_type": "AUTOINDEX", + "params": {}, + } + self.col.create_index( + self._vector_field, + index_params=self.index_params, + using=self.alias, + ) + logger.debug( + "Successfully created an index on collection: %s", + self.collection_name, + ) + + except MilvusException as e: + logger.error( + "Failed to create an index on collection: %s", self.collection_name + ) + raise e + + def _create_search_params(self) -> None: + """Generate search params based on the current index type""" + from pymilvus import Collection + + if isinstance(self.col, Collection) and self.search_params is None: + index = self._get_index() + if index is not None: + index_type: str = index["index_param"]["index_type"] + metric_type: str = index["index_param"]["metric_type"] + self.search_params = self.default_search_params[index_type] + self.search_params["metric_type"] = metric_type + + def _load(self) -> None: + """Load the collection if available.""" + from pymilvus import Collection + + if isinstance(self.col, Collection) and self._get_index() is not None: + self.col.load() def add_texts( self, texts: Iterable[str], metadatas: Optional[List[dict]] = None, - partition_name: Optional[str] = None, timeout: Optional[int] = None, + batch_size: int = 1000, **kwargs: Any, ) -> List[str]: """Insert text data into Milvus. - When using add_texts() it is assumed that a collecton has already - been made and indexed. If metadata is included, it is assumed that - it is ordered correctly to match the schema provided to the Collection - and that the embedding vector is the first schema field. + Inserting data when the collection has not be made yet will result + in creating a new Collection. The data of the first entity decides + the schema of the new collection, the dim is extracted from the first + embedding and the columns are decided by the first metadata dict. + Metada keys will need to be present for all inserted values. At + the moment there is no None equivalent in Milvus. Args: - texts (Iterable[str]): The text being embedded and inserted. - metadatas (Optional[List[dict]], optional): The metadata that - corresponds to each insert. Defaults to None. - partition_name (str, optional): The partition of the collection - to insert data into. Defaults to None. - timeout: specified timeout. + texts (Iterable[str]): The texts to embed, it is assumed + that they all fit in memory. + metadatas (Optional[List[dict]]): Metadata dicts attached to each of + the texts. Defaults to None. + timeout (Optional[int]): Timeout for each batch insert. Defaults + to None. + batch_size (int, optional): Batch size to use for insertion. + Defaults to 1000. + + Raises: + MilvusException: Failure to add texts Returns: List[str]: The resulting keys for each inserted element. """ - insert_dict: Any = {self.text_field: list(texts)} + from pymilvus import Collection, MilvusException + + texts = list(texts) + try: - insert_dict[self.vector_field] = self.embedding_func.embed_documents( - list(texts) - ) + embeddings = self.embedding_func.embed_documents(texts) except NotImplementedError: - insert_dict[self.vector_field] = [ - self.embedding_func.embed_query(x) for x in texts - ] + embeddings = [self.embedding_func.embed_query(x) for x in texts] + + if len(embeddings) == 0: + logger.debug("Nothing to insert, skipping.") + return [] + + # If the collection hasnt been initialized yet, perform all steps to do so + if not isinstance(self.col, Collection): + self._init(embeddings, metadatas) + + # Dict to hold all insert columns + insert_dict: dict[str, list] = { + self._text_field: texts, + self._vector_field: embeddings, + } + # Collect the metadata into the insert dict. - if len(self.fields) > 2 and metadatas is not None: + if metadatas is not None: for d in metadatas: for key, value in d.items(): if key in self.fields: insert_dict.setdefault(key, []).append(value) - # Convert dict to list of lists for insertion - insert_list = [insert_dict[x] for x in self.fields] - # Insert into the collection. - res = self.col.insert( - insert_list, partition_name=partition_name, timeout=timeout + + # Total insert count + vectors: list = insert_dict[self._vector_field] + total_count = len(vectors) + + pks: list[str] = [] + + assert isinstance(self.col, Collection) + for i in range(0, total_count, batch_size): + # Grab end index + end = min(i + batch_size, total_count) + # Convert dict to list of lists batch for insertion + insert_list = [insert_dict[x][i:end] for x in self.fields] + # Insert into the collection. + try: + res: Collection + res = self.col.insert(insert_list, timeout=timeout, **kwargs) + pks.extend(res.primary_keys) + except MilvusException as e: + logger.error( + "Failed to insert batch starting at entity: %s/%s", i, total_count + ) + raise e + return pks + + def similarity_search( + self, + query: str, + k: int = 4, + param: Optional[dict] = None, + expr: Optional[str] = None, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> List[Document]: + """Perform a similarity search against the query string. + + Args: + query (str): The text to search. + k (int, optional): How many results to return. Defaults to 4. + param (dict, optional): The search params for the index type. + Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + timeout (int, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.search() keyword arguments. + + Returns: + List[Document]: Document results for search. + """ + if self.col is None: + logger.debug("No existing collection to search.") + return [] + res = self.similarity_search_with_score( + query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs + ) + return [doc for doc, _ in res] + + def similarity_search_by_vector( + self, + embedding: List[float], + k: int = 4, + param: Optional[dict] = None, + expr: Optional[str] = None, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> List[Document]: + """Perform a similarity search against the query string. + + Args: + embedding (List[float]): The embedding vector to search. + k (int, optional): How many results to return. Defaults to 4. + param (dict, optional): The search params for the index type. + Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + timeout (int, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.search() keyword arguments. + + Returns: + List[Document]: Document results for search. + """ + if self.col is None: + logger.debug("No existing collection to search.") + return [] + res = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs ) - # Flush to make sure newly inserted is immediately searchable. - self.col.flush() - return res.primary_keys + return [doc for doc, _ in res] - def _worker_search( + def similarity_search_with_score( self, query: str, k: int = 4, param: Optional[dict] = None, expr: Optional[str] = None, - partition_names: Optional[List[str]] = None, - round_decimal: int = -1, timeout: Optional[int] = None, **kwargs: Any, - ) -> Tuple[List[float], List[Tuple[Document, Any, Any]]]: - # Load the collection into memory for searching. - self.col.load() - # Decide to use default params if not passed in. - if param is None: - index_type = self.col.indexes[0].params["index_type"] - param = self.index_params[index_type] + ) -> List[Tuple[Document, float]]: + """Perform a search on a query string and return results with score. + + For more information about the search parameters, take a look at the pymilvus + documentation found here: + https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md + + Args: + query (str): The text being searched. + k (int, optional): The amount of results ot return. Defaults to 4. + param (dict): The search params for the specified index. + Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + timeout (int, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.search() keyword arguments. + + Returns: + List[float], List[Tuple[Document, any, any]]: + """ + if self.col is None: + logger.debug("No existing collection to search.") + return [] + # Embed the query text. - data = [self.embedding_func.embed_query(query)] + embedding = self.embedding_func.embed_query(query) + # Determine result metadata fields. output_fields = self.fields[:] - output_fields.remove(self.vector_field) + output_fields.remove(self._vector_field) + + res = self.similarity_search_with_score_by_vector( + embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs + ) + return res + + def similarity_search_with_score_by_vector( + self, + embedding: List[float], + k: int = 4, + param: Optional[dict] = None, + expr: Optional[str] = None, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Perform a search on a query string and return results with score. + + For more information about the search parameters, take a look at the pymilvus + documentation found here: + https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md + + Args: + embedding (List[float]): The embedding vector being searched. + k (int, optional): The amount of results ot return. Defaults to 4. + param (dict): The search params for the specified index. + Defaults to None. + expr (str, optional): Filtering expression. Defaults to None. + timeout (int, optional): How long to wait before timeout error. + Defaults to None. + kwargs: Collection.search() keyword arguments. + + Returns: + List[Tuple[Document, float]]: Result doc and score. + """ + if self.col is None: + logger.debug("No existing collection to search.") + return [] + + if param is None: + param = self.search_params + + # Determine result metadata fields. + output_fields = self.fields[:] + output_fields.remove(self._vector_field) + # Perform the search. res = self.col.search( - data, - self.vector_field, - param, - k, + data=[embedding], + anns_field=self._vector_field, + param=param, + limit=k, expr=expr, output_fields=output_fields, - partition_names=partition_names, - round_decimal=round_decimal, timeout=timeout, **kwargs, ) @@ -172,258 +608,187 @@ class Milvus(VectorStore): ret = [] for result in res[0]: meta = {x: result.entity.get(x) for x in output_fields} - ret.append( - ( - Document(page_content=meta.pop(self.text_field), metadata=meta), - result.distance, - result.id, - ) - ) + doc = Document(page_content=meta.pop(self._text_field), metadata=meta) + pair = (doc, result.score) + ret.append(pair) - return data[0], ret + return ret - def similarity_search_with_score( + def max_marginal_relevance_search( self, query: str, k: int = 4, + fetch_k: int = 20, param: Optional[dict] = None, expr: Optional[str] = None, - partition_names: Optional[List[str]] = None, - round_decimal: int = -1, timeout: Optional[int] = None, **kwargs: Any, - ) -> List[Tuple[Document, float]]: - """Perform a search on a query string and return results. + ) -> List[Document]: + """Perform a search and return results that are reordered by MMR. Args: query (str): The text being searched. - k (int, optional): The amount of results ot return. Defaults to 4. + k (int, optional): How many results to give. Defaults to 4. + fetch_k (int, optional): Total results to select k from. + Defaults to 20. param (dict, optional): The search params for the specified index. Defaults to None. expr (str, optional): Filtering expression. Defaults to None. - partition_names (List[str], optional): Partitions to search through. + timeout (int, optional): How long to wait before timeout error. Defaults to None. - round_decimal (int, optional): Round the resulting distance. Defaults - to -1. - timeout (int, optional): Amount to wait before timeout error. Defaults - to None. kwargs: Collection.search() keyword arguments. + Returns: - List[float], List[Tuple[Document, any, any]]: search_embedding, - (Document, distance, primary_field) results. + List[Document]: Document results for search. """ - _, result = self._worker_search( - query, k, param, expr, partition_names, round_decimal, timeout, **kwargs + if self.col is None: + logger.debug("No existing collection to search.") + return [] + + embedding = self.embedding_func.embed_query(query) + + return self.max_marginal_relevance_search_by_vector( + embedding=embedding, + k=k, + fetch_k=fetch_k, + param=param, + expr=expr, + timeout=timeout, + **kwargs, ) - return [(x, y) for x, y, _ in result] - def max_marginal_relevance_search( + def max_marginal_relevance_search_by_vector( self, - query: str, + embedding: list[float], k: int = 4, fetch_k: int = 20, param: Optional[dict] = None, expr: Optional[str] = None, - partition_names: Optional[List[str]] = None, - round_decimal: int = -1, timeout: Optional[int] = None, **kwargs: Any, ) -> List[Document]: """Perform a search and return results that are reordered by MMR. Args: - query (str): The text being searched. + embedding (str): The embedding vector being searched. k (int, optional): How many results to give. Defaults to 4. fetch_k (int, optional): Total results to select k from. Defaults to 20. param (dict, optional): The search params for the specified index. Defaults to None. expr (str, optional): Filtering expression. Defaults to None. - partition_names (List[str], optional): What partitions to search. + timeout (int, optional): How long to wait before timeout error. Defaults to None. - round_decimal (int, optional): Round the resulting distance. Defaults - to -1. - timeout (int, optional): Amount to wait before timeout error. Defaults - to None. + kwargs: Collection.search() keyword arguments. Returns: List[Document]: Document results for search. """ - data, res = self._worker_search( - query, - fetch_k, - param, - expr, - partition_names, - round_decimal, - timeout, + if self.col is None: + logger.debug("No existing collection to search.") + return [] + + if param is None: + param = self.search_params + + # Determine result metadata fields. + output_fields = self.fields[:] + output_fields.remove(self._vector_field) + + # Perform the search. + res = self.col.search( + data=[embedding], + anns_field=self._vector_field, + param=param, + limit=fetch_k, + expr=expr, + output_fields=output_fields, + timeout=timeout, **kwargs, ) - # Extract result IDs. - ids = [x for _, _, x in res] - # Get the raw vectors from Milvus. + # Organize results. + ids = [] + documents = [] + scores = [] + for result in res[0]: + meta = {x: result.entity.get(x) for x in output_fields} + doc = Document(page_content=meta.pop(self._text_field), metadata=meta) + documents.append(doc) + scores.append(result.score) + ids.append(result.id) + vectors = self.col.query( - expr=f"{self.primary_field} in {ids}", - output_fields=[self.primary_field, self.vector_field], + expr=f"{self._primary_field} in {ids}", + output_fields=[self._primary_field, self._vector_field], + timeout=timeout, ) - # Reorganize the results from query to match result order. - vectors = {x[self.primary_field]: x[self.vector_field] for x in vectors} - search_embedding = data + # Reorganize the results from query to match search order. + vectors = {x[self._primary_field]: x[self._vector_field] for x in vectors} + ordered_result_embeddings = [vectors[x] for x in ids] + # Get the new order of results. new_ordering = maximal_marginal_relevance( - np.array(search_embedding), ordered_result_embeddings, k=k + np.array(embedding), ordered_result_embeddings, k=k ) + # Reorder the values and return. ret = [] for x in new_ordering: + # Function can return -1 index if x == -1: break else: - ret.append(res[x][0]) + ret.append(documents[x]) return ret - def similarity_search( - self, - query: str, - k: int = 4, - param: Optional[dict] = None, - expr: Optional[str] = None, - partition_names: Optional[List[str]] = None, - round_decimal: int = -1, - timeout: Optional[int] = None, - **kwargs: Any, - ) -> List[Document]: - """Perform a similarity search against the query string. - - Args: - query (str): The text to search. - k (int, optional): How many results to return. Defaults to 4. - param (dict, optional): The search params for the index type. - Defaults to None. - expr (str, optional): Filtering expression. Defaults to None. - partition_names (List[str], optional): What partitions to search. - Defaults to None. - round_decimal (int, optional): What decimal point to round to. - Defaults to -1. - timeout (int, optional): How long to wait before timeout error. - Defaults to None. - - Returns: - List[Document]: Document results for search. - """ - _, docs_and_scores = self._worker_search( - query, k, param, expr, partition_names, round_decimal, timeout, **kwargs - ) - return [doc for doc, _, _ in docs_and_scores] - @classmethod def from_texts( cls, texts: List[str], embedding: Embeddings, metadatas: Optional[List[dict]] = None, + collection_name: str = "LangChainCollection", + connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION, + consistency_level: str = "Session", + index_params: Optional[dict] = None, + search_params: Optional[dict] = None, + drop_old: bool = False, **kwargs: Any, ) -> Milvus: """Create a Milvus collection, indexes it with HNSW, and insert data. Args: - texts (List[str]): Text to insert. - embedding (Embeddings): Embedding function to use. - metadatas (Optional[List[dict]], optional): Dict metatadata. + texts (List[str]): Text data. + embedding (Embeddings): Embedding function. + metadatas (Optional[List[dict]]): Metadata for each text if it exists. + Defaults to None. + collection_name (str, optional): Collection name to use. Defaults to + "LangChainCollection". + connection_args (dict[str, Any], optional): Connection args to use. Defaults + to DEFAULT_MILVUS_CONNECTION. + consistency_level (str, optional): Which consistency level to use. Defaults + to "Session". + index_params (Optional[dict], optional): Which index_params to use. Defaults + to None. + search_params (Optional[dict], optional): Which search params to use. Defaults to None. + drop_old (Optional[bool], optional): Whether to drop the collection with + that name if it exists. Defaults to False. Returns: - VectorStore: The Milvus vector store. + Milvus: Milvus Vector Store """ - try: - from pymilvus import ( - Collection, - CollectionSchema, - DataType, - FieldSchema, - connections, - ) - from pymilvus.orm.types import infer_dtype_bydata - except ImportError: - raise ValueError( - "Could not import pymilvus python package. " - "Please install it with `pip install pymilvus`." - ) - # Connect to Milvus instance - if not connections.has_connection("default"): - connections.connect(**kwargs.get("connection_args", {"port": 19530})) - # Determine embedding dim - embeddings = embedding.embed_query(texts[0]) - dim = len(embeddings) - # Generate unique names - primary_field = "c" + str(uuid.uuid4().hex) - vector_field = "c" + str(uuid.uuid4().hex) - text_field = "c" + str(uuid.uuid4().hex) - collection_name = "c" + str(uuid.uuid4().hex) - fields = [] - # Determine metadata schema - if metadatas: - # Check if all metadata keys line up - key = metadatas[0].keys() - for x in metadatas: - if key != x.keys(): - raise ValueError( - "Mismatched metadata. " - "Make sure all metadata has the same keys and datatype." - ) - # Create FieldSchema for each entry in singular metadata. - for key, value in metadatas[0].items(): - # Infer the corresponding datatype of the metadata - dtype = infer_dtype_bydata(value) - if dtype == DataType.UNKNOWN: - raise ValueError(f"Unrecognized datatype for {key}.") - elif dtype == DataType.VARCHAR: - # Find out max length text based metadata - max_length = 0 - for subvalues in metadatas: - max_length = max(max_length, len(subvalues[key])) - fields.append( - FieldSchema(key, DataType.VARCHAR, max_length=max_length + 1) - ) - else: - fields.append(FieldSchema(key, dtype)) - - # Find out max length of texts - max_length = 0 - for y in texts: - max_length = max(max_length, len(y)) - # Create the text field - fields.append( - FieldSchema(text_field, DataType.VARCHAR, max_length=max_length + 1) - ) - # Create the primary key field - fields.append( - FieldSchema(primary_field, DataType.INT64, is_primary=True, auto_id=True) - ) - # Create the vector field - fields.append(FieldSchema(vector_field, DataType.FLOAT_VECTOR, dim=dim)) - # Create the schema for the collection - schema = CollectionSchema(fields) - # Create the collection - collection = Collection(collection_name, schema) - # Index parameters for the collection - index = { - "index_type": "HNSW", - "metric_type": "L2", - "params": {"M": 8, "efConstruction": 64}, - } - # Create the index - collection.create_index(vector_field, index) - # Create the VectorStore - milvus = cls( - embedding, - kwargs.get("connection_args", {"port": 19530}), - collection_name, - text_field, + vector_db = cls( + embedding_function=embedding, + collection_name=collection_name, + connection_args=connection_args, + consistency_level=consistency_level, + index_params=index_params, + search_params=search_params, + drop_old=drop_old, + **kwargs, ) - # Add the texts. - milvus.add_texts(texts, metadatas) - - return milvus + vector_db.add_texts(texts=texts, metadatas=metadatas) + return vector_db diff --git a/langchain/vectorstores/zilliz.py b/langchain/vectorstores/zilliz.py new file mode 100644 index 00000000..13d165d6 --- /dev/null +++ b/langchain/vectorstores/zilliz.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import logging +from typing import Any, List, Optional + +from langchain.embeddings.base import Embeddings +from langchain.vectorstores.milvus import Milvus + +logger = logging.getLogger(__name__) + + +class Zilliz(Milvus): + def _create_index(self) -> None: + """Create a index on the collection""" + from pymilvus import Collection, MilvusException + + if isinstance(self.col, Collection) and self._get_index() is None: + try: + # If no index params, use a default AutoIndex based one + if self.index_params is None: + self.index_params = { + "metric_type": "L2", + "index_type": "AUTOINDEX", + "params": {}, + } + + try: + self.col.create_index( + self._vector_field, + index_params=self.index_params, + using=self.alias, + ) + + # If default did not work, most likely Milvus self-hosted + except MilvusException: + # Use HNSW based index + self.index_params = { + "metric_type": "L2", + "index_type": "HNSW", + "params": {"M": 8, "efConstruction": 64}, + } + self.col.create_index( + self._vector_field, + index_params=self.index_params, + using=self.alias, + ) + logger.debug( + "Successfully created an index on collection: %s", + self.collection_name, + ) + + except MilvusException as e: + logger.error( + "Failed to create an index on collection: %s", self.collection_name + ) + raise e + + @classmethod + def from_texts( + cls, + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = "LangChainCollection", + connection_args: dict[str, Any] = {}, + consistency_level: str = "Session", + index_params: Optional[dict] = None, + search_params: Optional[dict] = None, + drop_old: bool = False, + **kwargs: Any, + ) -> Zilliz: + """Create a Zilliz collection, indexes it with HNSW, and insert data. + + Args: + texts (List[str]): Text data. + embedding (Embeddings): Embedding function. + metadatas (Optional[List[dict]]): Metadata for each text if it exists. + Defaults to None. + collection_name (str, optional): Collection name to use. Defaults to + "LangChainCollection". + connection_args (dict[str, Any], optional): Connection args to use. Defaults + to DEFAULT_MILVUS_CONNECTION. + consistency_level (str, optional): Which consistency level to use. Defaults + to "Session". + index_params (Optional[dict], optional): Which index_params to use. + Defaults to None. + search_params (Optional[dict], optional): Which search params to use. + Defaults to None. + drop_old (Optional[bool], optional): Whether to drop the collection with + that name if it exists. Defaults to False. + + Returns: + Zilliz: Zilliz Vector Store + """ + vector_db = cls( + embedding_function=embedding, + collection_name=collection_name, + connection_args=connection_args, + consistency_level=consistency_level, + index_params=index_params, + search_params=search_params, + drop_old=drop_old, + **kwargs, + ) + vector_db.add_texts(texts=texts, metadatas=metadatas) + return vector_db diff --git a/tests/integration_tests/vectorstores/test_milvus.py b/tests/integration_tests/vectorstores/test_milvus.py index 063427e7..38db31d6 100644 --- a/tests/integration_tests/vectorstores/test_milvus.py +++ b/tests/integration_tests/vectorstores/test_milvus.py @@ -9,12 +9,15 @@ from tests.integration_tests.vectorstores.fake_embeddings import ( ) -def _milvus_from_texts(metadatas: Optional[List[dict]] = None) -> Milvus: +def _milvus_from_texts( + metadatas: Optional[List[dict]] = None, drop: bool = True +) -> Milvus: return Milvus.from_texts( fake_texts, FakeEmbeddings(), metadatas=metadatas, connection_args={"host": "127.0.0.1", "port": "19530"}, + drop_old=drop, ) @@ -51,3 +54,36 @@ def test_milvus_max_marginal_relevance_search() -> None: Document(page_content="foo", metadata={"page": 0}), Document(page_content="baz", metadata={"page": 2}), ] + + +def test_milvus_add_extra() -> None: + """Test end to end construction and MRR search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = _milvus_from_texts(metadatas=metadatas) + + docsearch.add_texts(texts, metadatas) + + output = docsearch.similarity_search("foo", k=10) + assert len(output) == 6 + + +def test_milvus_no_drop() -> None: + """Test end to end construction and MRR search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = _milvus_from_texts(metadatas=metadatas) + del docsearch + + docsearch = _milvus_from_texts(metadatas=metadatas, drop=False) + + output = docsearch.similarity_search("foo", k=10) + assert len(output) == 6 + + +# if __name__ == "__main__": +# test_milvus() +# test_milvus_with_score() +# test_milvus_max_marginal_relevance_search() +# test_milvus_add_extra() +# test_milvus_no_drop() diff --git a/tests/integration_tests/vectorstores/test_zilliz.py b/tests/integration_tests/vectorstores/test_zilliz.py new file mode 100644 index 00000000..5080e222 --- /dev/null +++ b/tests/integration_tests/vectorstores/test_zilliz.py @@ -0,0 +1,94 @@ +"""Test Zilliz functionality.""" +from typing import List, Optional + +from langchain.docstore.document import Document +from langchain.vectorstores import Zilliz +from tests.integration_tests.vectorstores.fake_embeddings import ( + FakeEmbeddings, + fake_texts, +) + + +def _zilliz_from_texts( + metadatas: Optional[List[dict]] = None, drop: bool = True +) -> Zilliz: + return Zilliz.from_texts( + fake_texts, + FakeEmbeddings(), + metadatas=metadatas, + connection_args={ + "uri": "", + "user": "", + "password": "", + "secure": True, + }, + drop_old=drop, + ) + + +def test_zilliz() -> None: + """Test end to end construction and search.""" + docsearch = _zilliz_from_texts() + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + +def test_zilliz_with_score() -> None: + """Test end to end construction and search with scores and IDs.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = _zilliz_from_texts(metadatas=metadatas) + output = docsearch.similarity_search_with_score("foo", k=3) + docs = [o[0] for o in output] + scores = [o[1] for o in output] + assert docs == [ + Document(page_content="foo", metadata={"page": 0}), + Document(page_content="bar", metadata={"page": 1}), + Document(page_content="baz", metadata={"page": 2}), + ] + assert scores[0] < scores[1] < scores[2] + + +def test_zilliz_max_marginal_relevance_search() -> None: + """Test end to end construction and MRR search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = _zilliz_from_texts(metadatas=metadatas) + output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3) + assert output == [ + Document(page_content="foo", metadata={"page": 0}), + Document(page_content="baz", metadata={"page": 2}), + ] + + +def test_zilliz_add_extra() -> None: + """Test end to end construction and MRR search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = _zilliz_from_texts(metadatas=metadatas) + + docsearch.add_texts(texts, metadatas) + + output = docsearch.similarity_search("foo", k=10) + assert len(output) == 6 + + +def test_zilliz_no_drop() -> None: + """Test end to end construction and MRR search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = _zilliz_from_texts(metadatas=metadatas) + del docsearch + + docsearch = _zilliz_from_texts(metadatas=metadatas, drop=False) + + output = docsearch.similarity_search("foo", k=10) + assert len(output) == 6 + + +# if __name__ == "__main__": +# test_zilliz() +# test_zilliz_with_score() +# test_zilliz_max_marginal_relevance_search() +# test_zilliz_add_extra() +# test_zilliz_no_drop()