From 415d38ae622c340cdf2c38fc7fb44565a72b6903 Mon Sep 17 00:00:00 2001 From: Stefano Lottini Date: Wed, 13 Sep 2023 23:18:39 +0200 Subject: [PATCH] Cassandra Vector Store, add metadata filtering + improvements (#9280) This PR addresses a few minor issues with the Cassandra vector store implementation and extends the store to support Metadata search. Thanks to the latest cassIO library (>=0.1.0), metadata filtering is available in the store. Further, - the "relevance" score is prevented from being flipped in the [0,1] interval, thus ensuring that 1 corresponds to the closest vector (this is related to how the underlying cassIO class returns the cosine difference); - bumped the cassIO package version both in the notebooks and the pyproject.toml; - adjusted the textfile location for the vector-store example after the reshuffling of the Langchain repo dir structure; - added demonstration of metadata filtering in the Cassandra vector store notebook; - better docstring for the Cassandra vector store class; - fixed test flakiness and removed offending out-of-place escape chars from a test module docstring; To my knowledge all relevant tests pass and mypy+black+ruff don't complain. (mypy gives unrelated errors in other modules, which clearly don't depend on the content of this PR). Thank you! Stefano --------- Co-authored-by: Bagatur --- .../cassandra_chat_message_history.ipynb | 4 +- .../integrations/vectorstores/cassandra.ipynb | 55 +++++++++++++- .../langchain/vectorstores/cassandra.py | 74 +++++++++++++++---- .../vectorstores/test_cassandra.py | 18 +++-- 4 files changed, 123 insertions(+), 28 deletions(-) diff --git a/docs/extras/integrations/memory/cassandra_chat_message_history.ipynb b/docs/extras/integrations/memory/cassandra_chat_message_history.ipynb index 65ee1e5e2a..9fa2a6293c 100644 --- a/docs/extras/integrations/memory/cassandra_chat_message_history.ipynb +++ b/docs/extras/integrations/memory/cassandra_chat_message_history.ipynb @@ -23,7 +23,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install \"cassio>=0.0.7\"" + "!pip install \"cassio>=0.1.0\"" ] }, { @@ -155,7 +155,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/docs/extras/integrations/vectorstores/cassandra.ipynb b/docs/extras/integrations/vectorstores/cassandra.ipynb index b689ea74f9..fa8b4a570d 100644 --- a/docs/extras/integrations/vectorstores/cassandra.ipynb +++ b/docs/extras/integrations/vectorstores/cassandra.ipynb @@ -23,7 +23,7 @@ }, "outputs": [], "source": [ - "!pip install \"cassio>=0.0.7\"" + "!pip install \"cassio>=0.1.0\"" ] }, { @@ -152,7 +152,9 @@ "source": [ "from langchain.document_loaders import TextLoader\n", "\n", - "loader = TextLoader(\"../../../state_of_the_union.txt\")\n", + "SOURCE_FILE_NAME = \"../../modules/state_of_the_union.txt\"\n", + "\n", + "loader = TextLoader(SOURCE_FILE_NAME)\n", "documents = loader.load()\n", "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", "docs = text_splitter.split_documents(documents)\n", @@ -197,7 +199,7 @@ "# table_name=table_name,\n", "# )\n", "\n", - "# docsearch_preexisting.similarity_search(query, k=2)" + "# docs = docsearch_preexisting.similarity_search(query, k=2)" ] }, { @@ -253,6 +255,51 @@ "for i, doc in enumerate(found_docs):\n", " print(f\"{i + 1}.\", doc.page_content, \"\\n\")" ] + }, + { + "cell_type": "markdown", + "id": "da791c5f", + "metadata": {}, + "source": [ + "### Metadata filtering\n", + "\n", + "You can specify filtering on metadata when running searches in the vector store. By default, when inserting documents, the only metadata is the `\"source\"` (but you can customize the metadata at insertion time).\n", + "\n", + "Since only one files was inserted, this is just a demonstration of how filters are passed:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "93f132fa", + "metadata": {}, + "outputs": [], + "source": [ + "filter = {\"source\": SOURCE_FILE_NAME}\n", + "filtered_docs = docsearch.similarity_search(query, filter=filter, k=5)\n", + "print(f\"{len(filtered_docs)} documents retrieved.\")\n", + "print(f\"{filtered_docs[0].page_content[:64]} ...\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b413ec4", + "metadata": {}, + "outputs": [], + "source": [ + "filter = {\"source\": \"nonexisting_file.txt\"}\n", + "filtered_docs2 = docsearch.similarity_search(query, filter=filter)\n", + "print(f\"{len(filtered_docs2)} documents retrieved.\")" + ] + }, + { + "cell_type": "markdown", + "id": "a0fea764", + "metadata": {}, + "source": [ + "Please visit the [cassIO documentation](https://cassio.org/frameworks/langchain/about/) for more on using vector stores with Langchain." + ] } ], "metadata": { @@ -271,7 +318,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/libs/langchain/langchain/vectorstores/cassandra.py b/libs/langchain/langchain/vectorstores/cassandra.py index cc6541b5f7..083f8b90f6 100644 --- a/libs/langchain/langchain/vectorstores/cassandra.py +++ b/libs/langchain/langchain/vectorstores/cassandra.py @@ -2,7 +2,18 @@ from __future__ import annotations import typing import uuid -from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, TypeVar +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) import numpy as np @@ -18,11 +29,12 @@ CVST = TypeVar("CVST", bound="Cassandra") class Cassandra(VectorStore): - """`Cassandra` vector store. + """Wrapper around Apache Cassandra(R) for vector-store workloads. - It based on the Cassandra vector-store capabilities, based on cassIO. - There is no notion of a default table name, since each embedding - function implies its own vector dimension, which is part of the schema. + To use it, you need a recent installation of the `cassio` library + and a Cassandra cluster / Astra DB instance supporting vector capabilities. + + Visit the cassio.org website for extensive quickstarts and code examples. Example: .. code-block:: python @@ -31,12 +43,20 @@ class Cassandra(VectorStore): from langchain.embeddings.openai import OpenAIEmbeddings embeddings = OpenAIEmbeddings() - session = ... - keyspace = 'my_keyspace' - vectorstore = Cassandra(embeddings, session, keyspace, 'my_doc_archive') + session = ... # create your Cassandra session object + keyspace = 'my_keyspace' # the keyspace should exist already + table_name = 'my_vector_store' + vectorstore = Cassandra(embeddings, session, keyspace, table_name) """ - _embedding_dimension: int | None + _embedding_dimension: Union[int, None] + + @staticmethod + def _filter_to_metadata(filter_dict: Optional[Dict[str, str]]) -> Dict[str, Any]: + if filter_dict is None: + return {} + else: + return filter_dict def _get_embedding_dimension(self) -> int: if self._embedding_dimension is None: @@ -81,8 +101,18 @@ class Cassandra(VectorStore): def embeddings(self) -> Embeddings: return self.embedding + @staticmethod + def _dont_flip_the_cos_score(distance: float) -> float: + # the identity + return distance + def _select_relevance_score_fn(self) -> Callable[[float], float]: - return self._cosine_relevance_score_fn + """ + The underlying VectorTable already returns a "score proper", + i.e. one in [0, 1] where higher means more *similar*, + so here the final score transformation is not reversing the interval: + """ + return self._dont_flip_the_cos_score def delete_collection(self) -> None: """ @@ -172,22 +202,24 @@ class Cassandra(VectorStore): self, embedding: List[float], k: int = 4, + filter: Optional[Dict[str, str]] = None, ) -> List[Tuple[Document, float, str]]: """Return docs most similar to embedding vector. - No support for `filter` query (on metadata) along with vector search. - Args: embedding (str): Embedding to look up documents similar to. k (int): Number of Documents to return. Defaults to 4. Returns: List of (Document, score, id), the most similar to the query vector. """ + search_metadata = self._filter_to_metadata(filter) + # hits = self.table.search( embedding_vector=embedding, top_k=k, metric="cos", metric_threshold=None, + metadata=search_metadata, ) # We stick to 'cos' distance as it can be normalized on a 0-1 axis # (1=most relevant), as required by this class' contract. @@ -207,11 +239,13 @@ class Cassandra(VectorStore): self, query: str, k: int = 4, + filter: Optional[Dict[str, str]] = None, ) -> List[Tuple[Document, float, str]]: embedding_vector = self.embedding.embed_query(query) return self.similarity_search_with_score_id_by_vector( embedding=embedding_vector, k=k, + filter=filter, ) # id-unaware search facilities @@ -219,11 +253,10 @@ class Cassandra(VectorStore): self, embedding: List[float], k: int = 4, + filter: Optional[Dict[str, str]] = None, ) -> List[Tuple[Document, float]]: """Return docs most similar to embedding vector. - No support for `filter` query (on metadata) along with vector search. - Args: embedding (str): Embedding to look up documents similar to. k (int): Number of Documents to return. Defaults to 4. @@ -235,6 +268,7 @@ class Cassandra(VectorStore): for (doc, score, docId) in self.similarity_search_with_score_id_by_vector( embedding=embedding, k=k, + filter=filter, ) ] @@ -242,18 +276,21 @@ class Cassandra(VectorStore): self, query: str, k: int = 4, + filter: Optional[Dict[str, str]] = None, **kwargs: Any, ) -> List[Document]: embedding_vector = self.embedding.embed_query(query) return self.similarity_search_by_vector( embedding_vector, k, + filter=filter, ) def similarity_search_by_vector( self, embedding: List[float], k: int = 4, + filter: Optional[Dict[str, str]] = None, **kwargs: Any, ) -> List[Document]: return [ @@ -261,6 +298,7 @@ class Cassandra(VectorStore): for doc, _ in self.similarity_search_with_score_by_vector( embedding, k, + filter=filter, ) ] @@ -268,11 +306,13 @@ class Cassandra(VectorStore): self, query: str, k: int = 4, + filter: Optional[Dict[str, str]] = None, ) -> List[Tuple[Document, float]]: embedding_vector = self.embedding.embed_query(query) return self.similarity_search_with_score_by_vector( embedding_vector, k, + filter=filter, ) def max_marginal_relevance_search_by_vector( @@ -281,6 +321,7 @@ class Cassandra(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -296,11 +337,14 @@ class Cassandra(VectorStore): Returns: List of Documents selected by maximal marginal relevance. """ + search_metadata = self._filter_to_metadata(filter) + prefetchHits = self.table.search( embedding_vector=embedding, top_k=fetch_k, metric="cos", metric_threshold=None, + metadata=search_metadata, ) # let the mmr utility pick the *indices* in the above array mmrChosenIndices = maximal_marginal_relevance( @@ -328,6 +372,7 @@ class Cassandra(VectorStore): k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, **kwargs: Any, ) -> List[Document]: """Return docs selected using the maximal marginal relevance. @@ -350,6 +395,7 @@ class Cassandra(VectorStore): k, fetch_k, lambda_mult=lambda_mult, + filter=filter, ) @classmethod diff --git a/libs/langchain/tests/integration_tests/vectorstores/test_cassandra.py b/libs/langchain/tests/integration_tests/vectorstores/test_cassandra.py index 443dd73efc..e0c0403301 100644 --- a/libs/langchain/tests/integration_tests/vectorstores/test_cassandra.py +++ b/libs/langchain/tests/integration_tests/vectorstores/test_cassandra.py @@ -1,4 +1,5 @@ """Test Cassandra functionality.""" +import time from typing import List, Optional, Type from cassandra.cluster import Cluster @@ -61,9 +62,9 @@ def test_cassandra_with_score() -> None: docs = [o[0] for o in output] scores = [o[1] for o in output] assert docs == [ - Document(page_content="foo", metadata={"page": 0}), - Document(page_content="bar", metadata={"page": 1}), - Document(page_content="baz", metadata={"page": 2}), + Document(page_content="foo", metadata={"page": "0.0"}), + Document(page_content="bar", metadata={"page": "1.0"}), + Document(page_content="baz", metadata={"page": "2.0"}), ] assert scores[0] > scores[1] > scores[2] @@ -76,10 +77,10 @@ def test_cassandra_max_marginal_relevance_search() -> None: ______ v2 / \ - / \ v1 + / | v1 v3 | . | query - \ / v0 - \______/ (N.B. very crude drawing) + | / v0 + |______/ (N.B. very crude drawing) With fetch_k==3 and k==2, when query is at (1, ), one expects that v2 and v0 are returned (in some order). @@ -94,8 +95,8 @@ def test_cassandra_max_marginal_relevance_search() -> None: (mmr_doc.page_content, mmr_doc.metadata["page"]) for mmr_doc in output } assert output_set == { - ("+0.25", 2), - ("-0.124", 0), + ("+0.25", "2.0"), + ("-0.124", "0.0"), } @@ -150,6 +151,7 @@ def test_cassandra_delete() -> None: assert len(output) == 1 docsearch.clear() + time.sleep(0.3) output = docsearch.similarity_search("foo", k=10) assert len(output) == 0