mirror of https://github.com/hwchase17/langchain
FEAT: Merge TileDB vecstore (#12811)
commit
658a3a8607
@ -0,0 +1,178 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "25bce5eb-8599-40fe-947e-4932cfae8184",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# TileDB\n",
|
||||
"\n",
|
||||
"> [TileDB](https://github.com/TileDB-Inc/TileDB) is a powerful engine for indexing and querying dense and sparse multi-dimensional arrays.\n",
|
||||
"\n",
|
||||
"> TileDB offers ANN search capabilities using the [TileDB-Vector-Search](https://github.com/TileDB-Inc/TileDB-Vector-Search) module. It provides serverless execution of ANN queries and storage of vector indexes both on local disk and cloud object stores (i.e. AWS S3).\n",
|
||||
"\n",
|
||||
"More details in:\n",
|
||||
"- [Why TileDB as a Vector Database](https://tiledb.com/blog/why-tiledb-as-a-vector-database)\n",
|
||||
"- [TileDB 101: Vector Search](https://tiledb.com/blog/tiledb-101-vector-search)\n",
|
||||
"\n",
|
||||
"This notebook shows how to use the `TileDB` vector database."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f45f46f2-7229-4859-9797-30bbead1b8e0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install tiledb-vector-search"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2f65caa9-8383-409a-bccb-6e91fc8d5e8f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Basic Example"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c96d4fe0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import TextLoader\n",
|
||||
"from langchain.embeddings import HuggingFaceEmbeddings\n",
|
||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||
"from langchain.vectorstores import TileDB\n",
|
||||
"\n",
|
||||
"raw_documents = TextLoader(\"../../modules/state_of_the_union.txt\").load()\n",
|
||||
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||||
"documents = text_splitter.split_documents(raw_documents)\n",
|
||||
"embeddings = HuggingFaceEmbeddings()\n",
|
||||
"db = TileDB.from_documents(\n",
|
||||
" documents, embeddings, index_uri=\"/tmp/tiledb_index\", index_type=\"FLAT\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b0a6797c-2bb0-45db-a636-5d2437f7a4c0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"docs = db.similarity_search(query)\n",
|
||||
"docs[0].page_content"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "c4c4e06d-6def-44ce-ac9a-4c01673c29a2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Similarity search by vector"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1eb72610-d451-4158-880c-9f0d45fa5909",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embedding_vector = embeddings.embed_query(query)\n",
|
||||
"docs = db.similarity_search_by_vector(embedding_vector)\n",
|
||||
"docs[0].page_content"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d33588d4-67c2-4bd3-b251-76ae783cbafb",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Similarity search with score"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1a41e382-0336-4e6d-b2ef-44cc77db2696",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"docs_and_scores = db.similarity_search_with_score(query)\n",
|
||||
"docs_and_scores[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "57f930f2-41a0-4795-ad9e-44a33c8f88ec",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Maximal Marginal Relevance Search (MMR)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4790e437-3207-45cb-b121-d857ab5aabd8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In addition to using similarity search in the retriever object, you can also use `mmr` as retriever."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "495754b1-5cdb-4af6-9733-f68700bb7232",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever = db.as_retriever(search_type=\"mmr\")\n",
|
||||
"retriever.get_relevant_documents(query)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e213d957-e439-4bd6-90f2-8909323f5f09",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Or use `max_marginal_relevance_search` directly:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "99d928d0-3b79-4588-925e-32230e12af47",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"db.max_marginal_relevance_search(query, k=2, fetch_k=10)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.18"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1,789 @@
|
||||
"""Wrapper around TileDB vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
import random
|
||||
import sys
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.schema.embeddings import Embeddings
|
||||
from langchain.schema.vectorstore import VectorStore
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
INDEX_METRICS = frozenset(["euclidean"])
|
||||
DEFAULT_METRIC = "euclidean"
|
||||
DOCUMENTS_ARRAY_NAME = "documents"
|
||||
VECTOR_INDEX_NAME = "vectors"
|
||||
MAX_UINT64 = np.iinfo(np.dtype("uint64")).max
|
||||
MAX_FLOAT_32 = np.finfo(np.dtype("float32")).max
|
||||
MAX_FLOAT = sys.float_info.max
|
||||
|
||||
|
||||
def dependable_tiledb_import() -> Any:
|
||||
"""Import tiledb-vector-search if available, otherwise raise error."""
|
||||
try:
|
||||
import tiledb as tiledb
|
||||
import tiledb.vector_search as tiledb_vs
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import tiledb-vector-search python package. "
|
||||
"Please install it with `conda install -c tiledb tiledb-vector-search` "
|
||||
"or `pip install tiledb-vector-search`"
|
||||
)
|
||||
return tiledb_vs, tiledb
|
||||
|
||||
|
||||
def get_vector_index_uri_from_group(group: Any) -> str:
|
||||
return group[VECTOR_INDEX_NAME].uri
|
||||
|
||||
|
||||
def get_documents_array_uri_from_group(group: Any) -> str:
|
||||
return group[DOCUMENTS_ARRAY_NAME].uri
|
||||
|
||||
|
||||
def get_vector_index_uri(uri: str) -> str:
|
||||
return f"{uri}/{VECTOR_INDEX_NAME}"
|
||||
|
||||
|
||||
def get_documents_array_uri(uri: str) -> str:
|
||||
return f"{uri}/{DOCUMENTS_ARRAY_NAME}"
|
||||
|
||||
|
||||
class TileDB(VectorStore):
|
||||
"""Wrapper around TileDB vector database.
|
||||
|
||||
To use, you should have the ``tiledb-vector-search`` python package installed.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import TileDB
|
||||
embeddings = OpenAIEmbeddings()
|
||||
db = TileDB(embeddings, index_uri, metric)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding: Embeddings,
|
||||
index_uri: str,
|
||||
metric: str,
|
||||
*,
|
||||
vector_index_uri: str = "",
|
||||
docs_array_uri: str = "",
|
||||
config: Optional[Mapping[str, Any]] = None,
|
||||
timestamp: Any = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize with necessary components."""
|
||||
self.embedding = embedding
|
||||
self.embedding_function = embedding.embed_query
|
||||
self.index_uri = index_uri
|
||||
self.metric = metric
|
||||
self.config = config
|
||||
|
||||
tiledb_vs, tiledb = dependable_tiledb_import()
|
||||
with tiledb.scope_ctx(ctx_or_config=config):
|
||||
index_group = tiledb.Group(self.index_uri, "r")
|
||||
self.vector_index_uri = (
|
||||
vector_index_uri
|
||||
if vector_index_uri != ""
|
||||
else get_vector_index_uri_from_group(index_group)
|
||||
)
|
||||
self.docs_array_uri = (
|
||||
docs_array_uri
|
||||
if docs_array_uri != ""
|
||||
else get_documents_array_uri_from_group(index_group)
|
||||
)
|
||||
index_group.close()
|
||||
group = tiledb.Group(self.vector_index_uri, "r")
|
||||
self.index_type = group.meta.get("index_type")
|
||||
group.close()
|
||||
self.timestamp = timestamp
|
||||
if self.index_type == "FLAT":
|
||||
self.vector_index = tiledb_vs.flat_index.FlatIndex(
|
||||
uri=self.vector_index_uri,
|
||||
config=self.config,
|
||||
timestamp=self.timestamp,
|
||||
**kwargs,
|
||||
)
|
||||
elif self.index_type == "IVF_FLAT":
|
||||
self.vector_index = tiledb_vs.ivf_flat_index.IVFFlatIndex(
|
||||
uri=self.vector_index_uri,
|
||||
config=self.config,
|
||||
timestamp=self.timestamp,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
return self.embedding
|
||||
|
||||
def process_index_results(
|
||||
self,
|
||||
ids: List[int],
|
||||
scores: List[float],
|
||||
*,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
score_threshold: float = MAX_FLOAT,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Turns TileDB results into a list of documents and scores.
|
||||
|
||||
Args:
|
||||
ids: List of indices of the documents in the index.
|
||||
scores: List of distances of the documents in the index.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None.
|
||||
score_threshold: Optional, a floating point value to filter the
|
||||
resulting set of retrieved docs
|
||||
Returns:
|
||||
List of Documents and scores.
|
||||
"""
|
||||
tiledb_vs, tiledb = dependable_tiledb_import()
|
||||
docs = []
|
||||
docs_array = tiledb.open(
|
||||
self.docs_array_uri, "r", timestamp=self.timestamp, config=self.config
|
||||
)
|
||||
for idx, score in zip(ids, scores):
|
||||
if idx == 0 and score == 0:
|
||||
continue
|
||||
if idx == MAX_UINT64 and score == MAX_FLOAT_32:
|
||||
continue
|
||||
doc = docs_array[idx]
|
||||
if doc is None or len(doc["text"]) == 0:
|
||||
raise ValueError(f"Could not find document for id {idx}, got {doc}")
|
||||
pickled_metadata = doc.get("metadata")
|
||||
result_doc = Document(page_content=str(doc["text"][0]))
|
||||
if pickled_metadata is not None:
|
||||
metadata = pickle.loads(
|
||||
np.array(pickled_metadata.tolist()).astype(np.uint8).tobytes()
|
||||
)
|
||||
result_doc.metadata = metadata
|
||||
if filter is not None:
|
||||
filter = {
|
||||
key: [value] if not isinstance(value, list) else value
|
||||
for key, value in filter.items()
|
||||
}
|
||||
if all(
|
||||
result_doc.metadata.get(key) in value
|
||||
for key, value in filter.items()
|
||||
):
|
||||
docs.append((result_doc, score))
|
||||
else:
|
||||
docs.append((result_doc, score))
|
||||
docs_array.close()
|
||||
docs = [(doc, score) for doc, score in docs if score <= score_threshold]
|
||||
return docs[:k]
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
*,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
embedding: Embedding vector to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, Any]]): Filter by metadata. Defaults to None.
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
**kwargs: kwargs to be passed to similarity search. Can include:
|
||||
nprobe: Optional, number of partitions to check if using IVF_FLAT index
|
||||
score_threshold: Optional, a floating point value to filter the
|
||||
resulting set of retrieved docs
|
||||
|
||||
Returns:
|
||||
List of documents most similar to the query text and distance
|
||||
in float for each. Lower score represents more similarity.
|
||||
"""
|
||||
if "score_threshold" in kwargs:
|
||||
score_threshold = kwargs.pop("score_threshold")
|
||||
else:
|
||||
score_threshold = MAX_FLOAT
|
||||
d, i = self.vector_index.query(
|
||||
np.array([np.array(embedding).astype(np.float32)]).astype(np.float32),
|
||||
k=k if filter is None else fetch_k,
|
||||
**kwargs,
|
||||
)
|
||||
return self.process_index_results(
|
||||
ids=i[0], scores=d[0], filter=filter, k=k, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
|
||||
Returns:
|
||||
List of documents most similar to the query text with
|
||||
Distance as float. Lower score represents more similarity.
|
||||
"""
|
||||
embedding = self.embedding_function(query)
|
||||
docs = self.similarity_search_with_score_by_vector(
|
||||
embedding,
|
||||
k=k,
|
||||
filter=filter,
|
||||
fetch_k=fetch_k,
|
||||
**kwargs,
|
||||
)
|
||||
return docs
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the embedding.
|
||||
"""
|
||||
docs_and_scores = self.similarity_search_with_score_by_vector(
|
||||
embedding,
|
||||
k=k,
|
||||
filter=filter,
|
||||
fetch_k=fetch_k,
|
||||
**kwargs,
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
fetch_k: int = 20,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter: (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
fetch_k: (Optional[int]) Number of Documents to fetch before filtering.
|
||||
Defaults to 20.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query.
|
||||
"""
|
||||
docs_and_scores = self.similarity_search_with_score(
|
||||
query, k=k, filter=filter, fetch_k=fetch_k, **kwargs
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def max_marginal_relevance_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
*,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and their similarity scores selected using the maximal marginal
|
||||
relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch before filtering to
|
||||
pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
Returns:
|
||||
List of Documents and similarity scores selected by maximal marginal
|
||||
relevance and score for each.
|
||||
"""
|
||||
if "score_threshold" in kwargs:
|
||||
score_threshold = kwargs.pop("score_threshold")
|
||||
else:
|
||||
score_threshold = MAX_FLOAT
|
||||
scores, indices = self.vector_index.query(
|
||||
np.array([np.array(embedding).astype(np.float32)]).astype(np.float32),
|
||||
k=fetch_k if filter is None else fetch_k * 2,
|
||||
**kwargs,
|
||||
)
|
||||
results = self.process_index_results(
|
||||
ids=indices[0],
|
||||
scores=scores[0],
|
||||
filter=filter,
|
||||
k=fetch_k if filter is None else fetch_k * 2,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
embeddings = [
|
||||
self.embedding.embed_documents([doc.page_content])[0] for doc, _ in results
|
||||
]
|
||||
mmr_selected = maximal_marginal_relevance(
|
||||
np.array([embedding], dtype=np.float32),
|
||||
embeddings,
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
docs_and_scores = []
|
||||
for i in mmr_selected:
|
||||
docs_and_scores.append(results[i])
|
||||
return docs_and_scores
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch before filtering to
|
||||
pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
|
||||
embedding,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
**kwargs,
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch before filtering (if needed) to
|
||||
pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
embedding = self.embedding_function(query)
|
||||
docs = self.max_marginal_relevance_search_by_vector(
|
||||
embedding,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
**kwargs,
|
||||
)
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
index_uri: str,
|
||||
index_type: str,
|
||||
dimensions: int,
|
||||
vector_type: np.dtype,
|
||||
*,
|
||||
metadatas: bool = True,
|
||||
config: Optional[Mapping[str, Any]] = None,
|
||||
) -> None:
|
||||
tiledb_vs, tiledb = dependable_tiledb_import()
|
||||
with tiledb.scope_ctx(ctx_or_config=config):
|
||||
try:
|
||||
tiledb.group_create(index_uri)
|
||||
except tiledb.TileDBError as err:
|
||||
raise err
|
||||
group = tiledb.Group(index_uri, "w")
|
||||
vector_index_uri = get_vector_index_uri(group.uri)
|
||||
docs_uri = get_documents_array_uri(group.uri)
|
||||
if index_type == "FLAT":
|
||||
tiledb_vs.flat_index.create(
|
||||
uri=vector_index_uri,
|
||||
dimensions=dimensions,
|
||||
vector_type=vector_type,
|
||||
config=config,
|
||||
)
|
||||
elif index_type == "IVF_FLAT":
|
||||
tiledb_vs.ivf_flat_index.create(
|
||||
uri=vector_index_uri,
|
||||
dimensions=dimensions,
|
||||
vector_type=vector_type,
|
||||
config=config,
|
||||
)
|
||||
group.add(vector_index_uri, name=VECTOR_INDEX_NAME)
|
||||
|
||||
# Create TileDB array to store Documents
|
||||
# TODO add a Document store API to tiledb-vector-search to allow storing
|
||||
# different types of objects and metadata in a more generic way.
|
||||
dim = tiledb.Dim(
|
||||
name="id",
|
||||
domain=(0, MAX_UINT64 - 1),
|
||||
dtype=np.dtype(np.uint64),
|
||||
)
|
||||
dom = tiledb.Domain(dim)
|
||||
|
||||
text_attr = tiledb.Attr(name="text", dtype=np.dtype("U1"), var=True)
|
||||
attrs = [text_attr]
|
||||
if metadatas:
|
||||
metadata_attr = tiledb.Attr(name="metadata", dtype=np.uint8, var=True)
|
||||
attrs.append(metadata_attr)
|
||||
schema = tiledb.ArraySchema(
|
||||
domain=dom,
|
||||
sparse=True,
|
||||
allows_duplicates=False,
|
||||
attrs=attrs,
|
||||
)
|
||||
tiledb.Array.create(docs_uri, schema)
|
||||
group.add(docs_uri, name=DOCUMENTS_ARRAY_NAME)
|
||||
group.close()
|
||||
|
||||
@classmethod
|
||||
def __from(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embeddings: List[List[float]],
|
||||
embedding: Embeddings,
|
||||
index_uri: str,
|
||||
*,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
metric: str = DEFAULT_METRIC,
|
||||
index_type: str = "FLAT",
|
||||
config: Optional[Mapping[str, Any]] = None,
|
||||
index_timestamp: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> TileDB:
|
||||
if metric not in INDEX_METRICS:
|
||||
raise ValueError(
|
||||
(
|
||||
f"Unsupported distance metric: {metric}. "
|
||||
f"Expected one of {list(INDEX_METRICS)}"
|
||||
)
|
||||
)
|
||||
tiledb_vs, tiledb = dependable_tiledb_import()
|
||||
input_vectors = np.array(embeddings).astype(np.float32)
|
||||
cls.create(
|
||||
index_uri=index_uri,
|
||||
index_type=index_type,
|
||||
dimensions=input_vectors.shape[1],
|
||||
vector_type=input_vectors.dtype,
|
||||
metadatas=metadatas is not None,
|
||||
config=config,
|
||||
)
|
||||
with tiledb.scope_ctx(ctx_or_config=config):
|
||||
if not embeddings:
|
||||
raise ValueError("embeddings must be provided to build a TileDB index")
|
||||
|
||||
vector_index_uri = get_vector_index_uri(index_uri)
|
||||
docs_uri = get_documents_array_uri(index_uri)
|
||||
if ids is None:
|
||||
ids = [str(random.randint(0, MAX_UINT64 - 1)) for _ in texts]
|
||||
external_ids = np.array(ids).astype(np.uint64)
|
||||
|
||||
tiledb_vs.ingestion.ingest(
|
||||
index_type=index_type,
|
||||
index_uri=vector_index_uri,
|
||||
input_vectors=input_vectors,
|
||||
external_ids=external_ids,
|
||||
index_timestamp=index_timestamp if index_timestamp != 0 else None,
|
||||
config=config,
|
||||
**kwargs,
|
||||
)
|
||||
with tiledb.open(docs_uri, "w") as A:
|
||||
if external_ids is None:
|
||||
external_ids = np.zeros(len(texts), dtype=np.uint64)
|
||||
for i in range(len(texts)):
|
||||
external_ids[i] = i
|
||||
data = {}
|
||||
data["text"] = np.array(texts)
|
||||
if metadatas is not None:
|
||||
metadata_attr = np.empty([len(metadatas)], dtype=object)
|
||||
i = 0
|
||||
for metadata in metadatas:
|
||||
metadata_attr[i] = np.frombuffer(
|
||||
pickle.dumps(metadata), dtype=np.uint8
|
||||
)
|
||||
i += 1
|
||||
data["metadata"] = metadata_attr
|
||||
|
||||
A[external_ids] = data
|
||||
return cls(
|
||||
embedding=embedding,
|
||||
index_uri=index_uri,
|
||||
metric=metric,
|
||||
config=config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def delete(
|
||||
self, ids: Optional[List[str]] = None, timestamp: int = 0, **kwargs: Any
|
||||
) -> Optional[bool]:
|
||||
"""Delete by vector ID or other criteria.
|
||||
|
||||
Args:
|
||||
ids: List of ids to delete.
|
||||
timestamp: Optional timestamp to delete with.
|
||||
**kwargs: Other keyword arguments that subclasses might use.
|
||||
|
||||
Returns:
|
||||
Optional[bool]: True if deletion is successful,
|
||||
False otherwise, None if not implemented.
|
||||
"""
|
||||
|
||||
external_ids = np.array(ids).astype(np.uint64)
|
||||
self.vector_index.delete_batch(
|
||||
external_ids=external_ids, timestamp=timestamp if timestamp != 0 else None
|
||||
)
|
||||
return True
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
timestamp: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
texts: Iterable of strings to add to the vectorstore.
|
||||
metadatas: Optional list of metadatas associated with the texts.
|
||||
ids: Optional ids of each text object.
|
||||
timestamp: Optional timestamp to write new texts with.
|
||||
kwargs: vectorstore specific parameters
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the vectorstore.
|
||||
"""
|
||||
tiledb_vs, tiledb = dependable_tiledb_import()
|
||||
embeddings = self.embedding.embed_documents(list(texts))
|
||||
if ids is None:
|
||||
ids = [str(random.randint(0, MAX_UINT64 - 1)) for _ in texts]
|
||||
|
||||
external_ids = np.array(ids).astype(np.uint64)
|
||||
vectors = np.empty((len(embeddings)), dtype="O")
|
||||
for i in range(len(embeddings)):
|
||||
vectors[i] = np.array(embeddings[i], dtype=np.float32)
|
||||
self.vector_index.update_batch(
|
||||
vectors=vectors,
|
||||
external_ids=external_ids,
|
||||
timestamp=timestamp if timestamp != 0 else None,
|
||||
)
|
||||
|
||||
docs = {}
|
||||
docs["text"] = np.array(texts)
|
||||
if metadatas is not None:
|
||||
metadata_attr = np.empty([len(metadatas)], dtype=object)
|
||||
i = 0
|
||||
for metadata in metadatas:
|
||||
metadata_attr[i] = np.frombuffer(pickle.dumps(metadata), dtype=np.uint8)
|
||||
i += 1
|
||||
docs["metadata"] = metadata_attr
|
||||
|
||||
docs_array = tiledb.open(
|
||||
self.docs_array_uri,
|
||||
"w",
|
||||
timestamp=timestamp if timestamp != 0 else None,
|
||||
config=self.config,
|
||||
)
|
||||
docs_array[external_ids] = docs
|
||||
docs_array.close()
|
||||
return ids
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls,
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
metric: str = DEFAULT_METRIC,
|
||||
index_uri: str = "/tmp/tiledb_array",
|
||||
index_type: str = "FLAT",
|
||||
config: Optional[Mapping[str, Any]] = None,
|
||||
index_timestamp: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> TileDB:
|
||||
"""Construct a TileDB index from raw documents.
|
||||
|
||||
Args:
|
||||
texts: List of documents to index.
|
||||
embedding: Embedding function to use.
|
||||
metadatas: List of metadata dictionaries to associate with documents.
|
||||
ids: Optional ids of each text object.
|
||||
metric: Metric to use for indexing. Defaults to "euclidean".
|
||||
index_uri: The URI to write the TileDB arrays
|
||||
index_type: Optional, Vector index type ("FLAT", IVF_FLAT")
|
||||
config: Optional, TileDB config
|
||||
index_timestamp: Optional, timestamp to write new texts with.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import TileDB
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
embeddings = OpenAIEmbeddings()
|
||||
index = TileDB.from_texts(texts, embeddings)
|
||||
"""
|
||||
embeddings = []
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
return cls.__from(
|
||||
texts=texts,
|
||||
embeddings=embeddings,
|
||||
embedding=embedding,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
metric=metric,
|
||||
index_uri=index_uri,
|
||||
index_type=index_type,
|
||||
config=config,
|
||||
index_timestamp=index_timestamp,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_embeddings(
|
||||
cls,
|
||||
text_embeddings: List[Tuple[str, List[float]]],
|
||||
embedding: Embeddings,
|
||||
index_uri: str,
|
||||
*,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
metric: str = DEFAULT_METRIC,
|
||||
index_type: str = "FLAT",
|
||||
config: Optional[Mapping[str, Any]] = None,
|
||||
index_timestamp: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> TileDB:
|
||||
"""Construct TileDB index from embeddings.
|
||||
|
||||
Args:
|
||||
text_embeddings: List of tuples of (text, embedding)
|
||||
embedding: Embedding function to use.
|
||||
index_uri: The URI to write the TileDB arrays
|
||||
metadatas: List of metadata dictionaries to associate with documents.
|
||||
metric: Optional, Metric to use for indexing. Defaults to "euclidean".
|
||||
index_type: Optional, Vector index type ("FLAT", IVF_FLAT")
|
||||
config: Optional, TileDB config
|
||||
index_timestamp: Optional, timestamp to write new texts with.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import TileDB
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
embeddings = OpenAIEmbeddings()
|
||||
text_embeddings = embeddings.embed_documents(texts)
|
||||
text_embedding_pairs = list(zip(texts, text_embeddings))
|
||||
db = TileDB.from_embeddings(text_embedding_pairs, embeddings)
|
||||
"""
|
||||
texts = [t[0] for t in text_embeddings]
|
||||
embeddings = [t[1] for t in text_embeddings]
|
||||
|
||||
return cls.__from(
|
||||
texts=texts,
|
||||
embeddings=embeddings,
|
||||
embedding=embedding,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
metric=metric,
|
||||
index_uri=index_uri,
|
||||
index_type=index_type,
|
||||
config=config,
|
||||
index_timestamp=index_timestamp,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
index_uri: str,
|
||||
embedding: Embeddings,
|
||||
*,
|
||||
metric: str = DEFAULT_METRIC,
|
||||
config: Optional[Mapping[str, Any]] = None,
|
||||
timestamp: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> TileDB:
|
||||
"""Load a TileDB index from a URI.
|
||||
|
||||
Args:
|
||||
index_uri: The URI of the TileDB vector index.
|
||||
embedding: Embeddings to use when generating queries.
|
||||
metric: Optional, Metric to use for indexing. Defaults to "euclidean".
|
||||
config: Optional, TileDB config
|
||||
timestamp: Optional, timestamp to use for opening the arrays.
|
||||
"""
|
||||
return cls(
|
||||
embedding=embedding,
|
||||
index_uri=index_uri,
|
||||
metric=metric,
|
||||
config=config,
|
||||
timestamp=timestamp,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def consolidate_updates(self, **kwargs: Any) -> None:
|
||||
self.vector_index = self.vector_index.consolidate_updates(**kwargs)
|
@ -0,0 +1,358 @@
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.vectorstores.tiledb import TileDB
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
ConsistentFakeEmbeddings,
|
||||
FakeEmbeddings,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.requires("tiledb-vector-search")
|
||||
def test_tiledb(tmp_path: Path) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = TileDB.from_texts(
|
||||
texts=texts,
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
index_uri=f"{str(tmp_path)}/flat",
|
||||
index_type="FLAT",
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
docsearch = TileDB.from_texts(
|
||||
texts=texts,
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
index_uri=f"{str(tmp_path)}/ivf_flat",
|
||||
index_type="IVF_FLAT",
|
||||
)
|
||||
output = docsearch.similarity_search(
|
||||
"foo", k=1, nprobe=docsearch.vector_index.partitions
|
||||
)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
@pytest.mark.requires("tiledb-vector-search")
|
||||
def test_tiledb_vector_sim(tmp_path: Path) -> None:
|
||||
"""Test vector similarity."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = TileDB.from_texts(
|
||||
texts=texts,
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
index_uri=f"{str(tmp_path)}/flat",
|
||||
index_type="FLAT",
|
||||
)
|
||||
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.similarity_search_by_vector(query_vec, k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
docsearch = TileDB.from_texts(
|
||||
texts=texts,
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
index_uri=f"{str(tmp_path)}/ivf_flat",
|
||||
index_type="IVF_FLAT",
|
||||
)
|
||||
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.similarity_search_by_vector(
|
||||
query_vec, k=1, nprobe=docsearch.vector_index.partitions
|
||||
)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
@pytest.mark.requires("tiledb-vector-search")
|
||||
def test_tiledb_vector_sim_with_score_threshold(tmp_path: Path) -> None:
|
||||
"""Test vector similarity."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = TileDB.from_texts(
|
||||
texts=texts,
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
index_uri=f"{str(tmp_path)}/flat",
|
||||
index_type="FLAT",
|
||||
)
|
||||
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.similarity_search_by_vector(query_vec, k=2, score_threshold=0.2)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
docsearch = TileDB.from_texts(
|
||||
texts=texts,
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
index_uri=f"{str(tmp_path)}/ivf_flat",
|
||||
index_type="IVF_FLAT",
|
||||
)
|
||||
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.similarity_search_by_vector(
|
||||
query_vec, k=2, score_threshold=0.2, nprobe=docsearch.vector_index.partitions
|
||||
)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
@pytest.mark.requires("tiledb-vector-search")
|
||||
def test_similarity_search_with_score_by_vector(tmp_path: Path) -> None:
|
||||
"""Test vector similarity with score by vector."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = TileDB.from_texts(
|
||||
texts=texts,
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
index_uri=f"{str(tmp_path)}/flat",
|
||||
index_type="FLAT",
|
||||
)
|
||||
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.similarity_search_with_score_by_vector(query_vec, k=1)
|
||||
assert len(output) == 1
|
||||
assert output[0][0] == Document(page_content="foo")
|
||||
|
||||
docsearch = TileDB.from_texts(
|
||||
texts=texts,
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
index_uri=f"{str(tmp_path)}/ivf_flat",
|
||||
index_type="IVF_FLAT",
|
||||
)
|
||||
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.similarity_search_with_score_by_vector(
|
||||
query_vec, k=1, nprobe=docsearch.vector_index.partitions
|
||||
)
|
||||
assert len(output) == 1
|
||||
assert output[0][0] == Document(page_content="foo")
|
||||
|
||||
|
||||
@pytest.mark.requires("tiledb-vector-search")
|
||||
def test_similarity_search_with_score_by_vector_with_score_threshold(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test vector similarity with score by vector."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = TileDB.from_texts(
|
||||
texts=texts,
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
index_uri=f"{str(tmp_path)}/flat",
|
||||
index_type="FLAT",
|
||||
)
|
||||
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.similarity_search_with_score_by_vector(
|
||||
query_vec,
|
||||
k=2,
|
||||
score_threshold=0.2,
|
||||
)
|
||||
assert len(output) == 1
|
||||
assert output[0][0] == Document(page_content="foo")
|
||||
assert output[0][1] < 0.2
|
||||
|
||||
docsearch = TileDB.from_texts(
|
||||
texts=texts,
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
index_uri=f"{str(tmp_path)}/ivf_flat",
|
||||
index_type="IVF_FLAT",
|
||||
)
|
||||
query_vec = FakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.similarity_search_with_score_by_vector(
|
||||
query_vec, k=2, score_threshold=0.2, nprobe=docsearch.vector_index.partitions
|
||||
)
|
||||
assert len(output) == 1
|
||||
assert output[0][0] == Document(page_content="foo")
|
||||
assert output[0][1] < 0.2
|
||||
|
||||
|
||||
@pytest.mark.requires("tiledb-vector-search")
|
||||
def test_tiledb_mmr(tmp_path: Path) -> None:
|
||||
texts = ["foo", "foo", "fou", "foy"]
|
||||
docsearch = TileDB.from_texts(
|
||||
texts=texts,
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
index_uri=f"{str(tmp_path)}/flat",
|
||||
index_type="FLAT",
|
||||
)
|
||||
query_vec = ConsistentFakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.max_marginal_relevance_search_with_score_by_vector(
|
||||
query_vec, k=3, lambda_mult=0.1
|
||||
)
|
||||
assert output[0][0] == Document(page_content="foo")
|
||||
assert output[0][1] == 0.0
|
||||
assert output[1][0] != Document(page_content="foo")
|
||||
assert output[2][0] != Document(page_content="foo")
|
||||
|
||||
docsearch = TileDB.from_texts(
|
||||
texts=texts,
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
index_uri=f"{str(tmp_path)}/ivf_flat",
|
||||
index_type="IVF_FLAT",
|
||||
)
|
||||
query_vec = ConsistentFakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.max_marginal_relevance_search_with_score_by_vector(
|
||||
query_vec, k=3, lambda_mult=0.1, nprobe=docsearch.vector_index.partitions
|
||||
)
|
||||
assert output[0][0] == Document(page_content="foo")
|
||||
assert output[0][1] == 0.0
|
||||
assert output[1][0] != Document(page_content="foo")
|
||||
assert output[2][0] != Document(page_content="foo")
|
||||
|
||||
|
||||
@pytest.mark.requires("tiledb-vector-search")
|
||||
def test_tiledb_mmr_with_metadatas_and_filter(tmp_path: Path) -> None:
|
||||
texts = ["foo", "foo", "fou", "foy"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = TileDB.from_texts(
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
index_uri=f"{str(tmp_path)}/flat",
|
||||
index_type="FLAT",
|
||||
)
|
||||
query_vec = ConsistentFakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.max_marginal_relevance_search_with_score_by_vector(
|
||||
query_vec, k=3, lambda_mult=0.1, filter={"page": 1}
|
||||
)
|
||||
assert len(output) == 1
|
||||
assert output[0][0] == Document(page_content="foo", metadata={"page": 1})
|
||||
assert output[0][1] == 0.0
|
||||
|
||||
docsearch = TileDB.from_texts(
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
index_uri=f"{str(tmp_path)}/ivf_flat",
|
||||
index_type="IVF_FLAT",
|
||||
)
|
||||
query_vec = ConsistentFakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.max_marginal_relevance_search_with_score_by_vector(
|
||||
query_vec,
|
||||
k=3,
|
||||
lambda_mult=0.1,
|
||||
filter={"page": 1},
|
||||
nprobe=docsearch.vector_index.partitions,
|
||||
)
|
||||
assert len(output) == 1
|
||||
assert output[0][0] == Document(page_content="foo", metadata={"page": 1})
|
||||
assert output[0][1] == 0.0
|
||||
|
||||
|
||||
@pytest.mark.requires("tiledb-vector-search")
|
||||
def test_tiledb_mmr_with_metadatas_and_list_filter(tmp_path: Path) -> None:
|
||||
texts = ["foo", "fou", "foy", "foo"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = TileDB.from_texts(
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
index_uri=f"{str(tmp_path)}/flat",
|
||||
index_type="FLAT",
|
||||
)
|
||||
query_vec = ConsistentFakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.max_marginal_relevance_search_with_score_by_vector(
|
||||
query_vec, k=3, lambda_mult=0.1, filter={"page": [0, 1, 2]}
|
||||
)
|
||||
assert len(output) == 3
|
||||
assert output[0][0] == Document(page_content="foo", metadata={"page": 0})
|
||||
assert output[0][1] == 0.0
|
||||
assert output[1][0] != Document(page_content="foo", metadata={"page": 0})
|
||||
assert output[2][0] != Document(page_content="foo", metadata={"page": 0})
|
||||
|
||||
docsearch = TileDB.from_texts(
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=ConsistentFakeEmbeddings(),
|
||||
index_uri=f"{str(tmp_path)}/ivf_flat",
|
||||
index_type="IVF_FLAT",
|
||||
)
|
||||
query_vec = ConsistentFakeEmbeddings().embed_query(text="foo")
|
||||
output = docsearch.max_marginal_relevance_search_with_score_by_vector(
|
||||
query_vec,
|
||||
k=3,
|
||||
lambda_mult=0.1,
|
||||
filter={"page": [0, 1, 2]},
|
||||
nprobe=docsearch.vector_index.partitions,
|
||||
)
|
||||
assert len(output) == 3
|
||||
assert output[0][0] == Document(page_content="foo", metadata={"page": 0})
|
||||
assert output[0][1] == 0.0
|
||||
assert output[1][0] != Document(page_content="foo", metadata={"page": 0})
|
||||
assert output[2][0] != Document(page_content="foo", metadata={"page": 0})
|
||||
|
||||
|
||||
@pytest.mark.requires("tiledb-vector-search")
|
||||
def test_tiledb_flat_updates(tmp_path: Path) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
dimensions = 10
|
||||
index_uri = str(tmp_path)
|
||||
embedding = ConsistentFakeEmbeddings(dimensionality=dimensions)
|
||||
TileDB.create(
|
||||
index_uri=index_uri,
|
||||
index_type="FLAT",
|
||||
dimensions=dimensions,
|
||||
vector_type=np.dtype("float32"),
|
||||
metadatas=False,
|
||||
)
|
||||
docsearch = TileDB.load(
|
||||
index_uri=index_uri,
|
||||
embedding=embedding,
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=2)
|
||||
assert output == []
|
||||
|
||||
docsearch.add_texts(texts=["foo", "bar", "baz"], ids=["1", "2", "3"])
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
docsearch.delete(["1", "3"])
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="bar")]
|
||||
output = docsearch.similarity_search("baz", k=1)
|
||||
assert output == [Document(page_content="bar")]
|
||||
|
||||
docsearch.add_texts(texts=["fooo", "bazz"], ids=["4", "5"])
|
||||
output = docsearch.similarity_search("fooo", k=1)
|
||||
assert output == [Document(page_content="fooo")]
|
||||
output = docsearch.similarity_search("bazz", k=1)
|
||||
assert output == [Document(page_content="bazz")]
|
||||
|
||||
docsearch.consolidate_updates()
|
||||
output = docsearch.similarity_search("fooo", k=1)
|
||||
assert output == [Document(page_content="fooo")]
|
||||
output = docsearch.similarity_search("bazz", k=1)
|
||||
assert output == [Document(page_content="bazz")]
|
||||
|
||||
|
||||
@pytest.mark.requires("tiledb-vector-search")
|
||||
def test_tiledb_ivf_flat_updates(tmp_path: Path) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
dimensions = 10
|
||||
index_uri = str(tmp_path)
|
||||
embedding = ConsistentFakeEmbeddings(dimensionality=dimensions)
|
||||
TileDB.create(
|
||||
index_uri=index_uri,
|
||||
index_type="IVF_FLAT",
|
||||
dimensions=dimensions,
|
||||
vector_type=np.dtype("float32"),
|
||||
metadatas=False,
|
||||
)
|
||||
docsearch = TileDB.load(
|
||||
index_uri=index_uri,
|
||||
embedding=embedding,
|
||||
)
|
||||
output = docsearch.similarity_search("foo", k=2)
|
||||
assert output == []
|
||||
|
||||
docsearch.add_texts(texts=["foo", "bar", "baz"], ids=["1", "2", "3"])
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
docsearch.delete(["1", "3"])
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="bar")]
|
||||
output = docsearch.similarity_search("baz", k=1)
|
||||
assert output == [Document(page_content="bar")]
|
||||
|
||||
docsearch.add_texts(texts=["fooo", "bazz"], ids=["4", "5"])
|
||||
output = docsearch.similarity_search("fooo", k=1)
|
||||
assert output == [Document(page_content="fooo")]
|
||||
output = docsearch.similarity_search("bazz", k=1)
|
||||
assert output == [Document(page_content="bazz")]
|
||||
|
||||
docsearch.consolidate_updates()
|
||||
output = docsearch.similarity_search("fooo", k=1)
|
||||
assert output == [Document(page_content="fooo")]
|
||||
output = docsearch.similarity_search("bazz", k=1)
|
||||
assert output == [Document(page_content="bazz")]
|
Loading…
Reference in New Issue