milvus: mv to external repo (#26920)

This commit is contained in:
Erick Friis 2024-09-30 17:38:30 -07:00 committed by GitHub
parent 35f6393144
commit a8e1577f85
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 2 additions and 5088 deletions

View File

@ -1 +0,0 @@
__pycache__

View File

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2023 LangChain, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,59 +0,0 @@
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
# Default target executed when no arguments are given to make.
all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
integration_test integration_tests: TEST_FILE=tests/integration_tests/
test tests integration_test integration_tests:
poetry run pytest $(TEST_FILE)
test_watch:
poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE)
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/milvus --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_milvus
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml
spell_fix:
poetry run codespell --toml pyproject.toml -w
check_imports: $(shell find langchain_milvus -name '*.py')
poetry run python ./scripts/check_imports.py $^
######################
# HELP
######################
help:
@echo '----'
@echo 'check_imports - check imports'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'test - run unit tests'
@echo 'tests - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'

View File

@ -1,42 +1,3 @@
# langchain-milvus This package has moved!
This is a library integration with [Milvus](https://milvus.io/) and [Zilliz Cloud](https://zilliz.com/cloud). https://github.com/langchain-ai/langchain-milvus/tree/main/libs/milvus
## Installation
```bash
pip install -U langchain-milvus
```
## Milvus vector database
See a [usage example](https://python.langchain.com/docs/integrations/vectorstores/milvus/)
```python
from langchain_milvus import Milvus
```
## Milvus hybrid search
See a [usage example](https://python.langchain.com/docs/integrations/retrievers/milvus_hybrid_search/).
```python
from langchain_milvus import MilvusCollectionHybridSearchRetriever
```
## Zilliz Cloud vector database
See a [usage example](https://python.langchain.com/docs/integrations/vectorstores/zilliz/).
```python
from langchain_milvus import Zilliz
```
## Zilliz Cloud Pipeline Retriever
See a [usage example](https://python.langchain.com/docs/integrations/retrievers/zilliz_cloud_pipeline/).
```python
from langchain_milvus import ZillizCloudPipelineRetriever
```

View File

@ -1,12 +0,0 @@
from langchain_milvus.retrievers import (
MilvusCollectionHybridSearchRetriever,
ZillizCloudPipelineRetriever,
)
from langchain_milvus.vectorstores import Milvus, Zilliz
__all__ = [
"Milvus",
"Zilliz",
"ZillizCloudPipelineRetriever",
"MilvusCollectionHybridSearchRetriever",
]

View File

@ -1,8 +0,0 @@
from langchain_milvus.retrievers.milvus_hybrid_search import (
MilvusCollectionHybridSearchRetriever,
)
from langchain_milvus.retrievers.zilliz_cloud_pipeline_retriever import (
ZillizCloudPipelineRetriever,
)
__all__ = ["ZillizCloudPipelineRetriever", "MilvusCollectionHybridSearchRetriever"]

View File

@ -1,161 +0,0 @@
from typing import Any, Dict, List, Optional, Union
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.retrievers import BaseRetriever
from pymilvus import AnnSearchRequest, Collection
from pymilvus.client.abstract import BaseRanker, SearchResult # type: ignore
from langchain_milvus.utils.sparse import BaseSparseEmbedding
class MilvusCollectionHybridSearchRetriever(BaseRetriever):
"""Hybrid search retriever
that uses Milvus Collection to retrieve documents based on multiple fields.
For more information, please refer to:
https://milvus.io/docs/release_notes.md#Multi-Embedding---Hybrid-Search
"""
collection: Collection
"""Milvus Collection object."""
rerank: BaseRanker
"""Milvus ranker object. Such as WeightedRanker or RRFRanker."""
anns_fields: List[str]
"""The names of vector fields that are used for ANNS search."""
field_embeddings: List[Union[Embeddings, BaseSparseEmbedding]]
"""The embedding functions of each vector fields,
which can be either Embeddings or BaseSparseEmbedding."""
field_search_params: Optional[List[Dict]] = None
"""The search parameters of each vector fields.
If not specified, the default search parameters will be used."""
field_limits: Optional[List[int]] = None
"""Limit number of results for each ANNS field.
If not specified, the default top_k will be used."""
field_exprs: Optional[List[Optional[str]]] = None
"""The boolean expression for filtering the search results."""
top_k: int = 4
"""Final top-K number of documents to retrieve."""
text_field: str = "text"
"""The text field name,
which will be used as the `page_content` of a `Document` object."""
output_fields: Optional[List[str]] = None
"""Final output fields of the documents.
If not specified, all fields except the vector fields will be used as output fields,
which will be the `metadata` of a `Document` object."""
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
# If some parameters are not specified, set default values
if self.field_search_params is None:
default_search_params = {
"metric_type": "L2",
"params": {"nprobe": 10},
}
self.field_search_params = [default_search_params] * len(self.anns_fields)
if self.field_limits is None:
self.field_limits = [self.top_k] * len(self.anns_fields)
if self.field_exprs is None:
self.field_exprs = [None] * len(self.anns_fields)
# Check the fields
self._validate_fields_num()
self.output_fields = self._get_output_fields()
self._validate_fields_name()
# Load collection
self.collection.load()
def _validate_fields_num(self) -> None:
assert (
len(self.anns_fields) >= 2
), "At least two fields are required for hybrid search."
lengths = [len(self.anns_fields)]
if self.field_limits is not None:
lengths.append(len(self.field_limits))
if self.field_exprs is not None:
lengths.append(len(self.field_exprs))
if not all(length == lengths[0] for length in lengths):
raise ValueError("All field-related lists must have the same length.")
if len(self.field_search_params) != len(self.anns_fields): # type: ignore[arg-type]
raise ValueError(
"field_search_params must have the same length as anns_fields."
)
def _validate_fields_name(self) -> None:
collection_fields = [x.name for x in self.collection.schema.fields]
for field in self.anns_fields:
assert (
field in collection_fields
), f"{field} is not a valid field in the collection."
assert (
self.text_field in collection_fields
), f"{self.text_field} is not a valid field in the collection."
for field in self.output_fields: # type: ignore[union-attr]
assert (
field in collection_fields
), f"{field} is not a valid field in the collection."
def _get_output_fields(self) -> List[str]:
if self.output_fields:
return self.output_fields
output_fields = [x.name for x in self.collection.schema.fields]
for field in self.anns_fields:
if field in output_fields:
output_fields.remove(field)
if self.text_field not in output_fields:
output_fields.append(self.text_field)
return output_fields
def _build_ann_search_requests(self, query: str) -> List[AnnSearchRequest]:
search_requests = []
for ann_field, embedding, param, limit, expr in zip(
self.anns_fields,
self.field_embeddings,
self.field_search_params, # type: ignore[arg-type]
self.field_limits, # type: ignore[arg-type]
self.field_exprs, # type: ignore[arg-type]
):
request = AnnSearchRequest(
data=[embedding.embed_query(query)],
anns_field=ann_field,
param=param,
limit=limit,
expr=expr,
)
search_requests.append(request)
return search_requests
def _parse_document(self, data: dict) -> Document:
return Document(
page_content=data.pop(self.text_field),
metadata=data,
)
def _process_search_result(
self, search_results: List[SearchResult]
) -> List[Document]:
documents = []
for result in search_results[0]:
data = {x: result.entity.get(x) for x in self.output_fields} # type: ignore[union-attr]
doc = self._parse_document(data)
documents.append(doc)
return documents
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
requests = self._build_ann_search_requests(query)
search_result = self.collection.hybrid_search(
requests, self.rerank, limit=self.top_k, output_fields=self.output_fields
)
documents = self._process_search_result(search_result)
return documents

View File

@ -1,215 +0,0 @@
from typing import Any, Dict, List, Optional
import requests
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
class ZillizCloudPipelineRetriever(BaseRetriever):
"""`Zilliz Cloud Pipeline` retriever.
Parameters:
pipeline_ids: A dictionary of pipeline ids.
Valid keys: "ingestion", "search", "deletion".
token: Zilliz Cloud's token. Defaults to "".
cloud_region: The region of Zilliz Cloud's cluster.
Defaults to 'gcp-us-west1'.
"""
pipeline_ids: Dict
token: str = ""
cloud_region: str = "gcp-us-west1"
def _get_relevant_documents(
self,
query: str,
top_k: int = 10,
offset: int = 0,
output_fields: List = [],
filter: str = "",
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
"""
Get documents relevant to a query.
Args:
query: String to find relevant documents for
top_k: The number of results. Defaults to 10.
offset: The number of records to skip in the search result.
Defaults to 0.
output_fields: The extra fields to present in output.
filter: The Milvus expression to filter search results.
Defaults to "".
run_manager: The callbacks handler to use.
Returns:
List of relevant documents
"""
if "search" in self.pipeline_ids:
search_pipe_id = self.pipeline_ids.get("search")
else:
raise Exception(
"A search pipeline id must be provided in pipeline_ids to "
"get relevant documents."
)
domain = (
f"https://controller.api.{self.cloud_region}.zillizcloud.com/v1/pipelines"
)
headers = {
"Authorization": f"Bearer {self.token}",
"Accept": "application/json",
"Content-Type": "application/json",
}
url = f"{domain}/{search_pipe_id}/run"
params = {
"data": {"query_text": query},
"params": {
"limit": top_k,
"offset": offset,
"outputFields": output_fields,
"filter": filter,
},
}
response = requests.post(url, headers=headers, json=params)
if response.status_code != 200:
raise RuntimeError(response.text)
response_dict = response.json()
if response_dict["code"] != 200:
raise RuntimeError(response_dict)
response_data = response_dict["data"]
search_results = response_data["result"]
return [
Document(
page_content=result.pop("text")
if "text" in result
else result.pop("chunk_text"),
metadata=result,
)
for result in search_results
]
def add_texts(
self, texts: List[str], metadata: Optional[Dict[str, Any]] = None
) -> Dict:
"""
Add documents to store.
Only supported by a text ingestion pipeline in Zilliz Cloud.
Args:
texts: A list of text strings.
metadata: A key-value dictionary of metadata will
be inserted as preserved fields required by ingestion pipeline.
Defaults to None.
"""
if "ingestion" in self.pipeline_ids:
ingeset_pipe_id = self.pipeline_ids.get("ingestion")
else:
raise Exception(
"An ingestion pipeline id must be provided in pipeline_ids to"
" add documents."
)
domain = (
f"https://controller.api.{self.cloud_region}.zillizcloud.com/v1/pipelines"
)
headers = {
"Authorization": f"Bearer {self.token}",
"Accept": "application/json",
"Content-Type": "application/json",
}
url = f"{domain}/{ingeset_pipe_id}/run"
metadata = {} if metadata is None else metadata
params = {"data": {"text_list": texts}}
params["data"].update(metadata)
response = requests.post(url, headers=headers, json=params)
if response.status_code != 200:
raise Exception(response.text)
response_dict = response.json()
if response_dict["code"] != 200:
raise Exception(response_dict)
response_data = response_dict["data"]
return response_data
def add_doc_url(
self, doc_url: str, metadata: Optional[Dict[str, Any]] = None
) -> Dict:
"""
Add a document from url.
Only supported by a document ingestion pipeline in Zilliz Cloud.
Args:
doc_url: A document url.
metadata: A key-value dictionary of metadata will
be inserted as preserved fields required by ingestion pipeline.
Defaults to None.
"""
if "ingestion" in self.pipeline_ids:
ingest_pipe_id = self.pipeline_ids.get("ingestion")
else:
raise Exception(
"An ingestion pipeline id must be provided in pipeline_ids to "
"add documents."
)
domain = (
f"https://controller.api.{self.cloud_region}.zillizcloud.com/v1/pipelines"
)
headers = {
"Authorization": f"Bearer {self.token}",
"Accept": "application/json",
"Content-Type": "application/json",
}
url = f"{domain}/{ingest_pipe_id}/run"
params = {"data": {"doc_url": doc_url}}
metadata = {} if metadata is None else metadata
params["data"].update(metadata)
response = requests.post(url, headers=headers, json=params)
if response.status_code != 200:
raise Exception(response.text)
response_dict = response.json()
if response_dict["code"] != 200:
raise Exception(response_dict)
response_data = response_dict["data"]
return response_data
def delete(self, key: str, value: Any) -> Dict:
"""
Delete documents. Only supported by a deletion pipeline in Zilliz Cloud.
Args:
key: input name to run the deletion pipeline
value: input value to run deletion pipeline
"""
if "deletion" in self.pipeline_ids:
deletion_pipe_id = self.pipeline_ids.get("deletion")
else:
raise Exception(
"A deletion pipeline id must be provided in pipeline_ids to "
"add documents."
)
domain = (
f"https://controller.api.{self.cloud_region}.zillizcloud.com/v1/pipelines"
)
headers = {
"Authorization": f"Bearer {self.token}",
"Accept": "application/json",
"Content-Type": "application/json",
}
url = f"{domain}/{deletion_pipe_id}/run"
params = {"data": {key: value}}
response = requests.post(url, headers=headers, json=params)
if response.status_code != 200:
raise Exception(response.text)
response_dict = response.json()
if response_dict["code"] != 200:
raise Exception(response_dict)
response_data = response_dict["data"]
return response_data

View File

@ -1,55 +0,0 @@
from abc import ABC, abstractmethod
from typing import Dict, List
from scipy.sparse import csr_array # type: ignore
class BaseSparseEmbedding(ABC):
"""Interface for Sparse embedding models.
You can inherit from it and implement your custom sparse embedding model.
"""
@abstractmethod
def embed_query(self, query: str) -> Dict[int, float]:
"""Embed query text."""
@abstractmethod
def embed_documents(self, texts: List[str]) -> List[Dict[int, float]]:
"""Embed search docs."""
class BM25SparseEmbedding(BaseSparseEmbedding):
"""Sparse embedding model based on BM25.
This class uses the BM25 model in Milvus model to implement sparse vector embedding.
This model requires pymilvus[model] to be installed.
`pip install pymilvus[model]`
For more information please refer to:
https://milvus.io/docs/embed-with-bm25.md
"""
def __init__(self, corpus: List[str], language: str = "en"):
from pymilvus.model.sparse import BM25EmbeddingFunction # type: ignore
from pymilvus.model.sparse.bm25.tokenizers import ( # type: ignore
build_default_analyzer,
)
self.analyzer = build_default_analyzer(language=language)
self.bm25_ef = BM25EmbeddingFunction(self.analyzer, num_workers=1)
self.bm25_ef.fit(corpus)
def embed_query(self, text: str) -> Dict[int, float]:
return self._sparse_to_dict(self.bm25_ef.encode_queries([text]))
def embed_documents(self, texts: List[str]) -> List[Dict[int, float]]:
sparse_arrays = self.bm25_ef.encode_documents(texts)
return [self._sparse_to_dict(sparse_array) for sparse_array in sparse_arrays]
def _sparse_to_dict(self, sparse_array: csr_array) -> Dict[int, float]:
row_indices, col_indices = sparse_array.nonzero()
non_zero_values = sparse_array.data
result_dict = {}
for col_index, value in zip(col_indices, non_zero_values):
result_dict[col_index] = value
return result_dict

View File

@ -1,7 +0,0 @@
from langchain_milvus.vectorstores.milvus import Milvus
from langchain_milvus.vectorstores.zilliz import Zilliz
__all__ = [
"Milvus",
"Zilliz",
]

File diff suppressed because it is too large Load Diff

View File

@ -1,197 +0,0 @@
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional, Union
from langchain_core.embeddings import Embeddings
from langchain_milvus.utils.sparse import BaseSparseEmbedding
from langchain_milvus.vectorstores.milvus import Milvus
logger = logging.getLogger(__name__)
class Zilliz(Milvus):
"""`Zilliz` vector store.
You need to have `pymilvus` installed and a
running Zilliz database.
See the following documentation for how to run a Zilliz instance:
https://docs.zilliz.com/docs/create-cluster
IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA.
Args:
embedding_function (Embeddings): Function used to embed the text.
collection_name (str): Which Zilliz collection to use. Defaults to
"LangChainCollection".
connection_args (Optional[dict[str, any]]): The connection args used for
this class comes in the form of a dict.
consistency_level (str): The consistency level to use for a collection.
Defaults to "Session".
index_params (Optional[dict]): Which index params to use. Defaults to
HNSW/AUTOINDEX depending on service.
search_params (Optional[dict]): Which search params to use. Defaults to
default of index.
drop_old (Optional[bool]): Whether to drop the current collection. Defaults
to False.
auto_id (bool): Whether to enable auto id for primary key. Defaults to False.
If False, you needs to provide text ids (string less than 65535 bytes).
If True, Milvus will generate unique integers as primary keys.
The connection args used for this class comes in the form of a dict,
here are a few of the options:
address (str): The actual address of Zilliz
instance. Example address: "localhost:19530"
uri (str): The uri of Zilliz instance. Example uri:
"https://in03-ba4234asae.api.gcp-us-west1.zillizcloud.com",
host (str): The host of Zilliz instance. Default at "localhost",
PyMilvus will fill in the default host if only port is provided.
port (str/int): The port of Zilliz instance. Default at 19530, PyMilvus
will fill in the default port if only host is provided.
user (str): Use which user to connect to Zilliz instance. If user and
password are provided, we will add related header in every RPC call.
password (str): Required when user is provided. The password
corresponding to the user.
token (str): API key, for serverless clusters which can be used as
replacements for user and password.
secure (bool): Default is false. If set to true, tls will be enabled.
client_key_path (str): If use tls two-way authentication, need to
write the client.key path.
client_pem_path (str): If use tls two-way authentication, need to
write the client.pem path.
ca_pem_path (str): If use tls two-way authentication, need to write
the ca.pem path.
server_pem_path (str): If use tls one-way authentication, need to
write the server.pem path.
server_name (str): If use tls, need to write the common name.
Example:
.. code-block:: python
from langchain_community.vectorstores import Zilliz
from langchain_community.embeddings import OpenAIEmbeddings
embedding = OpenAIEmbeddings()
# Connect to a Zilliz instance
milvus_store = Milvus(
embedding_function = embedding,
collection_name = "LangChainCollection",
connection_args = {
"uri": "https://in03-ba4234asae.api.gcp-us-west1.zillizcloud.com",
"user": "temp",
"password": "temp",
"token": "temp", # API key as replacements for user and password
"secure": True
}
drop_old: True,
)
Raises:
ValueError: If the pymilvus python package is not installed.
"""
def _create_index(self) -> None:
"""Create a index on the collection"""
from pymilvus import Collection, MilvusException
if isinstance(self.col, Collection) and self._get_index() is None:
try:
# If no index params, use a default AutoIndex based one
if self.index_params is None:
self.index_params = {
"metric_type": "L2",
"index_type": "AUTOINDEX",
"params": {},
}
try:
self.col.create_index(
self._vector_field,
index_params=self.index_params,
using=self.alias,
)
# If default did not work, most likely Milvus self-hosted
except MilvusException:
# Use HNSW based index
self.index_params = {
"metric_type": "L2",
"index_type": "HNSW",
"params": {"M": 8, "efConstruction": 64},
}
self.col.create_index(
self._vector_field,
index_params=self.index_params,
using=self.alias,
)
logger.debug(
"Successfully created an index on collection: %s",
self.collection_name,
)
except MilvusException as e:
logger.error(
"Failed to create an index on collection: %s", self.collection_name
)
raise e
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Union[Embeddings, BaseSparseEmbedding],
metadatas: Optional[List[dict]] = None,
collection_name: str = "LangChainCollection",
connection_args: Optional[Dict[str, Any]] = None,
consistency_level: str = "Session",
index_params: Optional[dict] = None,
search_params: Optional[dict] = None,
drop_old: bool = False,
*,
ids: Optional[List[str]] = None,
auto_id: bool = False,
**kwargs: Any,
) -> Zilliz:
"""Create a Zilliz collection, indexes it with HNSW, and insert data.
Args:
texts (List[str]): Text data.
embedding (Embeddings): Embedding function.
metadatas (Optional[List[dict]]): Metadata for each text if it exists.
Defaults to None.
collection_name (str, optional): Collection name to use. Defaults to
"LangChainCollection".
connection_args (dict[str, Any], optional): Connection args to use. Defaults
to DEFAULT_MILVUS_CONNECTION.
consistency_level (str, optional): Which consistency level to use. Defaults
to "Session".
index_params (Optional[dict], optional): Which index_params to use.
Defaults to None.
search_params (Optional[dict], optional): Which search params to use.
Defaults to None.
drop_old (Optional[bool], optional): Whether to drop the collection with
that name if it exists. Defaults to False.
ids (Optional[List[str]]): List of text ids.
auto_id (bool): Whether to enable auto id for primary key. Defaults to
False. If False, you needs to provide text ids (string less than 65535
bytes). If True, Milvus will generate unique integers as primary keys.
Returns:
Zilliz: Zilliz Vector Store
"""
vector_db = cls(
embedding_function=embedding,
collection_name=collection_name,
connection_args=connection_args or {},
consistency_level=consistency_level,
index_params=index_params,
search_params=search_params,
drop_old=drop_old,
auto_id=auto_id,
**kwargs,
)
vector_db.add_texts(texts=texts, metadatas=metadatas, ids=ids)
return vector_db

File diff suppressed because it is too large Load Diff

View File

@ -1,123 +0,0 @@
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "langchain-milvus"
version = "0.1.5"
description = "An integration package connecting Milvus and LangChain"
authors = []
readme = "README.md"
repository = "https://github.com/langchain-ai/langchain"
license = "MIT"
[tool.ruff]
select = ["E", "F", "I", "T201"]
[tool.mypy]
disallow_untyped_defs = "True"
[[tool.mypy.overrides]]
module = ["pymilvus"]
ignore_missing_imports = "True"
[tool.poetry.urls]
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/milvus"
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-milvus%3D%3D0%22&expanded=true"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
pymilvus = "^2.4.3"
[[tool.poetry.dependencies.langchain-core]]
version = ">=0.2.38,<0.4"
python = ">=3.9"
[[tool.poetry.dependencies.langchain-core]]
version = ">=0.2.38,<0.3"
python = "<3.9"
[[tool.poetry.dependencies.scipy]]
version = "^1.7"
python = "<3.12"
[[tool.poetry.dependencies.scipy]]
version = "^1.9"
python = ">=3.12"
[tool.coverage.run]
omit = ["tests/*"]
[tool.pytest.ini_options]
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
markers = [
"requires: mark tests as requiring a specific library",
"asyncio: mark tests as requiring asyncio",
"compile: mark placeholder test used to compile integration tests without running them",
]
asyncio_mode = "auto"
[tool.poetry.group.test]
optional = true
[tool.poetry.group.codespell]
optional = true
[tool.poetry.group.test_integration]
optional = true
[tool.poetry.group.lint]
optional = true
[tool.poetry.group.dev]
optional = true
[tool.poetry.group.test.dependencies]
pytest = "^7.3.0"
freezegun = "^1.2.2"
pytest-mock = "^3.10.0"
syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
milvus_model = "^0.2.0"
[[tool.poetry.group.test.dependencies.langchain-core]]
path = "../../core"
develop = true
python = ">=3.9"
[[tool.poetry.group.test.dependencies.langchain-core]]
version = ">=0.2.38,<0.3"
python = "<3.9"
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"
[tool.poetry.group.test_integration.dependencies]
milvus_model = "^0.2.0"
[tool.poetry.group.lint.dependencies]
ruff = "^0.1.5"
[tool.poetry.group.typing.dependencies]
mypy = "^0.991"
types-requests = "^2"
simsimd = "^5.0.0"
[[tool.poetry.group.typing.dependencies.langchain-core]]
path = "../../core"
develop = true
python = ">=3.9"
[[tool.poetry.group.typing.dependencies.langchain-core]]
version = ">=0.2.38,<0.3"
python = "<3.9"
[tool.poetry.group.dev.dependencies]
[[tool.poetry.group.dev.dependencies.langchain-core]]
path = "../../core"
develop = true
python = ">=3.9"
[[tool.poetry.group.dev.dependencies.langchain-core]]
version = ">=0.2.38,<0.3"
python = "<3.9"

View File

@ -1,17 +0,0 @@
import sys
import traceback
from importlib.machinery import SourceFileLoader
if __name__ == "__main__":
files = sys.argv[1:]
has_failure = False
for file in files:
try:
SourceFileLoader("x", file).load_module()
except Exception:
has_failure = True
print(file) # noqa: T201
traceback.print_exc()
print() # noqa: T201
sys.exit(1 if has_failure else 0)

View File

@ -1,17 +0,0 @@
#!/bin/bash
set -eu
# Initialize a variable to keep track of errors
errors=0
# make sure not importing from langchain or langchain_experimental
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
# Decide on an exit status based on the errors
if [ "$errors" -gt 0 ]; then
exit 1
else
exit 0
fi

View File

@ -1,7 +0,0 @@
import pytest
@pytest.mark.compile
def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests."""
pass

View File

@ -1,40 +0,0 @@
from typing import List
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
fake_texts = ["foo", "bar", "baz"]
class FakeEmbeddings(Embeddings):
"""Fake embeddings functionality for testing."""
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Return simple embeddings.
Embeddings encode each text as its index."""
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
return self.embed_documents(texts)
def embed_query(self, text: str) -> List[float]:
"""Return constant query embeddings.
Embeddings are identical to embed_documents(texts)[0].
Distance to each text will be that text's index,
as it was passed to embed_documents."""
return [float(1.0)] * 9 + [float(0.0)]
async def aembed_query(self, text: str) -> List[float]:
return self.embed_query(text)
def assert_docs_equal_without_pk(
docs1: List[Document], docs2: List[Document], pk_field: str = "pk"
) -> None:
"""Assert two lists of Documents are equal, ignoring the primary key field."""
assert len(docs1) == len(docs2)
for doc1, doc2 in zip(docs1, docs2):
assert doc1.page_content == doc2.page_content
doc1.metadata.pop(pk_field, None)
doc2.metadata.pop(pk_field, None)
assert doc1.metadata == doc2.metadata

View File

@ -1,430 +0,0 @@
"""Test Milvus functionality."""
import tempfile
from typing import Any, List, Optional
import pytest
from langchain_core.documents import Document
from langchain_milvus.utils.sparse import BM25SparseEmbedding
from langchain_milvus.vectorstores import Milvus
from tests.integration_tests.utils import (
FakeEmbeddings,
assert_docs_equal_without_pk,
fake_texts,
)
#
# To run this test properly, please start a Milvus server with the following command:
#
# ```shell
# wget https://raw.githubusercontent.com/milvus-io/milvus/master/scripts/standalone_embed.sh
# bash standalone_embed.sh start
# ```
#
# Here is the reference:
# https://milvus.io/docs/install_standalone-docker.md
#
@pytest.fixture
def temp_milvus_db() -> Any:
with tempfile.NamedTemporaryFile(suffix=".db") as temp_file:
yield temp_file.name
def _milvus_from_texts(
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
drop: bool = True,
db_path: str = "./milvus_demo.db",
**kwargs: Any,
) -> Milvus:
return Milvus.from_texts(
fake_texts,
FakeEmbeddings(),
metadatas=metadatas,
ids=ids,
# connection_args={"uri": "http://127.0.0.1:19530"},
connection_args={"uri": db_path},
drop_old=drop,
consistency_level="Strong",
**kwargs,
)
def _get_pks(expr: str, docsearch: Milvus) -> List[Any]:
return docsearch.get_pks(expr) # type: ignore[return-value]
def test_milvus(temp_milvus_db: Any) -> None:
"""Test end to end construction and search."""
docsearch = _milvus_from_texts(db_path=temp_milvus_db)
output = docsearch.similarity_search("foo", k=1)
assert_docs_equal_without_pk(output, [Document(page_content="foo")])
def test_milvus_vector_search(temp_milvus_db: Any) -> None:
"""Test end to end construction and search by vector."""
docsearch = _milvus_from_texts(db_path=temp_milvus_db)
output = docsearch.similarity_search_by_vector(
FakeEmbeddings().embed_query("foo"), k=1
)
assert_docs_equal_without_pk(output, [Document(page_content="foo")])
def test_milvus_with_metadata(temp_milvus_db: Any) -> None:
"""Test with metadata"""
docsearch = _milvus_from_texts(
metadatas=[{"label": "test"}] * len(fake_texts), db_path=temp_milvus_db
)
output = docsearch.similarity_search("foo", k=1)
assert_docs_equal_without_pk(
output, [Document(page_content="foo", metadata={"label": "test"})]
)
def test_milvus_with_id(temp_milvus_db: Any) -> None:
"""Test with ids"""
ids = ["id_" + str(i) for i in range(len(fake_texts))]
docsearch = _milvus_from_texts(ids=ids, db_path=temp_milvus_db)
output = docsearch.similarity_search("foo", k=1)
assert_docs_equal_without_pk(output, [Document(page_content="foo")])
output = docsearch.delete(ids=ids)
assert output.delete_count == len(fake_texts) # type: ignore[attr-defined]
try:
ids = ["dup_id" for _ in fake_texts]
_milvus_from_texts(ids=ids, db_path=temp_milvus_db)
except Exception as e:
assert isinstance(e, AssertionError)
def test_milvus_with_score(temp_milvus_db: Any) -> None:
"""Test end to end construction and search with scores and IDs."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _milvus_from_texts(metadatas=metadatas, db_path=temp_milvus_db)
output = docsearch.similarity_search_with_score("foo", k=3)
docs = [o[0] for o in output]
scores = [o[1] for o in output]
assert_docs_equal_without_pk(
docs,
[
Document(page_content="foo", metadata={"page": 0}),
Document(page_content="bar", metadata={"page": 1}),
Document(page_content="baz", metadata={"page": 2}),
],
)
assert scores[0] < scores[1] < scores[2]
def test_milvus_max_marginal_relevance_search(temp_milvus_db: Any) -> None:
"""Test end to end construction and MRR search."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _milvus_from_texts(metadatas=metadatas, db_path=temp_milvus_db)
output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3)
assert_docs_equal_without_pk(
output,
[
Document(page_content="foo", metadata={"page": 0}),
Document(page_content="bar", metadata={"page": 1}),
],
)
def test_milvus_max_marginal_relevance_search_with_dynamic_field(
temp_milvus_db: Any,
) -> None:
"""Test end to end construction and MRR search with enabling dynamic field."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _milvus_from_texts(
metadatas=metadatas, enable_dynamic_field=True, db_path=temp_milvus_db
)
output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3)
assert_docs_equal_without_pk(
output,
[
Document(page_content="foo", metadata={"page": 0}),
Document(page_content="bar", metadata={"page": 1}),
],
)
def test_milvus_add_extra(temp_milvus_db: Any) -> None:
"""Test end to end construction and MRR search."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _milvus_from_texts(metadatas=metadatas, db_path=temp_milvus_db)
docsearch.add_texts(texts, metadatas)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 6
def test_milvus_no_drop(temp_milvus_db: Any) -> None:
"""Test construction without dropping old data."""
texts = ["foo", "bar", "baz"]
metadatas = [{"page": i} for i in range(len(texts))]
docsearch = _milvus_from_texts(metadatas=metadatas, db_path=temp_milvus_db)
del docsearch
docsearch = _milvus_from_texts(
metadatas=metadatas, drop=False, db_path=temp_milvus_db
)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 6
def test_milvus_get_pks(temp_milvus_db: Any) -> None:
"""Test end to end construction and get pks with expr"""
texts = ["foo", "bar", "baz"]
metadatas = [{"id": i} for i in range(len(texts))]
docsearch = _milvus_from_texts(metadatas=metadatas, db_path=temp_milvus_db)
expr = "id in [1,2]"
output = _get_pks(expr, docsearch)
assert len(output) == 2
def test_milvus_delete_entities(temp_milvus_db: Any) -> None:
"""Test end to end construction and delete entities"""
texts = ["foo", "bar", "baz"]
metadatas = [{"id": i} for i in range(len(texts))]
docsearch = _milvus_from_texts(metadatas=metadatas, db_path=temp_milvus_db)
expr = "id in [1,2]"
pks = _get_pks(expr, docsearch)
result = docsearch.delete(pks)
assert result.delete_count == 2 # type: ignore[attr-defined]
def test_milvus_upsert_entities(temp_milvus_db: Any) -> None:
"""Test end to end construction and upsert entities"""
texts = ["foo", "bar", "baz"]
metadatas = [{"id": i} for i in range(len(texts))]
docsearch = _milvus_from_texts(metadatas=metadatas, db_path=temp_milvus_db)
expr = "id in [1,2]"
pks = _get_pks(expr, docsearch)
documents = [
Document(page_content="test_1", metadata={"id": 1}),
Document(page_content="test_2", metadata={"id": 3}),
]
ids = docsearch.upsert(pks, documents)
assert len(ids) == 2 # type: ignore[arg-type]
def test_milvus_enable_dynamic_field(temp_milvus_db: Any) -> None:
"""Test end to end construction and enable dynamic field"""
texts = ["foo", "bar", "baz"]
metadatas = [{"id": i} for i in range(len(texts))]
docsearch = _milvus_from_texts(
metadatas=metadatas, enable_dynamic_field=True, db_path=temp_milvus_db
)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 3
# When enable dynamic field, any new field data will be added to the collection.
new_metadatas = [{"id_new": i} for i in range(len(texts))]
docsearch.add_texts(texts, new_metadatas)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 6
assert set(docsearch.fields) == {
docsearch._primary_field,
docsearch._text_field,
docsearch._vector_field,
}
def test_milvus_disable_dynamic_field(temp_milvus_db: Any) -> None:
"""Test end to end construction and disable dynamic field"""
texts = ["foo", "bar", "baz"]
metadatas = [{"id": i} for i in range(len(texts))]
docsearch = _milvus_from_texts(
metadatas=metadatas, enable_dynamic_field=False, db_path=temp_milvus_db
)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 3
# ["pk", "text", "vector", "id"]
assert set(docsearch.fields) == {
docsearch._primary_field,
docsearch._text_field,
docsearch._vector_field,
"id",
}
# Try to add new fields "id_new", but since dynamic field is disabled,
# all fields in the collection is specified as ["pk", "text", "vector", "id"],
# new field information "id_new" will not be added.
new_metadatas = [{"id": i, "id_new": i} for i in range(len(texts))]
docsearch.add_texts(texts, new_metadatas)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 6
for doc in output:
assert set(doc.metadata.keys()) == {"id", "pk"} # `id_new` is not added.
# When disable dynamic field,
# missing data of the created fields "id", will raise an exception.
with pytest.raises(Exception):
new_metadatas = [{"id_new": i} for i in range(len(texts))]
docsearch.add_texts(texts, new_metadatas)
def test_milvus_metadata_field(temp_milvus_db: Any) -> None:
"""Test end to end construction and use metadata field"""
texts = ["foo", "bar", "baz"]
metadatas = [{"id": i} for i in range(len(texts))]
docsearch = _milvus_from_texts(
metadatas=metadatas, metadata_field="metadata", db_path=temp_milvus_db
)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 3
new_metadatas = [{"id_new": i} for i in range(len(texts))]
docsearch.add_texts(texts, new_metadatas)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 6
assert set(docsearch.fields) == {
docsearch._primary_field,
docsearch._text_field,
docsearch._vector_field,
docsearch._metadata_field,
}
def test_milvus_enable_dynamic_field_with_partition_key(temp_milvus_db: Any) -> None:
"""
Test end to end construction and enable dynamic field
with partition_key_field
"""
texts = ["foo", "bar", "baz"]
metadatas = [{"id": i, "namespace": f"name_{i}"} for i in range(len(texts))]
docsearch = _milvus_from_texts(
metadatas=metadatas,
enable_dynamic_field=True,
partition_key_field="namespace",
db_path=temp_milvus_db,
)
# filter on a single namespace
output = docsearch.similarity_search("foo", k=10, expr="namespace == 'name_2'")
assert len(output) == 1
# without namespace filter
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 3
assert set(docsearch.fields) == {
docsearch._primary_field,
docsearch._text_field,
docsearch._vector_field,
docsearch._partition_key_field,
}
def test_milvus_sparse_embeddings() -> None:
texts = [
"In 'The Clockwork Kingdom' by Augusta Wynter, a brilliant inventor discovers "
"a hidden world of clockwork machines and ancient magic, where a rebellion is "
"brewing against the tyrannical ruler of the land.",
"In 'The Phantom Pilgrim' by Rowan Welles, a charismatic smuggler is hired by "
"a mysterious organization to transport a valuable artifact across a war-torn "
"continent, but soon finds themselves pursued by assassins and rival factions.",
"In 'The Dreamwalker's Journey' by Lyra Snow, a young dreamwalker discovers "
"she has the ability to enter people's dreams, but soon finds herself trapped "
"in a surreal world of nightmares and illusions, where the boundaries between "
"reality and fantasy blur.",
]
try:
sparse_embedding_func = BM25SparseEmbedding(corpus=texts)
except LookupError:
import nltk # type: ignore[import]
nltk.download("punkt_tab")
sparse_embedding_func = BM25SparseEmbedding(corpus=texts)
with tempfile.NamedTemporaryFile(suffix=".db") as temp_db:
docsearch = Milvus.from_texts(
embedding=sparse_embedding_func,
texts=texts,
connection_args={"uri": temp_db.name},
drop_old=True,
)
output = docsearch.similarity_search("Pilgrim", k=1)
assert "Pilgrim" in output[0].page_content
def test_milvus_array_field(temp_milvus_db: Any) -> None:
"""Manually specify metadata schema, including an array_field.
For more information about array data type and filtering, please refer to
https://milvus.io/docs/array_data_type.md
"""
from pymilvus import DataType
texts = ["foo", "bar", "baz"]
metadatas = [{"id": i, "array_field": [i, i + 1, i + 2]} for i in range(len(texts))]
# Manually specify metadata schema, including an array_field.
# If some fields are not specified, Milvus will automatically infer their schemas.
docsearch = _milvus_from_texts(
metadatas=metadatas,
metadata_schema={
"array_field": {
"dtype": DataType.ARRAY,
"kwargs": {"element_type": DataType.INT64, "max_capacity": 50},
},
# "id": {
# "dtype": DataType.INT64,
# }
},
db_path=temp_milvus_db,
)
output = docsearch.similarity_search("foo", k=10, expr="array_field[0] < 2")
assert len(output) == 2
output = docsearch.similarity_search(
"foo", k=10, expr="ARRAY_CONTAINS(array_field, 3)"
)
assert len(output) == 2
# If we use enable_dynamic_field,
# there is no need to manually specify metadata schema.
docsearch = _milvus_from_texts(
enable_dynamic_field=True,
metadatas=metadatas,
db_path=temp_milvus_db,
)
output = docsearch.similarity_search("foo", k=10, expr="array_field[0] < 2")
assert len(output) == 2
output = docsearch.similarity_search(
"foo", k=10, expr="ARRAY_CONTAINS(array_field, 3)"
)
assert len(output) == 2
# if __name__ == "__main__":
# test_milvus()
# test_milvus_vector_search()
# test_milvus_with_metadata()
# test_milvus_with_id()
# test_milvus_with_score()
# test_milvus_max_marginal_relevance_search()
# test_milvus_max_marginal_relevance_search_with_dynamic_field()
# test_milvus_add_extra()
# test_milvus_no_drop()
# test_milvus_get_pks()
# test_milvus_delete_entities()
# test_milvus_upsert_entities()
# test_milvus_enable_dynamic_field()
# test_milvus_disable_dynamic_field()
# test_milvus_metadata_field()
# test_milvus_enable_dynamic_field_with_partition_key()
# test_milvus_sparse_embeddings()
# test_milvus_array_field()

View File

@ -1,12 +0,0 @@
from langchain_milvus import __all__
EXPECTED_ALL = [
"Milvus",
"MilvusCollectionHybridSearchRetriever",
"Zilliz",
"ZillizCloudPipelineRetriever",
]
def test_all_imports() -> None:
assert sorted(EXPECTED_ALL) == sorted(__all__)

View File

@ -1,17 +0,0 @@
import os
from tempfile import TemporaryDirectory
from unittest.mock import Mock
from langchain_milvus.vectorstores import Milvus
def test_initialization() -> None:
"""Test integration milvus initialization."""
embedding = Mock()
with TemporaryDirectory() as tmp_dir:
Milvus(
embedding_function=embedding,
connection_args={
"uri": os.path.join(tmp_dir, "milvus.db"),
},
)