Cassandra Vector Store, add metadata filtering + improvements (#9280)

This PR addresses a few minor issues with the Cassandra vector store
implementation and extends the store to support Metadata search.

Thanks to the latest cassIO library (>=0.1.0), metadata filtering is
available in the store.

Further,
- the "relevance" score is prevented from being flipped in the [0,1]
interval, thus ensuring that 1 corresponds to the closest vector (this
is related to how the underlying cassIO class returns the cosine
difference);
- bumped the cassIO package version both in the notebooks and the
pyproject.toml;
- adjusted the textfile location for the vector-store example after the
reshuffling of the Langchain repo dir structure;
- added demonstration of metadata filtering in the Cassandra vector
store notebook;
- better docstring for the Cassandra vector store class;
- fixed test flakiness and removed offending out-of-place escape chars
from a test module docstring;

To my knowledge all relevant tests pass and mypy+black+ruff don't
complain. (mypy gives unrelated errors in other modules, which clearly
don't depend on the content of this PR).

Thank you!
Stefano

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/10594/head
Stefano Lottini 1 year ago committed by GitHub
parent 49694f6a3f
commit 415d38ae62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -23,7 +23,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install \"cassio>=0.0.7\""
"!pip install \"cassio>=0.1.0\""
]
},
{
@ -155,7 +155,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.10.12"
}
},
"nbformat": 4,

@ -23,7 +23,7 @@
},
"outputs": [],
"source": [
"!pip install \"cassio>=0.0.7\""
"!pip install \"cassio>=0.1.0\""
]
},
{
@ -152,7 +152,9 @@
"source": [
"from langchain.document_loaders import TextLoader\n",
"\n",
"loader = TextLoader(\"../../../state_of_the_union.txt\")\n",
"SOURCE_FILE_NAME = \"../../modules/state_of_the_union.txt\"\n",
"\n",
"loader = TextLoader(SOURCE_FILE_NAME)\n",
"documents = loader.load()\n",
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
"docs = text_splitter.split_documents(documents)\n",
@ -197,7 +199,7 @@
"# table_name=table_name,\n",
"# )\n",
"\n",
"# docsearch_preexisting.similarity_search(query, k=2)"
"# docs = docsearch_preexisting.similarity_search(query, k=2)"
]
},
{
@ -253,6 +255,51 @@
"for i, doc in enumerate(found_docs):\n",
" print(f\"{i + 1}.\", doc.page_content, \"\\n\")"
]
},
{
"cell_type": "markdown",
"id": "da791c5f",
"metadata": {},
"source": [
"### Metadata filtering\n",
"\n",
"You can specify filtering on metadata when running searches in the vector store. By default, when inserting documents, the only metadata is the `\"source\"` (but you can customize the metadata at insertion time).\n",
"\n",
"Since only one files was inserted, this is just a demonstration of how filters are passed:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "93f132fa",
"metadata": {},
"outputs": [],
"source": [
"filter = {\"source\": SOURCE_FILE_NAME}\n",
"filtered_docs = docsearch.similarity_search(query, filter=filter, k=5)\n",
"print(f\"{len(filtered_docs)} documents retrieved.\")\n",
"print(f\"{filtered_docs[0].page_content[:64]} ...\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b413ec4",
"metadata": {},
"outputs": [],
"source": [
"filter = {\"source\": \"nonexisting_file.txt\"}\n",
"filtered_docs2 = docsearch.similarity_search(query, filter=filter)\n",
"print(f\"{len(filtered_docs2)} documents retrieved.\")"
]
},
{
"cell_type": "markdown",
"id": "a0fea764",
"metadata": {},
"source": [
"Please visit the [cassIO documentation](https://cassio.org/frameworks/langchain/about/) for more on using vector stores with Langchain."
]
}
],
"metadata": {
@ -271,7 +318,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.10.12"
}
},
"nbformat": 4,

@ -2,7 +2,18 @@ from __future__ import annotations
import typing
import uuid
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, TypeVar
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
import numpy as np
@ -18,11 +29,12 @@ CVST = TypeVar("CVST", bound="Cassandra")
class Cassandra(VectorStore):
"""`Cassandra` vector store.
"""Wrapper around Apache Cassandra(R) for vector-store workloads.
It based on the Cassandra vector-store capabilities, based on cassIO.
There is no notion of a default table name, since each embedding
function implies its own vector dimension, which is part of the schema.
To use it, you need a recent installation of the `cassio` library
and a Cassandra cluster / Astra DB instance supporting vector capabilities.
Visit the cassio.org website for extensive quickstarts and code examples.
Example:
.. code-block:: python
@ -31,12 +43,20 @@ class Cassandra(VectorStore):
from langchain.embeddings.openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
session = ...
keyspace = 'my_keyspace'
vectorstore = Cassandra(embeddings, session, keyspace, 'my_doc_archive')
session = ... # create your Cassandra session object
keyspace = 'my_keyspace' # the keyspace should exist already
table_name = 'my_vector_store'
vectorstore = Cassandra(embeddings, session, keyspace, table_name)
"""
_embedding_dimension: int | None
_embedding_dimension: Union[int, None]
@staticmethod
def _filter_to_metadata(filter_dict: Optional[Dict[str, str]]) -> Dict[str, Any]:
if filter_dict is None:
return {}
else:
return filter_dict
def _get_embedding_dimension(self) -> int:
if self._embedding_dimension is None:
@ -81,8 +101,18 @@ class Cassandra(VectorStore):
def embeddings(self) -> Embeddings:
return self.embedding
@staticmethod
def _dont_flip_the_cos_score(distance: float) -> float:
# the identity
return distance
def _select_relevance_score_fn(self) -> Callable[[float], float]:
return self._cosine_relevance_score_fn
"""
The underlying VectorTable already returns a "score proper",
i.e. one in [0, 1] where higher means more *similar*,
so here the final score transformation is not reversing the interval:
"""
return self._dont_flip_the_cos_score
def delete_collection(self) -> None:
"""
@ -172,22 +202,24 @@ class Cassandra(VectorStore):
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, str]] = None,
) -> List[Tuple[Document, float, str]]:
"""Return docs most similar to embedding vector.
No support for `filter` query (on metadata) along with vector search.
Args:
embedding (str): Embedding to look up documents similar to.
k (int): Number of Documents to return. Defaults to 4.
Returns:
List of (Document, score, id), the most similar to the query vector.
"""
search_metadata = self._filter_to_metadata(filter)
#
hits = self.table.search(
embedding_vector=embedding,
top_k=k,
metric="cos",
metric_threshold=None,
metadata=search_metadata,
)
# We stick to 'cos' distance as it can be normalized on a 0-1 axis
# (1=most relevant), as required by this class' contract.
@ -207,11 +239,13 @@ class Cassandra(VectorStore):
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, str]] = None,
) -> List[Tuple[Document, float, str]]:
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_with_score_id_by_vector(
embedding=embedding_vector,
k=k,
filter=filter,
)
# id-unaware search facilities
@ -219,11 +253,10 @@ class Cassandra(VectorStore):
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, str]] = None,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to embedding vector.
No support for `filter` query (on metadata) along with vector search.
Args:
embedding (str): Embedding to look up documents similar to.
k (int): Number of Documents to return. Defaults to 4.
@ -235,6 +268,7 @@ class Cassandra(VectorStore):
for (doc, score, docId) in self.similarity_search_with_score_id_by_vector(
embedding=embedding,
k=k,
filter=filter,
)
]
@ -242,18 +276,21 @@ class Cassandra(VectorStore):
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_by_vector(
embedding_vector,
k,
filter=filter,
)
def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
return [
@ -261,6 +298,7 @@ class Cassandra(VectorStore):
for doc, _ in self.similarity_search_with_score_by_vector(
embedding,
k,
filter=filter,
)
]
@ -268,11 +306,13 @@ class Cassandra(VectorStore):
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, str]] = None,
) -> List[Tuple[Document, float]]:
embedding_vector = self.embedding.embed_query(query)
return self.similarity_search_with_score_by_vector(
embedding_vector,
k,
filter=filter,
)
def max_marginal_relevance_search_by_vector(
@ -281,6 +321,7 @@ class Cassandra(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
@ -296,11 +337,14 @@ class Cassandra(VectorStore):
Returns:
List of Documents selected by maximal marginal relevance.
"""
search_metadata = self._filter_to_metadata(filter)
prefetchHits = self.table.search(
embedding_vector=embedding,
top_k=fetch_k,
metric="cos",
metric_threshold=None,
metadata=search_metadata,
)
# let the mmr utility pick the *indices* in the above array
mmrChosenIndices = maximal_marginal_relevance(
@ -328,6 +372,7 @@ class Cassandra(VectorStore):
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, str]] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
@ -350,6 +395,7 @@ class Cassandra(VectorStore):
k,
fetch_k,
lambda_mult=lambda_mult,
filter=filter,
)
@classmethod

@ -1,4 +1,5 @@
"""Test Cassandra functionality."""
import time
from typing import List, Optional, Type
from cassandra.cluster import Cluster
@ -61,9 +62,9 @@ def test_cassandra_with_score() -> None:
docs = [o[0] for o in output]
scores = [o[1] for o in output]
assert docs == [
Document(page_content="foo", metadata={"page": 0}),
Document(page_content="bar", metadata={"page": 1}),
Document(page_content="baz", metadata={"page": 2}),
Document(page_content="foo", metadata={"page": "0.0"}),
Document(page_content="bar", metadata={"page": "1.0"}),
Document(page_content="baz", metadata={"page": "2.0"}),
]
assert scores[0] > scores[1] > scores[2]
@ -76,10 +77,10 @@ def test_cassandra_max_marginal_relevance_search() -> None:
______ v2
/ \
/ \ v1
/ | v1
v3 | . | query
\ / v0
\______/ (N.B. very crude drawing)
| / v0
|______/ (N.B. very crude drawing)
With fetch_k==3 and k==2, when query is at (1, ),
one expects that v2 and v0 are returned (in some order).
@ -94,8 +95,8 @@ def test_cassandra_max_marginal_relevance_search() -> None:
(mmr_doc.page_content, mmr_doc.metadata["page"]) for mmr_doc in output
}
assert output_set == {
("+0.25", 2),
("-0.124", 0),
("+0.25", "2.0"),
("-0.124", "0.0"),
}
@ -150,6 +151,7 @@ def test_cassandra_delete() -> None:
assert len(output) == 1
docsearch.clear()
time.sleep(0.3)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 0

Loading…
Cancel
Save