forked from Archives/langchain
Vector store support for Cassandra (#6426)
This addresses #6291 adding support for using Cassandra (and compatible databases, such as DataStax Astra DB) as a [Vector Store](https://cwiki.apache.org/confluence/display/CASSANDRA/CEP-30%3A+Approximate+Nearest+Neighbor(ANN)+Vector+Search+via+Storage-Attached+Indexes). A new class `Cassandra` is introduced, which complies with the contract and interface for a vector store, along with the corresponding integration test, a sample notebook and modified dependency toml. Dependencies: the implementation relies on the library `cassio`, which simplifies interacting with Cassandra for ML- and LLM-oriented workloads. CassIO, in turn, uses the `cassandra-driver` low-lever drivers to communicate with the database. The former is added as optional dependency (+ in `extended_testing`), the latter was already in the project. Integration testing relies on a locally-running instance of Cassandra. [Here](https://cassio.org/more_info/#use-a-local-vector-capable-cassandra) a detailed description can be found on how to compile and run it (at the time of writing the feature has not made it yet to a release). During development of the integration tests, I added a new "fake embedding" class for what I consider a more controlled way of testing the MMR search method. Likewise, I had to amend what looked like a glitch in the behaviour of `ConsistentFakeEmbeddings` whereby an `embed_query` call would have bypassed storage of the requested text in the class cache for use in later repeated invocations. @dev2049 might be the right person to tag here for a review. Thank you! --------- Co-authored-by: rlm <pexpresss31@gmail.com>master
parent
cac6e45a67
commit
22af93d851
@ -0,0 +1,269 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "683953b3",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Cassandra\n",
|
||||||
|
"\n",
|
||||||
|
">[Apache Cassandra®](https://cassandra.apache.org) is a NoSQL, row-oriented, highly scalable and highly available database.\n",
|
||||||
|
"\n",
|
||||||
|
"Newest Cassandra releases natively [support](https://cwiki.apache.org/confluence/display/CASSANDRA/CEP-30%3A+Approximate+Nearest+Neighbor(ANN)+Vector+Search+via+Storage-Attached+Indexes) Vector Similarity Search.\n",
|
||||||
|
"\n",
|
||||||
|
"To run this notebook you need either a running Cassandra cluster equipped with Vector Search capabilities (in pre-release at the time of writing) or a DataStax Astra DB instance running in the cloud (you can get one for free at [datastax.com](https://astra.datastax.com)). Check [cassio.org](https://cassio.org/start_here/) for more information."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "b4c41cad-08ef-4f72-a545-2151e4598efe",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"!pip install \"cassio>=0.0.5\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "b7e46bb0",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Please provide database connection parameters and secrets:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "36128a32",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"import getpass\n",
|
||||||
|
"\n",
|
||||||
|
"database_mode = (input('\\n(L)ocal Cassandra or (A)stra DB? ')).upper()\n",
|
||||||
|
"\n",
|
||||||
|
"keyspace_name = input('\\nKeyspace name? ')\n",
|
||||||
|
"\n",
|
||||||
|
"if database_mode == 'A':\n",
|
||||||
|
" ASTRA_DB_APPLICATION_TOKEN = getpass.getpass('\\nAstra DB Token (\"AstraCS:...\") ')\n",
|
||||||
|
" #\n",
|
||||||
|
" ASTRA_DB_SECURE_BUNDLE_PATH = input('Full path to your Secure Connect Bundle? ')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "4f22aac2",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### depending on whether local or cloud-based Astra DB, create the corresponding database connection \"Session\" object"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "677f8576",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from cassandra.cluster import Cluster\n",
|
||||||
|
"from cassandra.auth import PlainTextAuthProvider\n",
|
||||||
|
"\n",
|
||||||
|
"if database_mode == 'L':\n",
|
||||||
|
" cluster = Cluster()\n",
|
||||||
|
" session = cluster.connect()\n",
|
||||||
|
"elif database_mode == 'A':\n",
|
||||||
|
" ASTRA_DB_CLIENT_ID = \"token\"\n",
|
||||||
|
" cluster = Cluster(\n",
|
||||||
|
" cloud={\n",
|
||||||
|
" \"secure_connect_bundle\": ASTRA_DB_SECURE_BUNDLE_PATH,\n",
|
||||||
|
" },\n",
|
||||||
|
" auth_provider=PlainTextAuthProvider(\n",
|
||||||
|
" ASTRA_DB_CLIENT_ID,\n",
|
||||||
|
" ASTRA_DB_APPLICATION_TOKEN,\n",
|
||||||
|
" ),\n",
|
||||||
|
" )\n",
|
||||||
|
" session = cluster.connect()\n",
|
||||||
|
"else:\n",
|
||||||
|
" raise NotImplementedError"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "320af802-9271-46ee-948f-d2453933d44b",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Please provide OpenAI access key\n",
|
||||||
|
"\n",
|
||||||
|
"We want to use `OpenAIEmbeddings` so we have to get the OpenAI API Key."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "ffea66e4-bc23-46a9-9580-b348dfe7b7a7",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"os.environ['OPENAI_API_KEY'] = getpass.getpass('OpenAI API Key:')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "e98a139b",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Creation and usage of the Vector Store"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "aac9563e",
|
||||||
|
"metadata": {
|
||||||
|
"tags": []
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||||
|
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||||
|
"from langchain.vectorstores import Cassandra\n",
|
||||||
|
"from langchain.document_loaders import TextLoader"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "a3c3999a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.document_loaders import TextLoader\n",
|
||||||
|
"loader = TextLoader('../../../state_of_the_union.txt')\n",
|
||||||
|
"documents = loader.load()\n",
|
||||||
|
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||||||
|
"docs = text_splitter.split_documents(documents)\n",
|
||||||
|
"\n",
|
||||||
|
"embedding_function = OpenAIEmbeddings()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6e104aee",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"table_name = 'my_vector_db_table'\n",
|
||||||
|
"\n",
|
||||||
|
"docsearch = Cassandra.from_documents(\n",
|
||||||
|
" documents=docs,\n",
|
||||||
|
" embedding=embedding_function,\n",
|
||||||
|
" session=session,\n",
|
||||||
|
" keyspace=keyspace_name,\n",
|
||||||
|
" table_name=table_name,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||||
|
"docs = docsearch.similarity_search(query)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "f509ee02",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"## if you already have an index, you can load it and use it like this:\n",
|
||||||
|
"\n",
|
||||||
|
"# docsearch_preexisting = Cassandra(\n",
|
||||||
|
"# embedding=embedding_function,\n",
|
||||||
|
"# session=session,\n",
|
||||||
|
"# keyspace=keyspace_name,\n",
|
||||||
|
"# table_name=table_name,\n",
|
||||||
|
"# )\n",
|
||||||
|
"\n",
|
||||||
|
"# docsearch_preexisting.similarity_search(query, k=2)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "9c608226",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"print(docs[0].page_content)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "d46d1452",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Maximal Marginal Relevance Searches\n",
|
||||||
|
"\n",
|
||||||
|
"In addition to using similarity search in the retriever object, you can also use `mmr` as retriever.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "a359ed74",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"retriever = docsearch.as_retriever(search_type=\"mmr\")\n",
|
||||||
|
"matched_docs = retriever.get_relevant_documents(query)\n",
|
||||||
|
"for i, d in enumerate(matched_docs):\n",
|
||||||
|
" print(f\"\\n## Document {i}\\n\")\n",
|
||||||
|
" print(d.page_content)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "7c477287",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Or use `max_marginal_relevance_search` directly:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "9ca82740",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"found_docs = docsearch.max_marginal_relevance_search(query, k=2, fetch_k=10)\n",
|
||||||
|
"for i, doc in enumerate(found_docs):\n",
|
||||||
|
" print(f\"{i + 1}.\", doc.page_content, \"\\n\")"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.8.10"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -0,0 +1,402 @@
|
|||||||
|
"""Wrapper around Cassandra vector-store capabilities, based on cassIO."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import typing
|
||||||
|
from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from cassandra.cluster import Session
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
|
CVST = TypeVar("CVST", bound="Cassandra")
|
||||||
|
|
||||||
|
# a positive number of seconds to expire entries, or None for no expiration.
|
||||||
|
CASSANDRA_VECTORSTORE_DEFAULT_TTL_SECONDS = None
|
||||||
|
|
||||||
|
|
||||||
|
def _hash(_input: str) -> str:
|
||||||
|
"""Use a deterministic hashing approach."""
|
||||||
|
return hashlib.md5(_input.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
class Cassandra(VectorStore):
|
||||||
|
"""Wrapper around Cassandra embeddings platform.
|
||||||
|
|
||||||
|
There is no notion of a default table name, since each embedding
|
||||||
|
function implies its own vector dimension, which is part of the schema.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.vectorstores import Cassandra
|
||||||
|
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||||
|
|
||||||
|
embeddings = OpenAIEmbeddings()
|
||||||
|
session = ...
|
||||||
|
keyspace = 'my_keyspace'
|
||||||
|
vectorstore = Cassandra(embeddings, session, keyspace, 'my_doc_archive')
|
||||||
|
"""
|
||||||
|
|
||||||
|
_embedding_dimension: int | None
|
||||||
|
|
||||||
|
def _getEmbeddingDimension(self) -> int:
|
||||||
|
if self._embedding_dimension is None:
|
||||||
|
self._embedding_dimension = len(
|
||||||
|
self.embedding.embed_query("This is a sample sentence.")
|
||||||
|
)
|
||||||
|
return self._embedding_dimension
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding: Embeddings,
|
||||||
|
session: Session,
|
||||||
|
keyspace: str,
|
||||||
|
table_name: str,
|
||||||
|
ttl_seconds: int | None = CASSANDRA_VECTORSTORE_DEFAULT_TTL_SECONDS,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
from cassio.vector import VectorTable
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import cassio python package. "
|
||||||
|
"Please install it with `pip install cassio`."
|
||||||
|
)
|
||||||
|
"""Create a vector table."""
|
||||||
|
self.embedding = embedding
|
||||||
|
self.session = session
|
||||||
|
self.keyspace = keyspace
|
||||||
|
self.table_name = table_name
|
||||||
|
self.ttl_seconds = ttl_seconds
|
||||||
|
#
|
||||||
|
self._embedding_dimension = None
|
||||||
|
#
|
||||||
|
self.table = VectorTable(
|
||||||
|
session=session,
|
||||||
|
keyspace=keyspace,
|
||||||
|
table=table_name,
|
||||||
|
embedding_dimension=self._getEmbeddingDimension(),
|
||||||
|
auto_id=False, # the `add_texts` contract admits user-provided ids
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_collection(self) -> None:
|
||||||
|
"""
|
||||||
|
Just an alias for `clear`
|
||||||
|
(to better align with other VectorStore implementations).
|
||||||
|
"""
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Empty the collection."""
|
||||||
|
self.table.clear()
|
||||||
|
|
||||||
|
def delete_by_document_id(self, document_id: str) -> None:
|
||||||
|
return self.table.delete(document_id)
|
||||||
|
|
||||||
|
def add_texts(
|
||||||
|
self,
|
||||||
|
texts: Iterable[str],
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
ids: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Run more texts through the embeddings and add to the vectorstore.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts (Iterable[str]): Texts to add to the vectorstore.
|
||||||
|
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
|
||||||
|
ids (Optional[List[str]], optional): Optional list of IDs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: List of IDs of the added texts.
|
||||||
|
"""
|
||||||
|
_texts = list(texts) # lest it be a generator or something
|
||||||
|
if ids is None:
|
||||||
|
# unless otherwise specified, we have deterministic IDs:
|
||||||
|
# re-inserting an existing document will not create a duplicate.
|
||||||
|
# (and effectively update the metadata)
|
||||||
|
ids = [_hash(text) for text in _texts]
|
||||||
|
if metadatas is None:
|
||||||
|
metadatas = [{} for _ in _texts]
|
||||||
|
#
|
||||||
|
ttl_seconds = kwargs.get("ttl_seconds", self.ttl_seconds)
|
||||||
|
#
|
||||||
|
embedding_vectors = self.embedding.embed_documents(_texts)
|
||||||
|
for text, embedding_vector, text_id, metadata in zip(
|
||||||
|
_texts, embedding_vectors, ids, metadatas
|
||||||
|
):
|
||||||
|
self.table.put(
|
||||||
|
document=text,
|
||||||
|
embedding_vector=embedding_vector,
|
||||||
|
document_id=text_id,
|
||||||
|
metadata=metadata,
|
||||||
|
ttl_seconds=ttl_seconds,
|
||||||
|
)
|
||||||
|
#
|
||||||
|
return ids
|
||||||
|
|
||||||
|
# id-returning search facilities
|
||||||
|
def similarity_search_with_score_id_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
) -> List[Tuple[Document, float, str]]:
|
||||||
|
"""Return docs most similar to embedding vector.
|
||||||
|
|
||||||
|
No support for `filter` query (on metadata) along with vector search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding (str): Embedding to look up documents similar to.
|
||||||
|
k (int): Number of Documents to return. Defaults to 4.
|
||||||
|
Returns:
|
||||||
|
List of (Document, score, id), the most similar to the query vector.
|
||||||
|
"""
|
||||||
|
hits = self.table.search(
|
||||||
|
embedding_vector=embedding,
|
||||||
|
top_k=k,
|
||||||
|
metric="cos",
|
||||||
|
metric_threshold=None,
|
||||||
|
)
|
||||||
|
# We stick to 'cos' distance as it can be normalized on a 0-1 axis
|
||||||
|
# (1=most relevant), as required by this class' contract.
|
||||||
|
return [
|
||||||
|
(
|
||||||
|
Document(
|
||||||
|
page_content=hit["document"],
|
||||||
|
metadata=hit["metadata"],
|
||||||
|
),
|
||||||
|
0.5 + 0.5 * hit["distance"],
|
||||||
|
hit["document_id"],
|
||||||
|
)
|
||||||
|
for hit in hits
|
||||||
|
]
|
||||||
|
|
||||||
|
def similarity_search_with_score_id(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Tuple[Document, float, str]]:
|
||||||
|
embedding_vector = self.embedding.embed_query(query)
|
||||||
|
return self.similarity_search_with_score_id_by_vector(
|
||||||
|
embedding=embedding_vector,
|
||||||
|
k=k,
|
||||||
|
)
|
||||||
|
|
||||||
|
# id-unaware search facilities
|
||||||
|
def similarity_search_with_score_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
"""Return docs most similar to embedding vector.
|
||||||
|
|
||||||
|
No support for `filter` query (on metadata) along with vector search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding (str): Embedding to look up documents similar to.
|
||||||
|
k (int): Number of Documents to return. Defaults to 4.
|
||||||
|
Returns:
|
||||||
|
List of (Document, score), the most similar to the query vector.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
(doc, score)
|
||||||
|
for (doc, score, docId) in self.similarity_search_with_score_id_by_vector(
|
||||||
|
embedding=embedding,
|
||||||
|
k=k,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def similarity_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
#
|
||||||
|
embedding_vector = self.embedding.embed_query(query)
|
||||||
|
return self.similarity_search_by_vector(
|
||||||
|
embedding_vector,
|
||||||
|
k,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def similarity_search_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
return [
|
||||||
|
doc
|
||||||
|
for doc, _ in self.similarity_search_with_score_by_vector(
|
||||||
|
embedding,
|
||||||
|
k,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def similarity_search_with_score(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
embedding_vector = self.embedding.embed_query(query)
|
||||||
|
return self.similarity_search_with_score_by_vector(
|
||||||
|
embedding_vector,
|
||||||
|
k,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Even though this is a `_`-method,
|
||||||
|
# it is apparently used by VectorSearch parent class
|
||||||
|
# in an exposed method (`similarity_search_with_relevance_scores`).
|
||||||
|
# So we implement it (hmm).
|
||||||
|
def _similarity_search_with_relevance_scores(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Tuple[Document, float]]:
|
||||||
|
return self.similarity_search_with_score(
|
||||||
|
query,
|
||||||
|
k,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def max_marginal_relevance_search_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs selected using the maximal marginal relevance.
|
||||||
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||||
|
among selected documents.
|
||||||
|
Args:
|
||||||
|
embedding: Embedding to look up documents similar to.
|
||||||
|
k: Number of Documents to return.
|
||||||
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||||
|
lambda_mult: Number between 0 and 1 that determines the degree
|
||||||
|
of diversity among the results with 0 corresponding
|
||||||
|
to maximum diversity and 1 to minimum diversity.
|
||||||
|
Returns:
|
||||||
|
List of Documents selected by maximal marginal relevance.
|
||||||
|
"""
|
||||||
|
prefetchHits = self.table.search(
|
||||||
|
embedding_vector=embedding,
|
||||||
|
top_k=fetch_k,
|
||||||
|
metric="cos",
|
||||||
|
metric_threshold=None,
|
||||||
|
)
|
||||||
|
# let the mmr utility pick the *indices* in the above array
|
||||||
|
mmrChosenIndices = maximal_marginal_relevance(
|
||||||
|
np.array(embedding, dtype=np.float32),
|
||||||
|
[pfHit["embedding_vector"] for pfHit in prefetchHits],
|
||||||
|
k=k,
|
||||||
|
lambda_mult=lambda_mult,
|
||||||
|
)
|
||||||
|
mmrHits = [
|
||||||
|
pfHit
|
||||||
|
for pfIndex, pfHit in enumerate(prefetchHits)
|
||||||
|
if pfIndex in mmrChosenIndices
|
||||||
|
]
|
||||||
|
return [
|
||||||
|
Document(
|
||||||
|
page_content=hit["document"],
|
||||||
|
metadata=hit["metadata"],
|
||||||
|
)
|
||||||
|
for hit in mmrHits
|
||||||
|
]
|
||||||
|
|
||||||
|
def max_marginal_relevance_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
lambda_mult: float = 0.5,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs selected using the maximal marginal relevance.
|
||||||
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||||
|
among selected documents.
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: Number of Documents to return.
|
||||||
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||||
|
lambda_mult: Number between 0 and 1 that determines the degree
|
||||||
|
of diversity among the results with 0 corresponding
|
||||||
|
to maximum diversity and 1 to minimum diversity.
|
||||||
|
Optional.
|
||||||
|
Returns:
|
||||||
|
List of Documents selected by maximal marginal relevance.
|
||||||
|
"""
|
||||||
|
embedding_vector = self.embedding.embed_query(query)
|
||||||
|
return self.max_marginal_relevance_search_by_vector(
|
||||||
|
embedding_vector,
|
||||||
|
k,
|
||||||
|
fetch_k,
|
||||||
|
lambda_mult=lambda_mult,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_texts(
|
||||||
|
cls: Type[CVST],
|
||||||
|
texts: List[str],
|
||||||
|
embedding: Embeddings,
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> CVST:
|
||||||
|
"""Create a Cassandra vectorstore from raw texts.
|
||||||
|
|
||||||
|
No support for specifying text IDs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a Cassandra vectorstore.
|
||||||
|
"""
|
||||||
|
session: Session = kwargs["session"]
|
||||||
|
keyspace: str = kwargs["keyspace"]
|
||||||
|
table_name: str = kwargs["table_name"]
|
||||||
|
cassandraStore = cls(
|
||||||
|
embedding=embedding,
|
||||||
|
session=session,
|
||||||
|
keyspace=keyspace,
|
||||||
|
table_name=table_name,
|
||||||
|
)
|
||||||
|
cassandraStore.add_texts(texts=texts, metadatas=metadatas)
|
||||||
|
return cassandraStore
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_documents(
|
||||||
|
cls: Type[CVST],
|
||||||
|
documents: List[Document],
|
||||||
|
embedding: Embeddings,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> CVST:
|
||||||
|
"""Create a Cassandra vectorstore from a document list.
|
||||||
|
|
||||||
|
No support for specifying text IDs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a Cassandra vectorstore.
|
||||||
|
"""
|
||||||
|
texts = [doc.page_content for doc in documents]
|
||||||
|
metadatas = [doc.metadata for doc in documents]
|
||||||
|
session: Session = kwargs["session"]
|
||||||
|
keyspace: str = kwargs["keyspace"]
|
||||||
|
table_name: str = kwargs["table_name"]
|
||||||
|
return cls.from_texts(
|
||||||
|
texts=texts,
|
||||||
|
metadatas=metadatas,
|
||||||
|
embedding=embedding,
|
||||||
|
session=session,
|
||||||
|
keyspace=keyspace,
|
||||||
|
table_name=table_name,
|
||||||
|
)
|
@ -0,0 +1,135 @@
|
|||||||
|
"""Test Cassandra functionality."""
|
||||||
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
|
from cassandra.cluster import Cluster
|
||||||
|
|
||||||
|
from langchain.docstore.document import Document
|
||||||
|
from langchain.vectorstores import Cassandra
|
||||||
|
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||||
|
AngularTwoDimensionalEmbeddings,
|
||||||
|
ConsistentFakeEmbeddings,
|
||||||
|
Embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _vectorstore_from_texts(
|
||||||
|
texts: List[str],
|
||||||
|
metadatas: Optional[List[dict]] = None,
|
||||||
|
embedding_class: Type[Embeddings] = ConsistentFakeEmbeddings,
|
||||||
|
drop: bool = True,
|
||||||
|
) -> Cassandra:
|
||||||
|
keyspace = "vector_test_keyspace"
|
||||||
|
table_name = "vector_test_table"
|
||||||
|
# get db connection
|
||||||
|
cluster = Cluster()
|
||||||
|
session = cluster.connect()
|
||||||
|
# ensure keyspace exists
|
||||||
|
session.execute(
|
||||||
|
(
|
||||||
|
f"CREATE KEYSPACE IF NOT EXISTS {keyspace} "
|
||||||
|
f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# drop table if required
|
||||||
|
if drop:
|
||||||
|
session.execute(f"DROP TABLE IF EXISTS {keyspace}.{table_name}")
|
||||||
|
#
|
||||||
|
return Cassandra.from_texts(
|
||||||
|
texts,
|
||||||
|
embedding_class(),
|
||||||
|
metadatas=metadatas,
|
||||||
|
session=session,
|
||||||
|
keyspace=keyspace,
|
||||||
|
table_name=table_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cassandra() -> None:
|
||||||
|
"""Test end to end construction and search."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
docsearch = _vectorstore_from_texts(texts)
|
||||||
|
output = docsearch.similarity_search("foo", k=1)
|
||||||
|
assert output == [Document(page_content="foo")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_cassandra_with_score() -> None:
|
||||||
|
"""Test end to end construction and search with scores and IDs."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch = _vectorstore_from_texts(texts, metadatas=metadatas)
|
||||||
|
output = docsearch.similarity_search_with_score("foo", k=3)
|
||||||
|
docs = [o[0] for o in output]
|
||||||
|
scores = [o[1] for o in output]
|
||||||
|
assert docs == [
|
||||||
|
Document(page_content="foo", metadata={"page": 0}),
|
||||||
|
Document(page_content="bar", metadata={"page": 1}),
|
||||||
|
Document(page_content="baz", metadata={"page": 2}),
|
||||||
|
]
|
||||||
|
assert scores[0] > scores[1] > scores[2]
|
||||||
|
|
||||||
|
|
||||||
|
def test_cassandra_max_marginal_relevance_search() -> None:
|
||||||
|
"""
|
||||||
|
Test end to end construction and MMR search.
|
||||||
|
The embedding function used here ensures `texts` become
|
||||||
|
the following vectors on a circle (numbered v0 through v3):
|
||||||
|
|
||||||
|
______ v2
|
||||||
|
/ \
|
||||||
|
/ \ v1
|
||||||
|
v3 | . | query
|
||||||
|
\ / v0
|
||||||
|
\______/ (N.B. very crude drawing)
|
||||||
|
|
||||||
|
With fetch_k==3 and k==2, when query is at (1, ),
|
||||||
|
one expects that v2 and v0 are returned (in some order).
|
||||||
|
"""
|
||||||
|
texts = ["-0.125", "+0.125", "+0.25", "+1.0"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch = _vectorstore_from_texts(
|
||||||
|
texts, metadatas=metadatas, embedding_class=AngularTwoDimensionalEmbeddings
|
||||||
|
)
|
||||||
|
output = docsearch.max_marginal_relevance_search("0.0", k=2, fetch_k=3)
|
||||||
|
output_set = {
|
||||||
|
(mmr_doc.page_content, mmr_doc.metadata["page"]) for mmr_doc in output
|
||||||
|
}
|
||||||
|
assert output_set == {
|
||||||
|
("+0.25", 2),
|
||||||
|
("-0.125", 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_cassandra_add_extra() -> None:
|
||||||
|
"""Test end to end construction with further insertions."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch = _vectorstore_from_texts(texts, metadatas=metadatas)
|
||||||
|
|
||||||
|
docsearch.add_texts(texts, metadatas)
|
||||||
|
texts2 = ["foo2", "bar2", "baz2"]
|
||||||
|
docsearch.add_texts(texts2, metadatas)
|
||||||
|
|
||||||
|
output = docsearch.similarity_search("foo", k=10)
|
||||||
|
assert len(output) == 6
|
||||||
|
|
||||||
|
|
||||||
|
def test_cassandra_no_drop() -> None:
|
||||||
|
"""Test end to end construction and re-opening the same index."""
|
||||||
|
texts = ["foo", "bar", "baz"]
|
||||||
|
metadatas = [{"page": i} for i in range(len(texts))]
|
||||||
|
docsearch = _vectorstore_from_texts(texts, metadatas=metadatas)
|
||||||
|
del docsearch
|
||||||
|
|
||||||
|
texts2 = ["foo2", "bar2", "baz2"]
|
||||||
|
docsearch = _vectorstore_from_texts(texts2, metadatas=metadatas, drop=False)
|
||||||
|
|
||||||
|
output = docsearch.similarity_search("foo", k=10)
|
||||||
|
assert len(output) == 6
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# test_cassandra()
|
||||||
|
# test_cassandra_with_score()
|
||||||
|
# test_cassandra_max_marginal_relevance_search()
|
||||||
|
# test_cassandra_add_extra()
|
||||||
|
# test_cassandra_no_drop()
|
Loading…
Reference in New Issue