From 22af93d8516a4ecc05e2c814ad5660c0b6427625 Mon Sep 17 00:00:00 2001 From: Stefano Lottini Date: Tue, 20 Jun 2023 19:46:20 +0200 Subject: [PATCH] Vector store support for Cassandra (#6426) This addresses #6291 adding support for using Cassandra (and compatible databases, such as DataStax Astra DB) as a [Vector Store](https://cwiki.apache.org/confluence/display/CASSANDRA/CEP-30%3A+Approximate+Nearest+Neighbor(ANN)+Vector+Search+via+Storage-Attached+Indexes). A new class `Cassandra` is introduced, which complies with the contract and interface for a vector store, along with the corresponding integration test, a sample notebook and modified dependency toml. Dependencies: the implementation relies on the library `cassio`, which simplifies interacting with Cassandra for ML- and LLM-oriented workloads. CassIO, in turn, uses the `cassandra-driver` low-lever drivers to communicate with the database. The former is added as optional dependency (+ in `extended_testing`), the latter was already in the project. Integration testing relies on a locally-running instance of Cassandra. [Here](https://cassio.org/more_info/#use-a-local-vector-capable-cassandra) a detailed description can be found on how to compile and run it (at the time of writing the feature has not made it yet to a release). During development of the integration tests, I added a new "fake embedding" class for what I consider a more controlled way of testing the MMR search method. Likewise, I had to amend what looked like a glitch in the behaviour of `ConsistentFakeEmbeddings` whereby an `embed_query` call would have bypassed storage of the requested text in the class cache for use in later repeated invocations. @dev2049 might be the right person to tag here for a review. Thank you! --------- Co-authored-by: rlm --- .../vectorstores/integrations/cassandra.ipynb | 269 ++++++++++++ langchain/vectorstores/__init__.py | 2 + langchain/vectorstores/cassandra.py | 402 ++++++++++++++++++ .../vectorstores/fake_embeddings.py | 27 ++ .../vectorstores/test_cassandra.py | 135 ++++++ 5 files changed, 835 insertions(+) create mode 100644 docs/extras/modules/data_connection/vectorstores/integrations/cassandra.ipynb create mode 100644 langchain/vectorstores/cassandra.py create mode 100644 tests/integration_tests/vectorstores/test_cassandra.py diff --git a/docs/extras/modules/data_connection/vectorstores/integrations/cassandra.ipynb b/docs/extras/modules/data_connection/vectorstores/integrations/cassandra.ipynb new file mode 100644 index 0000000000..406caa548f --- /dev/null +++ b/docs/extras/modules/data_connection/vectorstores/integrations/cassandra.ipynb @@ -0,0 +1,269 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "683953b3", + "metadata": {}, + "source": [ + "# Cassandra\n", + "\n", + ">[Apache Cassandra®](https://cassandra.apache.org) is a NoSQL, row-oriented, highly scalable and highly available database.\n", + "\n", + "Newest Cassandra releases natively [support](https://cwiki.apache.org/confluence/display/CASSANDRA/CEP-30%3A+Approximate+Nearest+Neighbor(ANN)+Vector+Search+via+Storage-Attached+Indexes) Vector Similarity Search.\n", + "\n", + "To run this notebook you need either a running Cassandra cluster equipped with Vector Search capabilities (in pre-release at the time of writing) or a DataStax Astra DB instance running in the cloud (you can get one for free at [datastax.com](https://astra.datastax.com)). Check [cassio.org](https://cassio.org/start_here/) for more information." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4c41cad-08ef-4f72-a545-2151e4598efe", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!pip install \"cassio>=0.0.5\"" + ] + }, + { + "cell_type": "markdown", + "id": "b7e46bb0", + "metadata": {}, + "source": [ + "### Please provide database connection parameters and secrets:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36128a32", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import getpass\n", + "\n", + "database_mode = (input('\\n(L)ocal Cassandra or (A)stra DB? ')).upper()\n", + "\n", + "keyspace_name = input('\\nKeyspace name? ')\n", + "\n", + "if database_mode == 'A':\n", + " ASTRA_DB_APPLICATION_TOKEN = getpass.getpass('\\nAstra DB Token (\"AstraCS:...\") ')\n", + " #\n", + " ASTRA_DB_SECURE_BUNDLE_PATH = input('Full path to your Secure Connect Bundle? ')" + ] + }, + { + "cell_type": "markdown", + "id": "4f22aac2", + "metadata": {}, + "source": [ + "#### depending on whether local or cloud-based Astra DB, create the corresponding database connection \"Session\" object" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "677f8576", + "metadata": {}, + "outputs": [], + "source": [ + "from cassandra.cluster import Cluster\n", + "from cassandra.auth import PlainTextAuthProvider\n", + "\n", + "if database_mode == 'L':\n", + " cluster = Cluster()\n", + " session = cluster.connect()\n", + "elif database_mode == 'A':\n", + " ASTRA_DB_CLIENT_ID = \"token\"\n", + " cluster = Cluster(\n", + " cloud={\n", + " \"secure_connect_bundle\": ASTRA_DB_SECURE_BUNDLE_PATH,\n", + " },\n", + " auth_provider=PlainTextAuthProvider(\n", + " ASTRA_DB_CLIENT_ID,\n", + " ASTRA_DB_APPLICATION_TOKEN,\n", + " ),\n", + " )\n", + " session = cluster.connect()\n", + "else:\n", + " raise NotImplementedError" + ] + }, + { + "cell_type": "markdown", + "id": "320af802-9271-46ee-948f-d2453933d44b", + "metadata": {}, + "source": [ + "### Please provide OpenAI access key\n", + "\n", + "We want to use `OpenAIEmbeddings` so we have to get the OpenAI API Key." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ffea66e4-bc23-46a9-9580-b348dfe7b7a7", + "metadata": {}, + "outputs": [], + "source": [ + "os.environ['OPENAI_API_KEY'] = getpass.getpass('OpenAI API Key:')" + ] + }, + { + "cell_type": "markdown", + "id": "e98a139b", + "metadata": {}, + "source": [ + "### Creation and usage of the Vector Store" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aac9563e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from langchain.embeddings.openai import OpenAIEmbeddings\n", + "from langchain.text_splitter import CharacterTextSplitter\n", + "from langchain.vectorstores import Cassandra\n", + "from langchain.document_loaders import TextLoader" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a3c3999a", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.document_loaders import TextLoader\n", + "loader = TextLoader('../../../state_of_the_union.txt')\n", + "documents = loader.load()\n", + "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", + "docs = text_splitter.split_documents(documents)\n", + "\n", + "embedding_function = OpenAIEmbeddings()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6e104aee", + "metadata": {}, + "outputs": [], + "source": [ + "table_name = 'my_vector_db_table'\n", + "\n", + "docsearch = Cassandra.from_documents(\n", + " documents=docs,\n", + " embedding=embedding_function,\n", + " session=session,\n", + " keyspace=keyspace_name,\n", + " table_name=table_name,\n", + ")\n", + "\n", + "query = \"What did the president say about Ketanji Brown Jackson\"\n", + "docs = docsearch.similarity_search(query)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f509ee02", + "metadata": {}, + "outputs": [], + "source": [ + "## if you already have an index, you can load it and use it like this:\n", + "\n", + "# docsearch_preexisting = Cassandra(\n", + "# embedding=embedding_function,\n", + "# session=session,\n", + "# keyspace=keyspace_name,\n", + "# table_name=table_name,\n", + "# )\n", + "\n", + "# docsearch_preexisting.similarity_search(query, k=2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c608226", + "metadata": {}, + "outputs": [], + "source": [ + "print(docs[0].page_content)" + ] + }, + { + "cell_type": "markdown", + "id": "d46d1452", + "metadata": {}, + "source": [ + "### Maximal Marginal Relevance Searches\n", + "\n", + "In addition to using similarity search in the retriever object, you can also use `mmr` as retriever.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a359ed74", + "metadata": {}, + "outputs": [], + "source": [ + "retriever = docsearch.as_retriever(search_type=\"mmr\")\n", + "matched_docs = retriever.get_relevant_documents(query)\n", + "for i, d in enumerate(matched_docs):\n", + " print(f\"\\n## Document {i}\\n\")\n", + " print(d.page_content)" + ] + }, + { + "cell_type": "markdown", + "id": "7c477287", + "metadata": {}, + "source": [ + "Or use `max_marginal_relevance_search` directly:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ca82740", + "metadata": {}, + "outputs": [], + "source": [ + "found_docs = docsearch.max_marginal_relevance_search(query, k=2, fetch_k=10)\n", + "for i, doc in enumerate(found_docs):\n", + " print(f\"{i + 1}.\", doc.page_content, \"\\n\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/vectorstores/__init__.py b/langchain/vectorstores/__init__.py index 2ca7f7a31c..e0f37135fc 100644 --- a/langchain/vectorstores/__init__.py +++ b/langchain/vectorstores/__init__.py @@ -9,6 +9,7 @@ from langchain.vectorstores.atlas import AtlasDB from langchain.vectorstores.awadb import AwaDB from langchain.vectorstores.azuresearch import AzureSearch from langchain.vectorstores.base import VectorStore +from langchain.vectorstores.cassandra import Cassandra from langchain.vectorstores.chroma import Chroma from langchain.vectorstores.clickhouse import Clickhouse, ClickhouseSettings from langchain.vectorstores.deeplake import DeepLake @@ -43,6 +44,7 @@ __all__ = [ "AtlasDB", "AwaDB", "AzureSearch", + "Cassandra", "Chroma", "Clickhouse", "ClickhouseSettings", diff --git a/langchain/vectorstores/cassandra.py b/langchain/vectorstores/cassandra.py new file mode 100644 index 0000000000..4c861d2429 --- /dev/null +++ b/langchain/vectorstores/cassandra.py @@ -0,0 +1,402 @@ +"""Wrapper around Cassandra vector-store capabilities, based on cassIO.""" +from __future__ import annotations + +import hashlib +import typing +from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar + +import numpy as np + +if typing.TYPE_CHECKING: + from cassandra.cluster import Session + +from langchain.docstore.document import Document +from langchain.embeddings.base import Embeddings +from langchain.vectorstores.base import VectorStore +from langchain.vectorstores.utils import maximal_marginal_relevance + +CVST = TypeVar("CVST", bound="Cassandra") + +# a positive number of seconds to expire entries, or None for no expiration. +CASSANDRA_VECTORSTORE_DEFAULT_TTL_SECONDS = None + + +def _hash(_input: str) -> str: + """Use a deterministic hashing approach.""" + return hashlib.md5(_input.encode()).hexdigest() + + +class Cassandra(VectorStore): + """Wrapper around Cassandra embeddings platform. + + There is no notion of a default table name, since each embedding + function implies its own vector dimension, which is part of the schema. + + Example: + .. code-block:: python + + from langchain.vectorstores import Cassandra + from langchain.embeddings.openai import OpenAIEmbeddings + + embeddings = OpenAIEmbeddings() + session = ... + keyspace = 'my_keyspace' + vectorstore = Cassandra(embeddings, session, keyspace, 'my_doc_archive') + """ + + _embedding_dimension: int | None + + def _getEmbeddingDimension(self) -> int: + if self._embedding_dimension is None: + self._embedding_dimension = len( + self.embedding.embed_query("This is a sample sentence.") + ) + return self._embedding_dimension + + def __init__( + self, + embedding: Embeddings, + session: Session, + keyspace: str, + table_name: str, + ttl_seconds: int | None = CASSANDRA_VECTORSTORE_DEFAULT_TTL_SECONDS, + ) -> None: + try: + from cassio.vector import VectorTable + except (ImportError, ModuleNotFoundError): + raise ImportError( + "Could not import cassio python package. " + "Please install it with `pip install cassio`." + ) + """Create a vector table.""" + self.embedding = embedding + self.session = session + self.keyspace = keyspace + self.table_name = table_name + self.ttl_seconds = ttl_seconds + # + self._embedding_dimension = None + # + self.table = VectorTable( + session=session, + keyspace=keyspace, + table=table_name, + embedding_dimension=self._getEmbeddingDimension(), + auto_id=False, # the `add_texts` contract admits user-provided ids + ) + + def delete_collection(self) -> None: + """ + Just an alias for `clear` + (to better align with other VectorStore implementations). + """ + self.clear() + + def clear(self) -> None: + """Empty the collection.""" + self.table.clear() + + def delete_by_document_id(self, document_id: str) -> None: + return self.table.delete(document_id) + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[str]: + """Run more texts through the embeddings and add to the vectorstore. + + Args: + texts (Iterable[str]): Texts to add to the vectorstore. + metadatas (Optional[List[dict]], optional): Optional list of metadatas. + ids (Optional[List[str]], optional): Optional list of IDs. + + Returns: + List[str]: List of IDs of the added texts. + """ + _texts = list(texts) # lest it be a generator or something + if ids is None: + # unless otherwise specified, we have deterministic IDs: + # re-inserting an existing document will not create a duplicate. + # (and effectively update the metadata) + ids = [_hash(text) for text in _texts] + if metadatas is None: + metadatas = [{} for _ in _texts] + # + ttl_seconds = kwargs.get("ttl_seconds", self.ttl_seconds) + # + embedding_vectors = self.embedding.embed_documents(_texts) + for text, embedding_vector, text_id, metadata in zip( + _texts, embedding_vectors, ids, metadatas + ): + self.table.put( + document=text, + embedding_vector=embedding_vector, + document_id=text_id, + metadata=metadata, + ttl_seconds=ttl_seconds, + ) + # + return ids + + # id-returning search facilities + def similarity_search_with_score_id_by_vector( + self, + embedding: List[float], + k: int = 4, + ) -> List[Tuple[Document, float, str]]: + """Return docs most similar to embedding vector. + + No support for `filter` query (on metadata) along with vector search. + + Args: + embedding (str): Embedding to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + Returns: + List of (Document, score, id), the most similar to the query vector. + """ + hits = self.table.search( + embedding_vector=embedding, + top_k=k, + metric="cos", + metric_threshold=None, + ) + # We stick to 'cos' distance as it can be normalized on a 0-1 axis + # (1=most relevant), as required by this class' contract. + return [ + ( + Document( + page_content=hit["document"], + metadata=hit["metadata"], + ), + 0.5 + 0.5 * hit["distance"], + hit["document_id"], + ) + for hit in hits + ] + + def similarity_search_with_score_id( + self, + query: str, + k: int = 4, + **kwargs: Any, + ) -> List[Tuple[Document, float, str]]: + embedding_vector = self.embedding.embed_query(query) + return self.similarity_search_with_score_id_by_vector( + embedding=embedding_vector, + k=k, + ) + + # id-unaware search facilities + def similarity_search_with_score_by_vector( + self, + embedding: List[float], + k: int = 4, + ) -> List[Tuple[Document, float]]: + """Return docs most similar to embedding vector. + + No support for `filter` query (on metadata) along with vector search. + + Args: + embedding (str): Embedding to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + Returns: + List of (Document, score), the most similar to the query vector. + """ + return [ + (doc, score) + for (doc, score, docId) in self.similarity_search_with_score_id_by_vector( + embedding=embedding, + k=k, + ) + ] + + def similarity_search( + self, + query: str, + k: int = 4, + **kwargs: Any, + ) -> List[Document]: + # + embedding_vector = self.embedding.embed_query(query) + return self.similarity_search_by_vector( + embedding_vector, + k, + **kwargs, + ) + + def similarity_search_by_vector( + self, + embedding: List[float], + k: int = 4, + **kwargs: Any, + ) -> List[Document]: + return [ + doc + for doc, _ in self.similarity_search_with_score_by_vector( + embedding, + k, + ) + ] + + def similarity_search_with_score( + self, + query: str, + k: int = 4, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + embedding_vector = self.embedding.embed_query(query) + return self.similarity_search_with_score_by_vector( + embedding_vector, + k, + ) + + # Even though this is a `_`-method, + # it is apparently used by VectorSearch parent class + # in an exposed method (`similarity_search_with_relevance_scores`). + # So we implement it (hmm). + def _similarity_search_with_relevance_scores( + self, + query: str, + k: int = 4, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + return self.similarity_search_with_score( + query, + k, + **kwargs, + ) + + def max_marginal_relevance_search_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Returns: + List of Documents selected by maximal marginal relevance. + """ + prefetchHits = self.table.search( + embedding_vector=embedding, + top_k=fetch_k, + metric="cos", + metric_threshold=None, + ) + # let the mmr utility pick the *indices* in the above array + mmrChosenIndices = maximal_marginal_relevance( + np.array(embedding, dtype=np.float32), + [pfHit["embedding_vector"] for pfHit in prefetchHits], + k=k, + lambda_mult=lambda_mult, + ) + mmrHits = [ + pfHit + for pfIndex, pfHit in enumerate(prefetchHits) + if pfIndex in mmrChosenIndices + ] + return [ + Document( + page_content=hit["document"], + metadata=hit["metadata"], + ) + for hit in mmrHits + ] + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Optional. + Returns: + List of Documents selected by maximal marginal relevance. + """ + embedding_vector = self.embedding.embed_query(query) + return self.max_marginal_relevance_search_by_vector( + embedding_vector, + k, + fetch_k, + lambda_mult=lambda_mult, + ) + + @classmethod + def from_texts( + cls: Type[CVST], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + **kwargs: Any, + ) -> CVST: + """Create a Cassandra vectorstore from raw texts. + + No support for specifying text IDs + + Returns: + a Cassandra vectorstore. + """ + session: Session = kwargs["session"] + keyspace: str = kwargs["keyspace"] + table_name: str = kwargs["table_name"] + cassandraStore = cls( + embedding=embedding, + session=session, + keyspace=keyspace, + table_name=table_name, + ) + cassandraStore.add_texts(texts=texts, metadatas=metadatas) + return cassandraStore + + @classmethod + def from_documents( + cls: Type[CVST], + documents: List[Document], + embedding: Embeddings, + **kwargs: Any, + ) -> CVST: + """Create a Cassandra vectorstore from a document list. + + No support for specifying text IDs + + Returns: + a Cassandra vectorstore. + """ + texts = [doc.page_content for doc in documents] + metadatas = [doc.metadata for doc in documents] + session: Session = kwargs["session"] + keyspace: str = kwargs["keyspace"] + table_name: str = kwargs["table_name"] + return cls.from_texts( + texts=texts, + metadatas=metadatas, + embedding=embedding, + session=session, + keyspace=keyspace, + table_name=table_name, + ) diff --git a/tests/integration_tests/vectorstores/fake_embeddings.py b/tests/integration_tests/vectorstores/fake_embeddings.py index d6914e8aaa..c818b35dce 100644 --- a/tests/integration_tests/vectorstores/fake_embeddings.py +++ b/tests/integration_tests/vectorstores/fake_embeddings.py @@ -1,4 +1,5 @@ """Fake Embedding class for testing purposes.""" +import math from typing import List from langchain.embeddings.base import Embeddings @@ -45,3 +46,29 @@ class ConsistentFakeEmbeddings(FakeEmbeddings): if text not in self.known_texts: return [float(1.0)] * 9 + [float(0.0)] return [float(1.0)] * 9 + [float(self.known_texts.index(text))] + + +class AngularTwoDimensionalEmbeddings(Embeddings): + """ + From angles (as strings in units of pi) to unit embedding vectors on a circle. + """ + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """ + Make a list of texts into a list of embedding vectors. + """ + return [self.embed_query(text) for text in texts] + + def embed_query(self, text: str) -> List[float]: + """ + Convert input text to a 'vector' (list of floats). + If the text is a number, use it as the angle for the + unit vector in units of pi. + Any other input text becomes the singular result [0, 0] ! + """ + try: + angle = float(text) + return [math.cos(angle * math.pi), math.sin(angle * math.pi)] + except ValueError: + # Assume: just test string, no attention is paid to values. + return [0.0, 0.0] diff --git a/tests/integration_tests/vectorstores/test_cassandra.py b/tests/integration_tests/vectorstores/test_cassandra.py new file mode 100644 index 0000000000..775547e079 --- /dev/null +++ b/tests/integration_tests/vectorstores/test_cassandra.py @@ -0,0 +1,135 @@ +"""Test Cassandra functionality.""" +from typing import List, Optional, Type + +from cassandra.cluster import Cluster + +from langchain.docstore.document import Document +from langchain.vectorstores import Cassandra +from tests.integration_tests.vectorstores.fake_embeddings import ( + AngularTwoDimensionalEmbeddings, + ConsistentFakeEmbeddings, + Embeddings, +) + + +def _vectorstore_from_texts( + texts: List[str], + metadatas: Optional[List[dict]] = None, + embedding_class: Type[Embeddings] = ConsistentFakeEmbeddings, + drop: bool = True, +) -> Cassandra: + keyspace = "vector_test_keyspace" + table_name = "vector_test_table" + # get db connection + cluster = Cluster() + session = cluster.connect() + # ensure keyspace exists + session.execute( + ( + f"CREATE KEYSPACE IF NOT EXISTS {keyspace} " + f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}" + ) + ) + # drop table if required + if drop: + session.execute(f"DROP TABLE IF EXISTS {keyspace}.{table_name}") + # + return Cassandra.from_texts( + texts, + embedding_class(), + metadatas=metadatas, + session=session, + keyspace=keyspace, + table_name=table_name, + ) + + +def test_cassandra() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + docsearch = _vectorstore_from_texts(texts) + output = docsearch.similarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + +def test_cassandra_with_score() -> None: + """Test end to end construction and search with scores and IDs.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = _vectorstore_from_texts(texts, metadatas=metadatas) + output = docsearch.similarity_search_with_score("foo", k=3) + docs = [o[0] for o in output] + scores = [o[1] for o in output] + assert docs == [ + Document(page_content="foo", metadata={"page": 0}), + Document(page_content="bar", metadata={"page": 1}), + Document(page_content="baz", metadata={"page": 2}), + ] + assert scores[0] > scores[1] > scores[2] + + +def test_cassandra_max_marginal_relevance_search() -> None: + """ + Test end to end construction and MMR search. + The embedding function used here ensures `texts` become + the following vectors on a circle (numbered v0 through v3): + + ______ v2 + / \ + / \ v1 + v3 | . | query + \ / v0 + \______/ (N.B. very crude drawing) + + With fetch_k==3 and k==2, when query is at (1, ), + one expects that v2 and v0 are returned (in some order). + """ + texts = ["-0.125", "+0.125", "+0.25", "+1.0"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = _vectorstore_from_texts( + texts, metadatas=metadatas, embedding_class=AngularTwoDimensionalEmbeddings + ) + output = docsearch.max_marginal_relevance_search("0.0", k=2, fetch_k=3) + output_set = { + (mmr_doc.page_content, mmr_doc.metadata["page"]) for mmr_doc in output + } + assert output_set == { + ("+0.25", 2), + ("-0.125", 0), + } + + +def test_cassandra_add_extra() -> None: + """Test end to end construction with further insertions.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = _vectorstore_from_texts(texts, metadatas=metadatas) + + docsearch.add_texts(texts, metadatas) + texts2 = ["foo2", "bar2", "baz2"] + docsearch.add_texts(texts2, metadatas) + + output = docsearch.similarity_search("foo", k=10) + assert len(output) == 6 + + +def test_cassandra_no_drop() -> None: + """Test end to end construction and re-opening the same index.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": i} for i in range(len(texts))] + docsearch = _vectorstore_from_texts(texts, metadatas=metadatas) + del docsearch + + texts2 = ["foo2", "bar2", "baz2"] + docsearch = _vectorstore_from_texts(texts2, metadatas=metadatas, drop=False) + + output = docsearch.similarity_search("foo", k=10) + assert len(output) == 6 + + +# if __name__ == "__main__": +# test_cassandra() +# test_cassandra_with_score() +# test_cassandra_max_marginal_relevance_search() +# test_cassandra_add_extra() +# test_cassandra_no_drop()