LanceDB integration update (#22869)

Added : 

- [x] relevance search (w/wo scores)
- [x] maximal marginal search
- [x] image ingestion
- [x] filtering support
- [x] hybrid search w reranking 

make test, lint_diff and format checked.
pull/23068/head
Raghav Dixit 2 weeks ago committed by GitHub
parent 62c8a67f56
commit 55705c0f5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -12,6 +12,16 @@
"This notebook shows how to use functionality related to the `LanceDB` vector database based on the Lance data format."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1051ba9",
"metadata": {},
"outputs": [],
"source": [
"! pip install tantivy"
]
},
{
"cell_type": "code",
"execution_count": null,
@ -29,7 +39,7 @@
"metadata": {},
"outputs": [],
"source": [
"! pip install lancedb"
"! pip install lancedb==0.6.13"
]
},
{
@ -42,7 +52,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "a0361f5c-e6f4-45f4-b829-11680cf03cec",
"metadata": {
"tags": []
@ -57,7 +67,17 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 2,
"id": "d114ed78",
"metadata": {},
"outputs": [],
"source": [
"! rm -rf /tmp/lancedb"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a3c3999a",
"metadata": {},
"outputs": [],
@ -94,49 +114,38 @@
" embedding=embeddings,\n",
" table_name='langchain_test'\n",
" )\n",
"```\n"
"```\n",
"\n",
"You can also add `region`, `api_key`, `uri` to `from_documents()` classmethod\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 4,
"id": "6e104aee",
"metadata": {},
"outputs": [],
"source": [
"docsearch = LanceDB.from_documents(documents, embeddings)\n",
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
"docs = docsearch.similarity_search(query)"
]
},
{
"cell_type": "markdown",
"id": "f5e1cdfd",
"metadata": {},
"source": [
"Additionaly, to explore the table you can load it into a df or save it in a csv file: \n",
"```python\n",
"tbl = docsearch.get_table()\n",
"print(\"tbl:\", tbl)\n",
"pd_df = tbl.to_pandas()\n",
"# pd_df.to_csv(\"docsearch.csv\", index=False)\n",
"from lancedb.rerankers import LinearCombinationReranker\n",
"\n",
"# you can also create a new vector store object using an older connection object:\n",
"vector_store = LanceDB(connection=tbl, embedding=embeddings)\n",
"```"
"reranker = LinearCombinationReranker(weight=0.3)\n",
"\n",
"docsearch = LanceDB.from_documents(documents, embeddings, reranker=reranker)\n",
"query = \"What did the president say about Ketanji Brown Jackson\""
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "9c608226",
"execution_count": 31,
"id": "259c7988",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"They were responding to a 9-1-1 call when a man shot and killed them with a stolen gun. \n",
"relevance score - 0.7066475030191711\n",
"text- They were responding to a 9-1-1 call when a man shot and killed them with a stolen gun. \n",
"\n",
"Officer Mora was 27 years old. \n",
"\n",
@ -156,62 +165,112 @@
"\n",
"Thats why the Justice Department required body cameras, banned chokeholds, and restricted no-knock warrants for its officers. \n",
"\n",
"Thats why the American Rescue Plan provided $350 Billion that cities, states, and counties can use to hire more police and invest in proven strategies like community violence interruption—trusted messengers breaking the cycle of violence and trauma and giving young people hope. \n",
"\n",
"We should all agree: The answer is not to Defund the police. The answer is to FUND the police with the resources and training they need to protect our communities. \n",
"\n",
"I ask Democrats and Republicans alike: Pass my budget and keep our neighborhoods safe. \n",
"\n",
"And I will keep doing everything in my power to crack down on gun trafficking and ghost guns you can buy online and make at home—they have no serial numbers and cant be traced. \n",
"\n",
"And I ask Congress to pass proven measures to reduce gun violence. Pass universal background checks. Why should anyone on a terrorist list be able to purchase a weapon? \n",
"\n",
"Ban assault weapons and high-capacity magazines. \n",
"\n",
"Repeal the liability shield that makes gun manufacturers the only industry in America that cant be sued. \n",
"Thats why the American Rescue \n"
]
}
],
"source": [
"docs = docsearch.similarity_search_with_relevance_scores(query)\n",
"print(\"relevance score - \", docs[0][1])\n",
"print(\"text- \", docs[0][0].page_content[:1000])"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "9fa29dae",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"distance - 0.30000001192092896\n",
"text- My administration is providing assistance with job training and housing, and now helping lower-income veterans get VA care debt-free. \n",
"\n",
"These laws dont infringe on the Second Amendment. They save lives. \n",
"Our troops in Iraq and Afghanistan faced many dangers. \n",
"\n",
"The most fundamental right in America is the right to vote and to have it counted. And its under assault. \n",
"One was stationed at bases and breathing in toxic smoke from “burn pits” that incinerated wastes of war—medical and hazard material, jet fuel, and more. \n",
"\n",
"In state after state, new laws have been passed, not only to suppress the vote, but to subvert entire elections. \n",
"When they came home, many of the worlds fittest and best trained warriors were never the same. \n",
"\n",
"We cannot let this happen. \n",
"Headaches. Numbness. Dizziness. \n",
"\n",
"Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while youre at it, pass the Disclose Act so Americans can know who is funding our elections. \n",
"A cancer that would put them in a flag-draped coffin. \n",
"\n",
"Tonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \n",
"I know. \n",
"\n",
"One of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \n",
"One of those soldiers was my son Major Beau Biden. \n",
"\n",
"And I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence. \n",
"We dont know for sure if a burn pit was the cause of his brain cancer, or the diseases of so many of our troops. \n",
"\n",
"A former top litigator in private practice. A former federal public defender. And from a family of public school educators and police officers. A consensus builder. Since shes been nominated, shes received a broad range of support—from the Fraternal Order of Police to former judges appointed by Democrats and Republicans. \n",
"But Im committed to finding out everything we can. \n",
"\n",
"And if we are to advance liberty and justice, we need to secure the Border and fix the immigration system. \n",
"Committed to military families like Danielle Robinson from Ohio. \n",
"\n",
"We can do both. At our border, weve installed new technology like cutting-edge scanners to better detect drug smuggling. \n",
"The widow of Sergeant First Class Heath Robinson. \n",
"\n",
"Weve set up joint patrols with Mexico and Guatemala to catch more human traffickers. \n",
"He was born a soldier. Army National Guard. Combat medic in Kosovo and Iraq. \n",
"\n",
"Were putting in place dedicated immigration judges so families fleeing persecution and violence can have their cases heard faster.\n"
"Stationed near Baghdad, just ya\n"
]
}
],
"source": [
"print(docs[0].page_content)"
"docs = docsearch.similarity_search_with_score(query=\"Headaches\", query_type=\"hybrid\")\n",
"print(\"distance - \", docs[0][1])\n",
"print(\"text- \", docs[0][0].page_content[:1000])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "a359ed74",
"execution_count": 8,
"id": "e70ad201",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"reranker : <lancedb.rerankers.linear_combination.LinearCombinationReranker object at 0x107ef1130>\n"
]
}
],
"source": [
"print(\"reranker : \", docsearch._reranker)"
]
},
{
"cell_type": "markdown",
"id": "f5e1cdfd",
"metadata": {},
"source": [
"Additionaly, to explore the table you can load it into a df or save it in a csv file: \n",
"```python\n",
"tbl = docsearch.get_table()\n",
"print(\"tbl:\", tbl)\n",
"pd_df = tbl.to_pandas()\n",
"# pd_df.to_csv(\"docsearch.csv\", index=False)\n",
"\n",
"# you can also create a new vector store object using an older connection object:\n",
"vector_store = LanceDB(connection=tbl, embedding=embeddings)\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "9c608226",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"metadata : {'source': '../../how_to/state_of_the_union.txt'}\n",
"\n",
"SQL filtering :\n",
"\n",
"They were responding to a 9-1-1 call when a man shot and killed them with a stolen gun. \n",
"\n",
"Officer Mora was 27 years old. \n",
@ -275,17 +334,211 @@
}
],
"source": [
"docs = docsearch.similarity_search(\n",
" query=query, filter={\"metadata.source\": \"../../how_to/state_of_the_union.txt\"}\n",
")\n",
"\n",
"print(\"metadata :\", docs[0].metadata)\n",
"\n",
"# or you can directly supply SQL string filters :\n",
"\n",
"print(\"\\nSQL filtering :\\n\")\n",
"docs = docsearch.similarity_search(query=query, filter=\"text LIKE '%Officer Rivera%'\")\n",
"print(docs[0].page_content)"
]
},
{
"cell_type": "markdown",
"id": "9a173c94",
"metadata": {},
"source": [
"## Adding images "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9d749a3f-df17-4a8a-b256-08a3bbc74cb6",
"id": "05f669d7",
"metadata": {},
"outputs": [],
"source": [
"! pip install -U langchain-experimental"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3ed69810",
"metadata": {},
"outputs": [],
"source": [
"! pip install open_clip_torch torch"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "2cacb5ee",
"metadata": {},
"outputs": [],
"source": [
"! rm -rf '/tmp/multimmodal_lance'"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "b3456e2c",
"metadata": {},
"outputs": [],
"source": [
"from langchain_experimental.open_clip import OpenCLIPEmbeddings"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "3848eba2",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"import requests\n",
"\n",
"# List of image URLs to download\n",
"image_urls = [\n",
" \"https://github.com/raghavdixit99/assets/assets/34462078/abf47cc4-d979-4aaa-83be-53a2115bf318\",\n",
" \"https://github.com/raghavdixit99/assets/assets/34462078/93be928e-522b-4e37-889d-d4efd54b2112\",\n",
"]\n",
"\n",
"texts = [\"bird\", \"dragon\"]\n",
"\n",
"# Directory to save images\n",
"dir_name = \"./photos/\"\n",
"\n",
"# Create directory if it doesn't exist\n",
"os.makedirs(dir_name, exist_ok=True)\n",
"\n",
"image_uris = []\n",
"# Download and save each image\n",
"for i, url in enumerate(image_urls, start=1):\n",
" response = requests.get(url)\n",
" path = os.path.join(dir_name, f\"image{i}.jpg\")\n",
" image_uris.append(path)\n",
" with open(path, \"wb\") as f:\n",
" f.write(response.content)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "3d62c2a0",
"metadata": {},
"outputs": [],
"source": [
"print(docs[0].metadata)"
"from langchain_community.vectorstores import LanceDB\n",
"\n",
"vec_store = LanceDB(\n",
" table_name=\"multimodal_test\",\n",
" embedding=OpenCLIPEmbeddings(),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "ebbb4881",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['b673620b-01f0-42ca-a92e-d033bb92c0a6',\n",
" '99c3a5b0-b577-417a-8177-92f4a655dbfb']"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vec_store.add_images(uris=image_uris)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "3c29dea3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['f7adde5d-a4a3-402b-9e73-088b230722c3',\n",
" 'cbed59da-0aec-4bff-8820-9e59d81a2140']"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vec_store.add_texts(texts)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "8b2f25ce",
"metadata": {},
"outputs": [],
"source": [
"img_embed = vec_store._embedding.embed_query(\"bird\")"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "87a24079",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Document(page_content='bird', metadata={'id': 'f7adde5d-a4a3-402b-9e73-088b230722c3'})"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vec_store.similarity_search_by_vector(img_embed)[0]"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "78557867",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"LanceTable(connection=LanceDBConnection(/tmp/lancedb), name=\"multimodal_test\")"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vec_store._table"
]
}
],
@ -305,7 +558,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.2"
}
},
"nbformat": 4,

@ -1,21 +1,32 @@
from __future__ import annotations
import base64
import os
import uuid
import warnings
from typing import Any, Iterable, List, Optional
from typing import Any, Callable, Dict, Iterable, List, Optional, Type
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import guard_import
from langchain_core.vectorstores import VectorStore
from langchain_community.vectorstores.utils import maximal_marginal_relevance
DEFAULT_K = 4 # Number of Documents to return.
def import_lancedb() -> Any:
"""Import lancedb package."""
return guard_import("lancedb")
def to_lance_filter(filter: Dict[str, str]) -> str:
"""Converts a dict filter to a LanceDB filter string."""
return " AND ".join([f"{k} = '{v}'" for k, v in filter.items()])
class LanceDB(VectorStore):
"""`LanceDB` vector store.
@ -55,6 +66,11 @@ class LanceDB(VectorStore):
api_key: Optional[str] = None,
region: Optional[str] = None,
mode: Optional[str] = "overwrite",
table: Optional[Any] = None,
distance: Optional[str] = "l2",
reranker: Optional[Any] = None,
relevance_score_fn: Optional[Callable[[float], float]] = None,
limit: int = DEFAULT_K,
):
"""Initialize with Lance DB vectorstore"""
lancedb = guard_import("lancedb")
@ -62,10 +78,22 @@ class LanceDB(VectorStore):
self._vector_key = vector_key
self._id_key = id_key
self._text_key = text_key
self._table_name = table_name
self.api_key = api_key or os.getenv("LANCE_API_KEY") if api_key != "" else None
self.region = region
self.mode = mode
self.distance = distance
self.override_relevance_score_fn = relevance_score_fn
self.limit = limit
self._fts_index = None
if isinstance(reranker, lancedb.rerankers.Reranker):
self._reranker = reranker
elif reranker is None:
self._reranker = None
else:
raise ValueError(
"`reranker` has to be a lancedb.rerankers.Reranker object."
)
if isinstance(uri, str) and self.api_key is None:
if uri.startswith("db://"):
@ -96,6 +124,52 @@ class LanceDB(VectorStore):
"api key provided with local uri.\
The data will be stored locally"
)
if table is not None:
try:
assert isinstance(
table, (lancedb.db.LanceTable, lancedb.remote.table.RemoteTable)
)
self._table = table
self._table_name = (
table.name if hasattr(table, "name") else "remote_table"
)
except AssertionError:
raise ValueError(
"""`table` has to be a lancedb.db.LanceTable or
lancedb.remote.table.RemoteTable object."""
)
else:
self._table = self.get_table(table_name, set_default=True)
def results_to_docs(self, results: Any, score: bool = False) -> Any:
columns = results.schema.names
if "_distance" in columns:
score_col = "_distance"
elif "_relevance_score" in columns:
score_col = "_relevance_score"
else:
score_col = None
if score_col is None or not score:
return [
Document(
page_content=results[self._text_key][idx].as_py(),
metadata=results["metadata"][idx].as_py(),
)
for idx in range(len(results))
]
elif score_col and score:
return [
(
Document(
page_content=results[self._text_key][idx].as_py(),
metadata=results["metadata"][idx].as_py(),
),
results[score_col][idx].as_py(),
)
for idx in range(len(results))
]
@property
def embeddings(self) -> Optional[Embeddings]:
@ -114,11 +188,11 @@ class LanceDB(VectorStore):
texts: Iterable of strings to add to the vectorstore.
metadatas: Optional list of metadatas associated with the texts.
ids: Optional list of ids to associate with the texts.
ids: Optional list of ids to associate with the texts.
Returns:
List of ids of the added texts.
"""
# Embed texts and create documents
docs = []
ids = ids or [str(uuid.uuid4()) for _ in texts]
embeddings = self._embedding.embed_documents(list(texts)) # type: ignore
@ -134,14 +208,19 @@ class LanceDB(VectorStore):
}
)
if self._table_name in self._connection.table_names():
tbl = self._connection.open_table(self._table_name)
tbl = self.get_table()
if tbl is None:
tbl = self._connection.create_table(self._table_name, data=docs)
self._table = tbl
else:
if self.api_key is None:
tbl.add(docs, mode=self.mode)
else:
tbl.add(docs)
else:
self._connection.create_table(self._table_name, data=docs)
self._fts_index = None
return ids
def get_table(
@ -164,14 +243,18 @@ class LanceDB(VectorStore):
"""
if name is not None:
try:
if set_default:
self._table_name = name
return self._connection.open_table(name)
except Exception:
raise ValueError(f"Table {name} not found in the database")
if set_default:
self._table_name = name
_name = self._table_name
else:
_name = name
else:
return self._connection.open_table(self._table_name)
_name = self._table_name
try:
return self._connection.open_table(_name)
except Exception:
return None
def create_index(
self,
@ -181,6 +264,7 @@ class LanceDB(VectorStore):
num_sub_vectors: Optional[int] = 96,
index_cache_size: Optional[int] = None,
metric: Optional[str] = "L2",
name: Optional[str] = None,
) -> None:
"""
Create a scalar(for non-vector cols) or a vector index on a table.
@ -191,11 +275,15 @@ class LanceDB(VectorStore):
col_name: Provide if you want to create index on a non-vector column.
metric: Provide the metric to use for vector index. Defaults to 'L2'
choice of metrics: 'L2', 'dot', 'cosine'
num_partitions: Number of partitions to use for the index. Defaults to 256.
num_sub_vectors: Number of sub-vectors to use for the index. Defaults to 96.
index_cache_size: Size of the index cache. Defaults to None.
name: Name of the table to create index on. Defaults to None.
Returns:
None
"""
tbl = self.get_table()
tbl = self.get_table(name)
if vector_col:
tbl.create_index(
@ -210,8 +298,205 @@ class LanceDB(VectorStore):
else:
raise ValueError("Provide either vector_col or col_name")
def encode_image(self, uri: str) -> str:
"""Get base64 string from image URI."""
with open(uri, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def add_images(
self,
uris: List[str],
metadatas: Optional[List[dict]] = None,
ids: Optional[List[str]] = None,
**kwargs: Any,
) -> List[str]:
"""Run more images through the embeddings and add to the vectorstore.
Args:
uris List[str]: File path to the image.
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
ids (Optional[List[str]], optional): Optional list of IDs.
Returns:
List[str]: List of IDs of the added images.
"""
tbl = self.get_table()
# Map from uris to b64 encoded strings
b64_texts = [self.encode_image(uri=uri) for uri in uris]
# Populate IDs
if ids is None:
ids = [str(uuid.uuid4()) for _ in uris]
embeddings = None
# Set embeddings
if self._embedding is not None and hasattr(self._embedding, "embed_image"):
embeddings = self._embedding.embed_image(uris=uris)
else:
raise ValueError(
"embedding object should be provided and must have embed_image method."
)
data = []
for idx, emb in enumerate(embeddings):
metadata = metadatas[idx] if metadatas else {"id": ids[idx]}
data.append(
{
self._vector_key: emb,
self._id_key: ids[idx],
self._text_key: b64_texts[idx],
"metadata": metadata,
}
)
if tbl is None:
tbl = self._connection.create_table(self._table_name, data=data)
self._table = tbl
else:
tbl.add(data)
return ids
def _query(
self,
query: Any,
k: Optional[int] = None,
filter: Optional[Any] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Any:
if k is None:
k = self.limit
tbl = self.get_table(name)
if isinstance(filter, dict):
filter = to_lance_filter(filter)
prefilter = kwargs.get("prefilter", False)
query_type = kwargs.get("query_type", "vector")
lance_query = (
tbl.search(query=query, vector_column_name=self._vector_key)
.limit(k)
.where(filter, prefilter=prefilter)
)
if query_type == "hybrid" and self._reranker is not None:
lance_query.rerank(reranker=self._reranker)
docs = lance_query.to_arrow()
if len(docs) == 0:
warnings.warn("No results found for the query.")
return docs
def _select_relevance_score_fn(self) -> Callable[[float], float]:
"""
The 'correct' relevance function
may differ depending on a few things, including:
- the distance / similarity metric used by the VectorStore
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
- embedding dimensionality
- etc.
"""
if self.override_relevance_score_fn:
return self.override_relevance_score_fn
if self.distance == "cosine":
return self._cosine_relevance_score_fn
elif self.distance == "l2":
return self._euclidean_relevance_score_fn
elif self.distance == "ip":
return self._max_inner_product_relevance_score_fn
else:
raise ValueError(
"No supported normalization function"
f" for distance metric of type: {self.distance}."
"Consider providing relevance_score_fn to Chroma constructor."
)
def similarity_search_by_vector(
self,
embedding: List[float],
k: Optional[int] = None,
filter: Optional[Dict[str, str]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""
Return documents most similar to the query vector.
"""
if k is None:
k = self.limit
res = self._query(embedding, k, filter=filter, name=name, **kwargs)
return self.results_to_docs(res, score=kwargs.pop("score", False))
def similarity_search_by_vector_with_relevance_scores(
self,
embedding: List[float],
k: Optional[int] = None,
filter: Optional[Dict[str, str]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> Any:
"""
Return documents most similar to the query vector with relevance scores.
"""
if k is None:
k = self.limit
relevance_score_fn = self._select_relevance_score_fn()
docs_and_scores = self.similarity_search_by_vector(
embedding, k, score=True, **kwargs
)
return [
(doc, relevance_score_fn(float(score))) for doc, score in docs_and_scores
]
def similarity_search_with_score(
self,
query: str,
k: Optional[int] = None,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> Any:
"""Return documents most similar to the query with relevance scores."""
if k is None:
k = self.limit
score = kwargs.get("score", True)
name = kwargs.get("name", None)
query_type = kwargs.get("query_type", "vector")
if self._embedding is None:
raise ValueError("search needs an emmbedding function to be specified.")
if query_type == "fts" or query_type == "hybrid":
if self.api_key is None and self._fts_index is None:
tbl = self.get_table(name)
self._fts_index = tbl.create_fts_index(self._text_key, replace=True)
if query_type == "hybrid":
embedding = self._embedding.embed_query(query)
_query = (embedding, query)
else:
_query = query # type: ignore
res = self._query(_query, k, filter=filter, name=name, **kwargs)
return self.results_to_docs(res, score=score)
else:
raise NotImplementedError(
"Full text/ Hybrid search is not supported in LanceDB Cloud yet."
)
else:
embedding = self._embedding.embed_query(query)
res = self._query(embedding, k, filter=filter, **kwargs)
return self.results_to_docs(res, score=score)
def similarity_search(
self, query: str, k: int = 4, name: Optional[str] = None, **kwargs: Any
self,
query: str,
k: Optional[int] = None,
name: Optional[str] = None,
filter: Optional[Any] = None,
fts: Optional[bool] = False,
**kwargs: Any,
) -> List[Document]:
"""Return documents most similar to the query
@ -227,60 +512,118 @@ class LanceDB(VectorStore):
Returns:
List of documents most similar to the query.
"""
res = self.similarity_search_with_score(
query=query, k=k, name=name, filter=filter, fts=fts, score=False, **kwargs
)
return res
Examples:
def max_marginal_relevance_search(
self,
query: str,
k: Optional[int] = None,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
.. code-block:: python
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
# Retrieve documents with filtering based on a metadata file_type
vector_store.as_retriever(search_kwargs={"k": 4, "filter":{
'sql_filter':"file_type='notice'",
'prefilter': True
}
})
# Retrieve documents with filtering on a specific file name
vector_store.as_retriever(search_kwargs={"k": 4, "filter":{
'sql_filter':"source='my-file.txt'",
'prefilter': True
}
})
Returns:
List of Documents selected by maximal marginal relevance.
"""
embedding = self._embedding.embed_query(query) # type: ignore
tbl = self.get_table(name)
filters = kwargs.pop("filter", {})
sql_filter = filters.pop("sql_filter", None)
prefilter = filters.pop("prefilter", False)
docs = (
tbl.search(embedding, vector_column_name=self._vector_key)
.where(sql_filter, prefilter=prefilter)
.limit(k)
.to_arrow()
)
columns = docs.schema.names
return [
Document(
page_content=docs[self._text_key][idx].as_py(),
metadata={
col: docs[col][idx].as_py()
for col in columns
if col != self._text_key
},
if k is None:
k = self.limit
if self._embedding is None:
raise ValueError(
"For MMR search, you must specify an embedding function on" "creation."
)
for idx in range(len(docs))
]
embedding = self._embedding.embed_query(query)
docs = self.max_marginal_relevance_search_by_vector(
embedding,
k,
fetch_k,
lambda_mult=lambda_mult,
filter=filter,
)
return docs
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: Optional[int] = None,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
Returns:
List of Documents selected by maximal marginal relevance.
"""
results = self._query(
query=embedding,
k=fetch_k,
filter=filter,
**kwargs,
)
mmr_selected = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),
results["vector"].to_pylist(),
k=k or self.limit,
lambda_mult=lambda_mult,
)
candidates = self.results_to_docs(results)
selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected]
return selected_results
@classmethod
def from_texts(
cls,
cls: Type[LanceDB],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
connection: Any = None,
connection: Optional[Any] = None,
vector_key: Optional[str] = "vector",
id_key: Optional[str] = "id",
text_key: Optional[str] = "text",
table_name: Optional[str] = "vectorstore",
api_key: Optional[str] = None,
region: Optional[str] = None,
mode: Optional[str] = "overwrite",
distance: Optional[str] = "l2",
reranker: Optional[Any] = None,
relevance_score_fn: Optional[Callable[[float], float]] = None,
**kwargs: Any,
) -> LanceDB:
instance = LanceDB(
@ -290,8 +633,15 @@ class LanceDB(VectorStore):
id_key=id_key,
text_key=text_key,
table_name=table_name,
api_key=api_key,
region=region,
mode=mode,
distance=distance,
reranker=reranker,
relevance_score_fn=relevance_score_fn,
**kwargs,
)
instance.add_texts(texts, metadatas=metadatas, **kwargs)
instance.add_texts(texts, metadatas=metadatas)
return instance

@ -11,7 +11,7 @@ def import_lancedb() -> Any:
import lancedb
except ImportError as e:
raise ImportError(
"Could not import pinecone lancedb package. "
"Could not import lancedb package. "
"Please install it with `pip install lancedb`."
) from e
return lancedb
@ -56,3 +56,51 @@ def test_lancedb_add_texts() -> None:
result = store.similarity_search("text 2")
result_texts = [doc.page_content for doc in result]
assert "text 2" in result_texts
@pytest.mark.requires("lancedb")
def test_mmr() -> None:
embeddings = FakeEmbeddings()
store = LanceDB(embedding=embeddings)
store.add_texts(["text 1", "text 2", "item 3"])
result = store.max_marginal_relevance_search(query="text")
result_texts = [doc.page_content for doc in result]
assert "text 1" in result_texts
result = store.max_marginal_relevance_search_by_vector(
embeddings.embed_query("text")
)
result_texts = [doc.page_content for doc in result]
assert "text 1" in result_texts
@pytest.mark.requires("lancedb")
def test_lancedb_delete() -> None:
embeddings = FakeEmbeddings()
store = LanceDB(embedding=embeddings)
store.add_texts(["text 1", "text 2", "item 3"])
store.delete(filter="text = 'text 1'")
assert store.get_table().count_rows() == 2
@pytest.mark.requires("lancedb")
def test_lancedb_all_searches() -> None:
embeddings = FakeEmbeddings()
store = LanceDB(embedding=embeddings)
store.add_texts(["text 1", "text 2", "item 3"])
result_1 = store.similarity_search_with_relevance_scores(
"text 1", distance="cosine"
)
assert len(result_1[0]) == 2
assert "text 1" in result_1[0][0].page_content
result_2 = store.similarity_search_by_vector(embeddings.embed_query("text 1"))
assert "text 1" in result_2[0].page_content
result_3 = store.similarity_search_by_vector_with_relevance_scores(
embeddings.embed_query("text 1")
)
assert len(result_3[0]) == 2 # type: ignore
assert "text 1" in result_3[0][0].page_content # type: ignore

Loading…
Cancel
Save