mirror of
https://github.com/hwchase17/langchain
synced 2024-11-13 19:10:52 +00:00
milvus: mv to external repo (#26920)
This commit is contained in:
parent
35f6393144
commit
a8e1577f85
1
libs/partners/milvus/.gitignore
vendored
1
libs/partners/milvus/.gitignore
vendored
@ -1 +0,0 @@
|
||||
__pycache__
|
@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
@ -1,59 +0,0 @@
|
||||
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
integration_test integration_tests: TEST_FILE=tests/integration_tests/
|
||||
|
||||
test tests integration_test integration_tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
test_watch:
|
||||
poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE)
|
||||
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
MYPY_CACHE=.mypy_cache
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/milvus --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
lint_package: PYTHON_FILES=langchain_milvus
|
||||
lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
|
||||
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES)
|
||||
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
check_imports: $(shell find langchain_milvus -name '*.py')
|
||||
poetry run python ./scripts/check_imports.py $^
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '----'
|
||||
@echo 'check_imports - check imports'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
@ -1,42 +1,3 @@
|
||||
# langchain-milvus
|
||||
This package has moved!
|
||||
|
||||
This is a library integration with [Milvus](https://milvus.io/) and [Zilliz Cloud](https://zilliz.com/cloud).
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install -U langchain-milvus
|
||||
```
|
||||
|
||||
## Milvus vector database
|
||||
|
||||
See a [usage example](https://python.langchain.com/docs/integrations/vectorstores/milvus/)
|
||||
|
||||
```python
|
||||
from langchain_milvus import Milvus
|
||||
```
|
||||
|
||||
## Milvus hybrid search
|
||||
|
||||
See a [usage example](https://python.langchain.com/docs/integrations/retrievers/milvus_hybrid_search/).
|
||||
|
||||
```python
|
||||
from langchain_milvus import MilvusCollectionHybridSearchRetriever
|
||||
```
|
||||
|
||||
|
||||
## Zilliz Cloud vector database
|
||||
|
||||
See a [usage example](https://python.langchain.com/docs/integrations/vectorstores/zilliz/).
|
||||
|
||||
```python
|
||||
from langchain_milvus import Zilliz
|
||||
```
|
||||
|
||||
## Zilliz Cloud Pipeline Retriever
|
||||
|
||||
See a [usage example](https://python.langchain.com/docs/integrations/retrievers/zilliz_cloud_pipeline/).
|
||||
|
||||
```python
|
||||
from langchain_milvus import ZillizCloudPipelineRetriever
|
||||
```
|
||||
https://github.com/langchain-ai/langchain-milvus/tree/main/libs/milvus
|
||||
|
@ -1,12 +0,0 @@
|
||||
from langchain_milvus.retrievers import (
|
||||
MilvusCollectionHybridSearchRetriever,
|
||||
ZillizCloudPipelineRetriever,
|
||||
)
|
||||
from langchain_milvus.vectorstores import Milvus, Zilliz
|
||||
|
||||
__all__ = [
|
||||
"Milvus",
|
||||
"Zilliz",
|
||||
"ZillizCloudPipelineRetriever",
|
||||
"MilvusCollectionHybridSearchRetriever",
|
||||
]
|
@ -1,8 +0,0 @@
|
||||
from langchain_milvus.retrievers.milvus_hybrid_search import (
|
||||
MilvusCollectionHybridSearchRetriever,
|
||||
)
|
||||
from langchain_milvus.retrievers.zilliz_cloud_pipeline_retriever import (
|
||||
ZillizCloudPipelineRetriever,
|
||||
)
|
||||
|
||||
__all__ = ["ZillizCloudPipelineRetriever", "MilvusCollectionHybridSearchRetriever"]
|
@ -1,161 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from pymilvus import AnnSearchRequest, Collection
|
||||
from pymilvus.client.abstract import BaseRanker, SearchResult # type: ignore
|
||||
|
||||
from langchain_milvus.utils.sparse import BaseSparseEmbedding
|
||||
|
||||
|
||||
class MilvusCollectionHybridSearchRetriever(BaseRetriever):
|
||||
"""Hybrid search retriever
|
||||
that uses Milvus Collection to retrieve documents based on multiple fields.
|
||||
|
||||
For more information, please refer to:
|
||||
https://milvus.io/docs/release_notes.md#Multi-Embedding---Hybrid-Search
|
||||
"""
|
||||
|
||||
collection: Collection
|
||||
"""Milvus Collection object."""
|
||||
rerank: BaseRanker
|
||||
"""Milvus ranker object. Such as WeightedRanker or RRFRanker."""
|
||||
anns_fields: List[str]
|
||||
"""The names of vector fields that are used for ANNS search."""
|
||||
field_embeddings: List[Union[Embeddings, BaseSparseEmbedding]]
|
||||
"""The embedding functions of each vector fields,
|
||||
which can be either Embeddings or BaseSparseEmbedding."""
|
||||
field_search_params: Optional[List[Dict]] = None
|
||||
"""The search parameters of each vector fields.
|
||||
If not specified, the default search parameters will be used."""
|
||||
field_limits: Optional[List[int]] = None
|
||||
"""Limit number of results for each ANNS field.
|
||||
If not specified, the default top_k will be used."""
|
||||
field_exprs: Optional[List[Optional[str]]] = None
|
||||
"""The boolean expression for filtering the search results."""
|
||||
top_k: int = 4
|
||||
"""Final top-K number of documents to retrieve."""
|
||||
text_field: str = "text"
|
||||
"""The text field name,
|
||||
which will be used as the `page_content` of a `Document` object."""
|
||||
output_fields: Optional[List[str]] = None
|
||||
"""Final output fields of the documents.
|
||||
If not specified, all fields except the vector fields will be used as output fields,
|
||||
which will be the `metadata` of a `Document` object."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# If some parameters are not specified, set default values
|
||||
if self.field_search_params is None:
|
||||
default_search_params = {
|
||||
"metric_type": "L2",
|
||||
"params": {"nprobe": 10},
|
||||
}
|
||||
self.field_search_params = [default_search_params] * len(self.anns_fields)
|
||||
if self.field_limits is None:
|
||||
self.field_limits = [self.top_k] * len(self.anns_fields)
|
||||
if self.field_exprs is None:
|
||||
self.field_exprs = [None] * len(self.anns_fields)
|
||||
|
||||
# Check the fields
|
||||
self._validate_fields_num()
|
||||
self.output_fields = self._get_output_fields()
|
||||
self._validate_fields_name()
|
||||
|
||||
# Load collection
|
||||
self.collection.load()
|
||||
|
||||
def _validate_fields_num(self) -> None:
|
||||
assert (
|
||||
len(self.anns_fields) >= 2
|
||||
), "At least two fields are required for hybrid search."
|
||||
lengths = [len(self.anns_fields)]
|
||||
if self.field_limits is not None:
|
||||
lengths.append(len(self.field_limits))
|
||||
if self.field_exprs is not None:
|
||||
lengths.append(len(self.field_exprs))
|
||||
|
||||
if not all(length == lengths[0] for length in lengths):
|
||||
raise ValueError("All field-related lists must have the same length.")
|
||||
|
||||
if len(self.field_search_params) != len(self.anns_fields): # type: ignore[arg-type]
|
||||
raise ValueError(
|
||||
"field_search_params must have the same length as anns_fields."
|
||||
)
|
||||
|
||||
def _validate_fields_name(self) -> None:
|
||||
collection_fields = [x.name for x in self.collection.schema.fields]
|
||||
for field in self.anns_fields:
|
||||
assert (
|
||||
field in collection_fields
|
||||
), f"{field} is not a valid field in the collection."
|
||||
assert (
|
||||
self.text_field in collection_fields
|
||||
), f"{self.text_field} is not a valid field in the collection."
|
||||
for field in self.output_fields: # type: ignore[union-attr]
|
||||
assert (
|
||||
field in collection_fields
|
||||
), f"{field} is not a valid field in the collection."
|
||||
|
||||
def _get_output_fields(self) -> List[str]:
|
||||
if self.output_fields:
|
||||
return self.output_fields
|
||||
output_fields = [x.name for x in self.collection.schema.fields]
|
||||
for field in self.anns_fields:
|
||||
if field in output_fields:
|
||||
output_fields.remove(field)
|
||||
if self.text_field not in output_fields:
|
||||
output_fields.append(self.text_field)
|
||||
return output_fields
|
||||
|
||||
def _build_ann_search_requests(self, query: str) -> List[AnnSearchRequest]:
|
||||
search_requests = []
|
||||
for ann_field, embedding, param, limit, expr in zip(
|
||||
self.anns_fields,
|
||||
self.field_embeddings,
|
||||
self.field_search_params, # type: ignore[arg-type]
|
||||
self.field_limits, # type: ignore[arg-type]
|
||||
self.field_exprs, # type: ignore[arg-type]
|
||||
):
|
||||
request = AnnSearchRequest(
|
||||
data=[embedding.embed_query(query)],
|
||||
anns_field=ann_field,
|
||||
param=param,
|
||||
limit=limit,
|
||||
expr=expr,
|
||||
)
|
||||
search_requests.append(request)
|
||||
return search_requests
|
||||
|
||||
def _parse_document(self, data: dict) -> Document:
|
||||
return Document(
|
||||
page_content=data.pop(self.text_field),
|
||||
metadata=data,
|
||||
)
|
||||
|
||||
def _process_search_result(
|
||||
self, search_results: List[SearchResult]
|
||||
) -> List[Document]:
|
||||
documents = []
|
||||
for result in search_results[0]:
|
||||
data = {x: result.entity.get(x) for x in self.output_fields} # type: ignore[union-attr]
|
||||
doc = self._parse_document(data)
|
||||
documents.append(doc)
|
||||
return documents
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
requests = self._build_ann_search_requests(query)
|
||||
search_result = self.collection.hybrid_search(
|
||||
requests, self.rerank, limit=self.top_k, output_fields=self.output_fields
|
||||
)
|
||||
documents = self._process_search_result(search_result)
|
||||
return documents
|
@ -1,215 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
|
||||
|
||||
class ZillizCloudPipelineRetriever(BaseRetriever):
|
||||
"""`Zilliz Cloud Pipeline` retriever.
|
||||
|
||||
Parameters:
|
||||
pipeline_ids: A dictionary of pipeline ids.
|
||||
Valid keys: "ingestion", "search", "deletion".
|
||||
token: Zilliz Cloud's token. Defaults to "".
|
||||
cloud_region: The region of Zilliz Cloud's cluster.
|
||||
Defaults to 'gcp-us-west1'.
|
||||
"""
|
||||
|
||||
pipeline_ids: Dict
|
||||
token: str = ""
|
||||
cloud_region: str = "gcp-us-west1"
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 10,
|
||||
offset: int = 0,
|
||||
output_fields: List = [],
|
||||
filter: str = "",
|
||||
*,
|
||||
run_manager: CallbackManagerForRetrieverRun,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Get documents relevant to a query.
|
||||
|
||||
Args:
|
||||
query: String to find relevant documents for
|
||||
top_k: The number of results. Defaults to 10.
|
||||
offset: The number of records to skip in the search result.
|
||||
Defaults to 0.
|
||||
output_fields: The extra fields to present in output.
|
||||
filter: The Milvus expression to filter search results.
|
||||
Defaults to "".
|
||||
run_manager: The callbacks handler to use.
|
||||
|
||||
Returns:
|
||||
List of relevant documents
|
||||
"""
|
||||
if "search" in self.pipeline_ids:
|
||||
search_pipe_id = self.pipeline_ids.get("search")
|
||||
else:
|
||||
raise Exception(
|
||||
"A search pipeline id must be provided in pipeline_ids to "
|
||||
"get relevant documents."
|
||||
)
|
||||
domain = (
|
||||
f"https://controller.api.{self.cloud_region}.zillizcloud.com/v1/pipelines"
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.token}",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{domain}/{search_pipe_id}/run"
|
||||
|
||||
params = {
|
||||
"data": {"query_text": query},
|
||||
"params": {
|
||||
"limit": top_k,
|
||||
"offset": offset,
|
||||
"outputFields": output_fields,
|
||||
"filter": filter,
|
||||
},
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=params)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(response.text)
|
||||
response_dict = response.json()
|
||||
if response_dict["code"] != 200:
|
||||
raise RuntimeError(response_dict)
|
||||
response_data = response_dict["data"]
|
||||
search_results = response_data["result"]
|
||||
return [
|
||||
Document(
|
||||
page_content=result.pop("text")
|
||||
if "text" in result
|
||||
else result.pop("chunk_text"),
|
||||
metadata=result,
|
||||
)
|
||||
for result in search_results
|
||||
]
|
||||
|
||||
def add_texts(
|
||||
self, texts: List[str], metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Add documents to store.
|
||||
Only supported by a text ingestion pipeline in Zilliz Cloud.
|
||||
|
||||
Args:
|
||||
texts: A list of text strings.
|
||||
metadata: A key-value dictionary of metadata will
|
||||
be inserted as preserved fields required by ingestion pipeline.
|
||||
Defaults to None.
|
||||
"""
|
||||
if "ingestion" in self.pipeline_ids:
|
||||
ingeset_pipe_id = self.pipeline_ids.get("ingestion")
|
||||
else:
|
||||
raise Exception(
|
||||
"An ingestion pipeline id must be provided in pipeline_ids to"
|
||||
" add documents."
|
||||
)
|
||||
domain = (
|
||||
f"https://controller.api.{self.cloud_region}.zillizcloud.com/v1/pipelines"
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.token}",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{domain}/{ingeset_pipe_id}/run"
|
||||
|
||||
metadata = {} if metadata is None else metadata
|
||||
params = {"data": {"text_list": texts}}
|
||||
params["data"].update(metadata)
|
||||
|
||||
response = requests.post(url, headers=headers, json=params)
|
||||
if response.status_code != 200:
|
||||
raise Exception(response.text)
|
||||
response_dict = response.json()
|
||||
if response_dict["code"] != 200:
|
||||
raise Exception(response_dict)
|
||||
response_data = response_dict["data"]
|
||||
return response_data
|
||||
|
||||
def add_doc_url(
|
||||
self, doc_url: str, metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Add a document from url.
|
||||
Only supported by a document ingestion pipeline in Zilliz Cloud.
|
||||
|
||||
Args:
|
||||
doc_url: A document url.
|
||||
metadata: A key-value dictionary of metadata will
|
||||
be inserted as preserved fields required by ingestion pipeline.
|
||||
Defaults to None.
|
||||
"""
|
||||
if "ingestion" in self.pipeline_ids:
|
||||
ingest_pipe_id = self.pipeline_ids.get("ingestion")
|
||||
else:
|
||||
raise Exception(
|
||||
"An ingestion pipeline id must be provided in pipeline_ids to "
|
||||
"add documents."
|
||||
)
|
||||
domain = (
|
||||
f"https://controller.api.{self.cloud_region}.zillizcloud.com/v1/pipelines"
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.token}",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{domain}/{ingest_pipe_id}/run"
|
||||
|
||||
params = {"data": {"doc_url": doc_url}}
|
||||
metadata = {} if metadata is None else metadata
|
||||
params["data"].update(metadata)
|
||||
|
||||
response = requests.post(url, headers=headers, json=params)
|
||||
if response.status_code != 200:
|
||||
raise Exception(response.text)
|
||||
response_dict = response.json()
|
||||
if response_dict["code"] != 200:
|
||||
raise Exception(response_dict)
|
||||
response_data = response_dict["data"]
|
||||
return response_data
|
||||
|
||||
def delete(self, key: str, value: Any) -> Dict:
|
||||
"""
|
||||
Delete documents. Only supported by a deletion pipeline in Zilliz Cloud.
|
||||
|
||||
Args:
|
||||
key: input name to run the deletion pipeline
|
||||
value: input value to run deletion pipeline
|
||||
"""
|
||||
if "deletion" in self.pipeline_ids:
|
||||
deletion_pipe_id = self.pipeline_ids.get("deletion")
|
||||
else:
|
||||
raise Exception(
|
||||
"A deletion pipeline id must be provided in pipeline_ids to "
|
||||
"add documents."
|
||||
)
|
||||
domain = (
|
||||
f"https://controller.api.{self.cloud_region}.zillizcloud.com/v1/pipelines"
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.token}",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{domain}/{deletion_pipe_id}/run"
|
||||
|
||||
params = {"data": {key: value}}
|
||||
|
||||
response = requests.post(url, headers=headers, json=params)
|
||||
if response.status_code != 200:
|
||||
raise Exception(response.text)
|
||||
response_dict = response.json()
|
||||
if response_dict["code"] != 200:
|
||||
raise Exception(response_dict)
|
||||
response_data = response_dict["data"]
|
||||
return response_data
|
@ -1,55 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List
|
||||
|
||||
from scipy.sparse import csr_array # type: ignore
|
||||
|
||||
|
||||
class BaseSparseEmbedding(ABC):
|
||||
"""Interface for Sparse embedding models.
|
||||
|
||||
You can inherit from it and implement your custom sparse embedding model.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, query: str) -> Dict[int, float]:
|
||||
"""Embed query text."""
|
||||
|
||||
@abstractmethod
|
||||
def embed_documents(self, texts: List[str]) -> List[Dict[int, float]]:
|
||||
"""Embed search docs."""
|
||||
|
||||
|
||||
class BM25SparseEmbedding(BaseSparseEmbedding):
|
||||
"""Sparse embedding model based on BM25.
|
||||
|
||||
This class uses the BM25 model in Milvus model to implement sparse vector embedding.
|
||||
This model requires pymilvus[model] to be installed.
|
||||
`pip install pymilvus[model]`
|
||||
For more information please refer to:
|
||||
https://milvus.io/docs/embed-with-bm25.md
|
||||
"""
|
||||
|
||||
def __init__(self, corpus: List[str], language: str = "en"):
|
||||
from pymilvus.model.sparse import BM25EmbeddingFunction # type: ignore
|
||||
from pymilvus.model.sparse.bm25.tokenizers import ( # type: ignore
|
||||
build_default_analyzer,
|
||||
)
|
||||
|
||||
self.analyzer = build_default_analyzer(language=language)
|
||||
self.bm25_ef = BM25EmbeddingFunction(self.analyzer, num_workers=1)
|
||||
self.bm25_ef.fit(corpus)
|
||||
|
||||
def embed_query(self, text: str) -> Dict[int, float]:
|
||||
return self._sparse_to_dict(self.bm25_ef.encode_queries([text]))
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[Dict[int, float]]:
|
||||
sparse_arrays = self.bm25_ef.encode_documents(texts)
|
||||
return [self._sparse_to_dict(sparse_array) for sparse_array in sparse_arrays]
|
||||
|
||||
def _sparse_to_dict(self, sparse_array: csr_array) -> Dict[int, float]:
|
||||
row_indices, col_indices = sparse_array.nonzero()
|
||||
non_zero_values = sparse_array.data
|
||||
result_dict = {}
|
||||
for col_index, value in zip(col_indices, non_zero_values):
|
||||
result_dict[col_index] = value
|
||||
return result_dict
|
@ -1,7 +0,0 @@
|
||||
from langchain_milvus.vectorstores.milvus import Milvus
|
||||
from langchain_milvus.vectorstores.zilliz import Zilliz
|
||||
|
||||
__all__ = [
|
||||
"Milvus",
|
||||
"Zilliz",
|
||||
]
|
File diff suppressed because it is too large
Load Diff
@ -1,197 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from langchain_milvus.utils.sparse import BaseSparseEmbedding
|
||||
from langchain_milvus.vectorstores.milvus import Milvus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Zilliz(Milvus):
|
||||
"""`Zilliz` vector store.
|
||||
|
||||
You need to have `pymilvus` installed and a
|
||||
running Zilliz database.
|
||||
|
||||
See the following documentation for how to run a Zilliz instance:
|
||||
https://docs.zilliz.com/docs/create-cluster
|
||||
|
||||
|
||||
IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA.
|
||||
|
||||
Args:
|
||||
embedding_function (Embeddings): Function used to embed the text.
|
||||
collection_name (str): Which Zilliz collection to use. Defaults to
|
||||
"LangChainCollection".
|
||||
connection_args (Optional[dict[str, any]]): The connection args used for
|
||||
this class comes in the form of a dict.
|
||||
consistency_level (str): The consistency level to use for a collection.
|
||||
Defaults to "Session".
|
||||
index_params (Optional[dict]): Which index params to use. Defaults to
|
||||
HNSW/AUTOINDEX depending on service.
|
||||
search_params (Optional[dict]): Which search params to use. Defaults to
|
||||
default of index.
|
||||
drop_old (Optional[bool]): Whether to drop the current collection. Defaults
|
||||
to False.
|
||||
auto_id (bool): Whether to enable auto id for primary key. Defaults to False.
|
||||
If False, you needs to provide text ids (string less than 65535 bytes).
|
||||
If True, Milvus will generate unique integers as primary keys.
|
||||
|
||||
The connection args used for this class comes in the form of a dict,
|
||||
here are a few of the options:
|
||||
address (str): The actual address of Zilliz
|
||||
instance. Example address: "localhost:19530"
|
||||
uri (str): The uri of Zilliz instance. Example uri:
|
||||
"https://in03-ba4234asae.api.gcp-us-west1.zillizcloud.com",
|
||||
host (str): The host of Zilliz instance. Default at "localhost",
|
||||
PyMilvus will fill in the default host if only port is provided.
|
||||
port (str/int): The port of Zilliz instance. Default at 19530, PyMilvus
|
||||
will fill in the default port if only host is provided.
|
||||
user (str): Use which user to connect to Zilliz instance. If user and
|
||||
password are provided, we will add related header in every RPC call.
|
||||
password (str): Required when user is provided. The password
|
||||
corresponding to the user.
|
||||
token (str): API key, for serverless clusters which can be used as
|
||||
replacements for user and password.
|
||||
secure (bool): Default is false. If set to true, tls will be enabled.
|
||||
client_key_path (str): If use tls two-way authentication, need to
|
||||
write the client.key path.
|
||||
client_pem_path (str): If use tls two-way authentication, need to
|
||||
write the client.pem path.
|
||||
ca_pem_path (str): If use tls two-way authentication, need to write
|
||||
the ca.pem path.
|
||||
server_pem_path (str): If use tls one-way authentication, need to
|
||||
write the server.pem path.
|
||||
server_name (str): If use tls, need to write the common name.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.vectorstores import Zilliz
|
||||
from langchain_community.embeddings import OpenAIEmbeddings
|
||||
|
||||
embedding = OpenAIEmbeddings()
|
||||
# Connect to a Zilliz instance
|
||||
milvus_store = Milvus(
|
||||
embedding_function = embedding,
|
||||
collection_name = "LangChainCollection",
|
||||
connection_args = {
|
||||
"uri": "https://in03-ba4234asae.api.gcp-us-west1.zillizcloud.com",
|
||||
"user": "temp",
|
||||
"password": "temp",
|
||||
"token": "temp", # API key as replacements for user and password
|
||||
"secure": True
|
||||
}
|
||||
drop_old: True,
|
||||
)
|
||||
|
||||
Raises:
|
||||
ValueError: If the pymilvus python package is not installed.
|
||||
"""
|
||||
|
||||
def _create_index(self) -> None:
|
||||
"""Create a index on the collection"""
|
||||
from pymilvus import Collection, MilvusException
|
||||
|
||||
if isinstance(self.col, Collection) and self._get_index() is None:
|
||||
try:
|
||||
# If no index params, use a default AutoIndex based one
|
||||
if self.index_params is None:
|
||||
self.index_params = {
|
||||
"metric_type": "L2",
|
||||
"index_type": "AUTOINDEX",
|
||||
"params": {},
|
||||
}
|
||||
|
||||
try:
|
||||
self.col.create_index(
|
||||
self._vector_field,
|
||||
index_params=self.index_params,
|
||||
using=self.alias,
|
||||
)
|
||||
|
||||
# If default did not work, most likely Milvus self-hosted
|
||||
except MilvusException:
|
||||
# Use HNSW based index
|
||||
self.index_params = {
|
||||
"metric_type": "L2",
|
||||
"index_type": "HNSW",
|
||||
"params": {"M": 8, "efConstruction": 64},
|
||||
}
|
||||
self.col.create_index(
|
||||
self._vector_field,
|
||||
index_params=self.index_params,
|
||||
using=self.alias,
|
||||
)
|
||||
logger.debug(
|
||||
"Successfully created an index on collection: %s",
|
||||
self.collection_name,
|
||||
)
|
||||
|
||||
except MilvusException as e:
|
||||
logger.error(
|
||||
"Failed to create an index on collection: %s", self.collection_name
|
||||
)
|
||||
raise e
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Union[Embeddings, BaseSparseEmbedding],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
collection_name: str = "LangChainCollection",
|
||||
connection_args: Optional[Dict[str, Any]] = None,
|
||||
consistency_level: str = "Session",
|
||||
index_params: Optional[dict] = None,
|
||||
search_params: Optional[dict] = None,
|
||||
drop_old: bool = False,
|
||||
*,
|
||||
ids: Optional[List[str]] = None,
|
||||
auto_id: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Zilliz:
|
||||
"""Create a Zilliz collection, indexes it with HNSW, and insert data.
|
||||
|
||||
Args:
|
||||
texts (List[str]): Text data.
|
||||
embedding (Embeddings): Embedding function.
|
||||
metadatas (Optional[List[dict]]): Metadata for each text if it exists.
|
||||
Defaults to None.
|
||||
collection_name (str, optional): Collection name to use. Defaults to
|
||||
"LangChainCollection".
|
||||
connection_args (dict[str, Any], optional): Connection args to use. Defaults
|
||||
to DEFAULT_MILVUS_CONNECTION.
|
||||
consistency_level (str, optional): Which consistency level to use. Defaults
|
||||
to "Session".
|
||||
index_params (Optional[dict], optional): Which index_params to use.
|
||||
Defaults to None.
|
||||
search_params (Optional[dict], optional): Which search params to use.
|
||||
Defaults to None.
|
||||
drop_old (Optional[bool], optional): Whether to drop the collection with
|
||||
that name if it exists. Defaults to False.
|
||||
ids (Optional[List[str]]): List of text ids.
|
||||
auto_id (bool): Whether to enable auto id for primary key. Defaults to
|
||||
False. If False, you needs to provide text ids (string less than 65535
|
||||
bytes). If True, Milvus will generate unique integers as primary keys.
|
||||
|
||||
Returns:
|
||||
Zilliz: Zilliz Vector Store
|
||||
"""
|
||||
vector_db = cls(
|
||||
embedding_function=embedding,
|
||||
collection_name=collection_name,
|
||||
connection_args=connection_args or {},
|
||||
consistency_level=consistency_level,
|
||||
index_params=index_params,
|
||||
search_params=search_params,
|
||||
drop_old=drop_old,
|
||||
auto_id=auto_id,
|
||||
**kwargs,
|
||||
)
|
||||
vector_db.add_texts(texts=texts, metadatas=metadatas, ids=ids)
|
||||
return vector_db
|
2297
libs/partners/milvus/poetry.lock
generated
2297
libs/partners/milvus/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,123 +0,0 @@
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.poetry]
|
||||
name = "langchain-milvus"
|
||||
version = "0.1.5"
|
||||
description = "An integration package connecting Milvus and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/langchain-ai/langchain"
|
||||
license = "MIT"
|
||||
|
||||
[tool.ruff]
|
||||
select = ["E", "F", "I", "T201"]
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["pymilvus"]
|
||||
ignore_missing_imports = "True"
|
||||
|
||||
[tool.poetry.urls]
|
||||
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/milvus"
|
||||
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22langchain-milvus%3D%3D0%22&expanded=true"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
pymilvus = "^2.4.3"
|
||||
|
||||
[[tool.poetry.dependencies.langchain-core]]
|
||||
version = ">=0.2.38,<0.4"
|
||||
python = ">=3.9"
|
||||
|
||||
[[tool.poetry.dependencies.langchain-core]]
|
||||
version = ">=0.2.38,<0.3"
|
||||
python = "<3.9"
|
||||
|
||||
[[tool.poetry.dependencies.scipy]]
|
||||
version = "^1.7"
|
||||
python = "<3.12"
|
||||
|
||||
[[tool.poetry.dependencies.scipy]]
|
||||
version = "^1.9"
|
||||
python = ">=3.12"
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
||||
markers = [
|
||||
"requires: mark tests as requiring a specific library",
|
||||
"asyncio: mark tests as requiring asyncio",
|
||||
"compile: mark placeholder test used to compile integration tests without running them",
|
||||
]
|
||||
asyncio_mode = "auto"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.3.0"
|
||||
freezegun = "^1.2.2"
|
||||
pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
milvus_model = "^0.2.0"
|
||||
|
||||
[[tool.poetry.group.test.dependencies.langchain-core]]
|
||||
path = "../../core"
|
||||
develop = true
|
||||
python = ">=3.9"
|
||||
|
||||
[[tool.poetry.group.test.dependencies.langchain-core]]
|
||||
version = ">=0.2.38,<0.3"
|
||||
python = "<3.9"
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.0"
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
milvus_model = "^0.2.0"
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.1.5"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^0.991"
|
||||
types-requests = "^2"
|
||||
simsimd = "^5.0.0"
|
||||
|
||||
[[tool.poetry.group.typing.dependencies.langchain-core]]
|
||||
path = "../../core"
|
||||
develop = true
|
||||
python = ">=3.9"
|
||||
|
||||
[[tool.poetry.group.typing.dependencies.langchain-core]]
|
||||
version = ">=0.2.38,<0.3"
|
||||
python = "<3.9"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
[[tool.poetry.group.dev.dependencies.langchain-core]]
|
||||
path = "../../core"
|
||||
develop = true
|
||||
python = ">=3.9"
|
||||
|
||||
[[tool.poetry.group.dev.dependencies.langchain-core]]
|
||||
version = ">=0.2.38,<0.3"
|
||||
python = "<3.9"
|
@ -1,17 +0,0 @@
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
files = sys.argv[1:]
|
||||
has_failure = False
|
||||
for file in files:
|
||||
try:
|
||||
SourceFileLoader("x", file).load_module()
|
||||
except Exception:
|
||||
has_failure = True
|
||||
print(file) # noqa: T201
|
||||
traceback.print_exc()
|
||||
print() # noqa: T201
|
||||
|
||||
sys.exit(1 if has_failure else 0)
|
@ -1,17 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eu
|
||||
|
||||
# Initialize a variable to keep track of errors
|
||||
errors=0
|
||||
|
||||
# make sure not importing from langchain or langchain_experimental
|
||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||
|
||||
# Decide on an exit status based on the errors
|
||||
if [ "$errors" -gt 0 ]; then
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
@ -1,7 +0,0 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
@ -1,40 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
fake_texts = ["foo", "bar", "baz"]
|
||||
|
||||
|
||||
class FakeEmbeddings(Embeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Return simple embeddings.
|
||||
Embeddings encode each text as its index."""
|
||||
return [[float(1.0)] * 9 + [float(i)] for i in range(len(texts))]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Return constant query embeddings.
|
||||
Embeddings are identical to embed_documents(texts)[0].
|
||||
Distance to each text will be that text's index,
|
||||
as it was passed to embed_documents."""
|
||||
return [float(1.0)] * 9 + [float(0.0)]
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
return self.embed_query(text)
|
||||
|
||||
|
||||
def assert_docs_equal_without_pk(
|
||||
docs1: List[Document], docs2: List[Document], pk_field: str = "pk"
|
||||
) -> None:
|
||||
"""Assert two lists of Documents are equal, ignoring the primary key field."""
|
||||
assert len(docs1) == len(docs2)
|
||||
for doc1, doc2 in zip(docs1, docs2):
|
||||
assert doc1.page_content == doc2.page_content
|
||||
doc1.metadata.pop(pk_field, None)
|
||||
doc2.metadata.pop(pk_field, None)
|
||||
assert doc1.metadata == doc2.metadata
|
@ -1,430 +0,0 @@
|
||||
"""Test Milvus functionality."""
|
||||
|
||||
import tempfile
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_milvus.utils.sparse import BM25SparseEmbedding
|
||||
from langchain_milvus.vectorstores import Milvus
|
||||
from tests.integration_tests.utils import (
|
||||
FakeEmbeddings,
|
||||
assert_docs_equal_without_pk,
|
||||
fake_texts,
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# To run this test properly, please start a Milvus server with the following command:
|
||||
#
|
||||
# ```shell
|
||||
# wget https://raw.githubusercontent.com/milvus-io/milvus/master/scripts/standalone_embed.sh
|
||||
# bash standalone_embed.sh start
|
||||
# ```
|
||||
#
|
||||
# Here is the reference:
|
||||
# https://milvus.io/docs/install_standalone-docker.md
|
||||
#
|
||||
@pytest.fixture
|
||||
def temp_milvus_db() -> Any:
|
||||
with tempfile.NamedTemporaryFile(suffix=".db") as temp_file:
|
||||
yield temp_file.name
|
||||
|
||||
|
||||
def _milvus_from_texts(
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
drop: bool = True,
|
||||
db_path: str = "./milvus_demo.db",
|
||||
**kwargs: Any,
|
||||
) -> Milvus:
|
||||
return Milvus.from_texts(
|
||||
fake_texts,
|
||||
FakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
# connection_args={"uri": "http://127.0.0.1:19530"},
|
||||
connection_args={"uri": db_path},
|
||||
drop_old=drop,
|
||||
consistency_level="Strong",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _get_pks(expr: str, docsearch: Milvus) -> List[Any]:
|
||||
return docsearch.get_pks(expr) # type: ignore[return-value]
|
||||
|
||||
|
||||
def test_milvus(temp_milvus_db: Any) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
docsearch = _milvus_from_texts(db_path=temp_milvus_db)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert_docs_equal_without_pk(output, [Document(page_content="foo")])
|
||||
|
||||
|
||||
def test_milvus_vector_search(temp_milvus_db: Any) -> None:
|
||||
"""Test end to end construction and search by vector."""
|
||||
docsearch = _milvus_from_texts(db_path=temp_milvus_db)
|
||||
output = docsearch.similarity_search_by_vector(
|
||||
FakeEmbeddings().embed_query("foo"), k=1
|
||||
)
|
||||
assert_docs_equal_without_pk(output, [Document(page_content="foo")])
|
||||
|
||||
|
||||
def test_milvus_with_metadata(temp_milvus_db: Any) -> None:
|
||||
"""Test with metadata"""
|
||||
docsearch = _milvus_from_texts(
|
||||
metadatas=[{"label": "test"}] * len(fake_texts), db_path=temp_milvus_db
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert_docs_equal_without_pk(
|
||||
output, [Document(page_content="foo", metadata={"label": "test"})]
|
||||
)
|
||||
|
||||
|
||||
def test_milvus_with_id(temp_milvus_db: Any) -> None:
|
||||
"""Test with ids"""
|
||||
ids = ["id_" + str(i) for i in range(len(fake_texts))]
|
||||
docsearch = _milvus_from_texts(ids=ids, db_path=temp_milvus_db)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert_docs_equal_without_pk(output, [Document(page_content="foo")])
|
||||
|
||||
output = docsearch.delete(ids=ids)
|
||||
assert output.delete_count == len(fake_texts) # type: ignore[attr-defined]
|
||||
|
||||
try:
|
||||
ids = ["dup_id" for _ in fake_texts]
|
||||
_milvus_from_texts(ids=ids, db_path=temp_milvus_db)
|
||||
except Exception as e:
|
||||
assert isinstance(e, AssertionError)
|
||||
|
||||
|
||||
def test_milvus_with_score(temp_milvus_db: Any) -> None:
|
||||
"""Test end to end construction and search with scores and IDs."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _milvus_from_texts(metadatas=metadatas, db_path=temp_milvus_db)
|
||||
output = docsearch.similarity_search_with_score("foo", k=3)
|
||||
docs = [o[0] for o in output]
|
||||
scores = [o[1] for o in output]
|
||||
assert_docs_equal_without_pk(
|
||||
docs,
|
||||
[
|
||||
Document(page_content="foo", metadata={"page": 0}),
|
||||
Document(page_content="bar", metadata={"page": 1}),
|
||||
Document(page_content="baz", metadata={"page": 2}),
|
||||
],
|
||||
)
|
||||
assert scores[0] < scores[1] < scores[2]
|
||||
|
||||
|
||||
def test_milvus_max_marginal_relevance_search(temp_milvus_db: Any) -> None:
|
||||
"""Test end to end construction and MRR search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _milvus_from_texts(metadatas=metadatas, db_path=temp_milvus_db)
|
||||
output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3)
|
||||
assert_docs_equal_without_pk(
|
||||
output,
|
||||
[
|
||||
Document(page_content="foo", metadata={"page": 0}),
|
||||
Document(page_content="bar", metadata={"page": 1}),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_milvus_max_marginal_relevance_search_with_dynamic_field(
|
||||
temp_milvus_db: Any,
|
||||
) -> None:
|
||||
"""Test end to end construction and MRR search with enabling dynamic field."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _milvus_from_texts(
|
||||
metadatas=metadatas, enable_dynamic_field=True, db_path=temp_milvus_db
|
||||
)
|
||||
output = docsearch.max_marginal_relevance_search("foo", k=2, fetch_k=3)
|
||||
assert_docs_equal_without_pk(
|
||||
output,
|
||||
[
|
||||
Document(page_content="foo", metadata={"page": 0}),
|
||||
Document(page_content="bar", metadata={"page": 1}),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_milvus_add_extra(temp_milvus_db: Any) -> None:
|
||||
"""Test end to end construction and MRR search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _milvus_from_texts(metadatas=metadatas, db_path=temp_milvus_db)
|
||||
|
||||
docsearch.add_texts(texts, metadatas)
|
||||
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 6
|
||||
|
||||
|
||||
def test_milvus_no_drop(temp_milvus_db: Any) -> None:
|
||||
"""Test construction without dropping old data."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _milvus_from_texts(metadatas=metadatas, db_path=temp_milvus_db)
|
||||
del docsearch
|
||||
|
||||
docsearch = _milvus_from_texts(
|
||||
metadatas=metadatas, drop=False, db_path=temp_milvus_db
|
||||
)
|
||||
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 6
|
||||
|
||||
|
||||
def test_milvus_get_pks(temp_milvus_db: Any) -> None:
|
||||
"""Test end to end construction and get pks with expr"""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"id": i} for i in range(len(texts))]
|
||||
docsearch = _milvus_from_texts(metadatas=metadatas, db_path=temp_milvus_db)
|
||||
expr = "id in [1,2]"
|
||||
output = _get_pks(expr, docsearch)
|
||||
assert len(output) == 2
|
||||
|
||||
|
||||
def test_milvus_delete_entities(temp_milvus_db: Any) -> None:
|
||||
"""Test end to end construction and delete entities"""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"id": i} for i in range(len(texts))]
|
||||
docsearch = _milvus_from_texts(metadatas=metadatas, db_path=temp_milvus_db)
|
||||
expr = "id in [1,2]"
|
||||
pks = _get_pks(expr, docsearch)
|
||||
result = docsearch.delete(pks)
|
||||
assert result.delete_count == 2 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_milvus_upsert_entities(temp_milvus_db: Any) -> None:
|
||||
"""Test end to end construction and upsert entities"""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"id": i} for i in range(len(texts))]
|
||||
docsearch = _milvus_from_texts(metadatas=metadatas, db_path=temp_milvus_db)
|
||||
expr = "id in [1,2]"
|
||||
pks = _get_pks(expr, docsearch)
|
||||
documents = [
|
||||
Document(page_content="test_1", metadata={"id": 1}),
|
||||
Document(page_content="test_2", metadata={"id": 3}),
|
||||
]
|
||||
ids = docsearch.upsert(pks, documents)
|
||||
assert len(ids) == 2 # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_milvus_enable_dynamic_field(temp_milvus_db: Any) -> None:
|
||||
"""Test end to end construction and enable dynamic field"""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"id": i} for i in range(len(texts))]
|
||||
docsearch = _milvus_from_texts(
|
||||
metadatas=metadatas, enable_dynamic_field=True, db_path=temp_milvus_db
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 3
|
||||
|
||||
# When enable dynamic field, any new field data will be added to the collection.
|
||||
new_metadatas = [{"id_new": i} for i in range(len(texts))]
|
||||
docsearch.add_texts(texts, new_metadatas)
|
||||
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 6
|
||||
|
||||
assert set(docsearch.fields) == {
|
||||
docsearch._primary_field,
|
||||
docsearch._text_field,
|
||||
docsearch._vector_field,
|
||||
}
|
||||
|
||||
|
||||
def test_milvus_disable_dynamic_field(temp_milvus_db: Any) -> None:
|
||||
"""Test end to end construction and disable dynamic field"""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"id": i} for i in range(len(texts))]
|
||||
docsearch = _milvus_from_texts(
|
||||
metadatas=metadatas, enable_dynamic_field=False, db_path=temp_milvus_db
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 3
|
||||
# ["pk", "text", "vector", "id"]
|
||||
assert set(docsearch.fields) == {
|
||||
docsearch._primary_field,
|
||||
docsearch._text_field,
|
||||
docsearch._vector_field,
|
||||
"id",
|
||||
}
|
||||
|
||||
# Try to add new fields "id_new", but since dynamic field is disabled,
|
||||
# all fields in the collection is specified as ["pk", "text", "vector", "id"],
|
||||
# new field information "id_new" will not be added.
|
||||
new_metadatas = [{"id": i, "id_new": i} for i in range(len(texts))]
|
||||
docsearch.add_texts(texts, new_metadatas)
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 6
|
||||
for doc in output:
|
||||
assert set(doc.metadata.keys()) == {"id", "pk"} # `id_new` is not added.
|
||||
|
||||
# When disable dynamic field,
|
||||
# missing data of the created fields "id", will raise an exception.
|
||||
with pytest.raises(Exception):
|
||||
new_metadatas = [{"id_new": i} for i in range(len(texts))]
|
||||
docsearch.add_texts(texts, new_metadatas)
|
||||
|
||||
|
||||
def test_milvus_metadata_field(temp_milvus_db: Any) -> None:
|
||||
"""Test end to end construction and use metadata field"""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"id": i} for i in range(len(texts))]
|
||||
docsearch = _milvus_from_texts(
|
||||
metadatas=metadatas, metadata_field="metadata", db_path=temp_milvus_db
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 3
|
||||
|
||||
new_metadatas = [{"id_new": i} for i in range(len(texts))]
|
||||
docsearch.add_texts(texts, new_metadatas)
|
||||
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 6
|
||||
|
||||
assert set(docsearch.fields) == {
|
||||
docsearch._primary_field,
|
||||
docsearch._text_field,
|
||||
docsearch._vector_field,
|
||||
docsearch._metadata_field,
|
||||
}
|
||||
|
||||
|
||||
def test_milvus_enable_dynamic_field_with_partition_key(temp_milvus_db: Any) -> None:
|
||||
"""
|
||||
Test end to end construction and enable dynamic field
|
||||
with partition_key_field
|
||||
"""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"id": i, "namespace": f"name_{i}"} for i in range(len(texts))]
|
||||
|
||||
docsearch = _milvus_from_texts(
|
||||
metadatas=metadatas,
|
||||
enable_dynamic_field=True,
|
||||
partition_key_field="namespace",
|
||||
db_path=temp_milvus_db,
|
||||
)
|
||||
|
||||
# filter on a single namespace
|
||||
output = docsearch.similarity_search("foo", k=10, expr="namespace == 'name_2'")
|
||||
assert len(output) == 1
|
||||
|
||||
# without namespace filter
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 3
|
||||
|
||||
assert set(docsearch.fields) == {
|
||||
docsearch._primary_field,
|
||||
docsearch._text_field,
|
||||
docsearch._vector_field,
|
||||
docsearch._partition_key_field,
|
||||
}
|
||||
|
||||
|
||||
def test_milvus_sparse_embeddings() -> None:
|
||||
texts = [
|
||||
"In 'The Clockwork Kingdom' by Augusta Wynter, a brilliant inventor discovers "
|
||||
"a hidden world of clockwork machines and ancient magic, where a rebellion is "
|
||||
"brewing against the tyrannical ruler of the land.",
|
||||
"In 'The Phantom Pilgrim' by Rowan Welles, a charismatic smuggler is hired by "
|
||||
"a mysterious organization to transport a valuable artifact across a war-torn "
|
||||
"continent, but soon finds themselves pursued by assassins and rival factions.",
|
||||
"In 'The Dreamwalker's Journey' by Lyra Snow, a young dreamwalker discovers "
|
||||
"she has the ability to enter people's dreams, but soon finds herself trapped "
|
||||
"in a surreal world of nightmares and illusions, where the boundaries between "
|
||||
"reality and fantasy blur.",
|
||||
]
|
||||
try:
|
||||
sparse_embedding_func = BM25SparseEmbedding(corpus=texts)
|
||||
except LookupError:
|
||||
import nltk # type: ignore[import]
|
||||
|
||||
nltk.download("punkt_tab")
|
||||
sparse_embedding_func = BM25SparseEmbedding(corpus=texts)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".db") as temp_db:
|
||||
docsearch = Milvus.from_texts(
|
||||
embedding=sparse_embedding_func,
|
||||
texts=texts,
|
||||
connection_args={"uri": temp_db.name},
|
||||
drop_old=True,
|
||||
)
|
||||
|
||||
output = docsearch.similarity_search("Pilgrim", k=1)
|
||||
assert "Pilgrim" in output[0].page_content
|
||||
|
||||
|
||||
def test_milvus_array_field(temp_milvus_db: Any) -> None:
|
||||
"""Manually specify metadata schema, including an array_field.
|
||||
For more information about array data type and filtering, please refer to
|
||||
https://milvus.io/docs/array_data_type.md
|
||||
"""
|
||||
from pymilvus import DataType
|
||||
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"id": i, "array_field": [i, i + 1, i + 2]} for i in range(len(texts))]
|
||||
|
||||
# Manually specify metadata schema, including an array_field.
|
||||
# If some fields are not specified, Milvus will automatically infer their schemas.
|
||||
docsearch = _milvus_from_texts(
|
||||
metadatas=metadatas,
|
||||
metadata_schema={
|
||||
"array_field": {
|
||||
"dtype": DataType.ARRAY,
|
||||
"kwargs": {"element_type": DataType.INT64, "max_capacity": 50},
|
||||
},
|
||||
# "id": {
|
||||
# "dtype": DataType.INT64,
|
||||
# }
|
||||
},
|
||||
db_path=temp_milvus_db,
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=10, expr="array_field[0] < 2")
|
||||
assert len(output) == 2
|
||||
output = docsearch.similarity_search(
|
||||
"foo", k=10, expr="ARRAY_CONTAINS(array_field, 3)"
|
||||
)
|
||||
assert len(output) == 2
|
||||
|
||||
# If we use enable_dynamic_field,
|
||||
# there is no need to manually specify metadata schema.
|
||||
docsearch = _milvus_from_texts(
|
||||
enable_dynamic_field=True,
|
||||
metadatas=metadatas,
|
||||
db_path=temp_milvus_db,
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=10, expr="array_field[0] < 2")
|
||||
assert len(output) == 2
|
||||
output = docsearch.similarity_search(
|
||||
"foo", k=10, expr="ARRAY_CONTAINS(array_field, 3)"
|
||||
)
|
||||
assert len(output) == 2
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# test_milvus()
|
||||
# test_milvus_vector_search()
|
||||
# test_milvus_with_metadata()
|
||||
# test_milvus_with_id()
|
||||
# test_milvus_with_score()
|
||||
# test_milvus_max_marginal_relevance_search()
|
||||
# test_milvus_max_marginal_relevance_search_with_dynamic_field()
|
||||
# test_milvus_add_extra()
|
||||
# test_milvus_no_drop()
|
||||
# test_milvus_get_pks()
|
||||
# test_milvus_delete_entities()
|
||||
# test_milvus_upsert_entities()
|
||||
# test_milvus_enable_dynamic_field()
|
||||
# test_milvus_disable_dynamic_field()
|
||||
# test_milvus_metadata_field()
|
||||
# test_milvus_enable_dynamic_field_with_partition_key()
|
||||
# test_milvus_sparse_embeddings()
|
||||
# test_milvus_array_field()
|
@ -1,12 +0,0 @@
|
||||
from langchain_milvus import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"Milvus",
|
||||
"MilvusCollectionHybridSearchRetriever",
|
||||
"Zilliz",
|
||||
"ZillizCloudPipelineRetriever",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
@ -1,17 +0,0 @@
|
||||
import os
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest.mock import Mock
|
||||
|
||||
from langchain_milvus.vectorstores import Milvus
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test integration milvus initialization."""
|
||||
embedding = Mock()
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
Milvus(
|
||||
embedding_function=embedding,
|
||||
connection_args={
|
||||
"uri": os.path.join(tmp_dir, "milvus.db"),
|
||||
},
|
||||
)
|
Loading…
Reference in New Issue
Block a user