mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Vector store support for Cassandra (#6426)
This addresses #6291 adding support for using Cassandra (and compatible databases, such as DataStax Astra DB) as a [Vector Store](https://cwiki.apache.org/confluence/display/CASSANDRA/CEP-30%3A+Approximate+Nearest+Neighbor(ANN)+Vector+Search+via+Storage-Attached+Indexes). A new class `Cassandra` is introduced, which complies with the contract and interface for a vector store, along with the corresponding integration test, a sample notebook and modified dependency toml. Dependencies: the implementation relies on the library `cassio`, which simplifies interacting with Cassandra for ML- and LLM-oriented workloads. CassIO, in turn, uses the `cassandra-driver` low-lever drivers to communicate with the database. The former is added as optional dependency (+ in `extended_testing`), the latter was already in the project. Integration testing relies on a locally-running instance of Cassandra. [Here](https://cassio.org/more_info/#use-a-local-vector-capable-cassandra) a detailed description can be found on how to compile and run it (at the time of writing the feature has not made it yet to a release). During development of the integration tests, I added a new "fake embedding" class for what I consider a more controlled way of testing the MMR search method. Likewise, I had to amend what looked like a glitch in the behaviour of `ConsistentFakeEmbeddings` whereby an `embed_query` call would have bypassed storage of the requested text in the class cache for use in later repeated invocations. @dev2049 might be the right person to tag here for a review. Thank you! --------- Co-authored-by: rlm <pexpresss31@gmail.com>
This commit is contained in:
parent
cac6e45a67
commit
22af93d851
@ -0,0 +1,269 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "683953b3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Cassandra\n",
|
||||
"\n",
|
||||
">[Apache Cassandra®](https://cassandra.apache.org) is a NoSQL, row-oriented, highly scalable and highly available database.\n",
|
||||
"\n",
|
||||
"Newest Cassandra releases natively [support](https://cwiki.apache.org/confluence/display/CASSANDRA/CEP-30%3A+Approximate+Nearest+Neighbor(ANN)+Vector+Search+via+Storage-Attached+Indexes) Vector Similarity Search.\n",
|
||||
"\n",
|
||||
"To run this notebook you need either a running Cassandra cluster equipped with Vector Search capabilities (in pre-release at the time of writing) or a DataStax Astra DB instance running in the cloud (you can get one for free at [datastax.com](https://astra.datastax.com)). Check [cassio.org](https://cassio.org/start_here/) for more information."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b4c41cad-08ef-4f72-a545-2151e4598efe",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install \"cassio>=0.0.5\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b7e46bb0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Please provide database connection parameters and secrets:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "36128a32",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import getpass\n",
|
||||
"\n",
|
||||
"database_mode = (input('\\n(L)ocal Cassandra or (A)stra DB? ')).upper()\n",
|
||||
"\n",
|
||||
"keyspace_name = input('\\nKeyspace name? ')\n",
|
||||
"\n",
|
||||
"if database_mode == 'A':\n",
|
||||
" ASTRA_DB_APPLICATION_TOKEN = getpass.getpass('\\nAstra DB Token (\"AstraCS:...\") ')\n",
|
||||
" #\n",
|
||||
" ASTRA_DB_SECURE_BUNDLE_PATH = input('Full path to your Secure Connect Bundle? ')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4f22aac2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### depending on whether local or cloud-based Astra DB, create the corresponding database connection \"Session\" object"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "677f8576",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from cassandra.cluster import Cluster\n",
|
||||
"from cassandra.auth import PlainTextAuthProvider\n",
|
||||
"\n",
|
||||
"if database_mode == 'L':\n",
|
||||
" cluster = Cluster()\n",
|
||||
" session = cluster.connect()\n",
|
||||
"elif database_mode == 'A':\n",
|
||||
" ASTRA_DB_CLIENT_ID = \"token\"\n",
|
||||
" cluster = Cluster(\n",
|
||||
" cloud={\n",
|
||||
" \"secure_connect_bundle\": ASTRA_DB_SECURE_BUNDLE_PATH,\n",
|
||||
" },\n",
|
||||
" auth_provider=PlainTextAuthProvider(\n",
|
||||
" ASTRA_DB_CLIENT_ID,\n",
|
||||
" ASTRA_DB_APPLICATION_TOKEN,\n",
|
||||
" ),\n",
|
||||
" )\n",
|
||||
" session = cluster.connect()\n",
|
||||
"else:\n",
|
||||
" raise NotImplementedError"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "320af802-9271-46ee-948f-d2453933d44b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Please provide OpenAI access key\n",
|
||||
"\n",
|
||||
"We want to use `OpenAIEmbeddings` so we have to get the OpenAI API Key."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ffea66e4-bc23-46a9-9580-b348dfe7b7a7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"os.environ['OPENAI_API_KEY'] = getpass.getpass('OpenAI API Key:')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e98a139b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Creation and usage of the Vector Store"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "aac9563e",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||
"from langchain.vectorstores import Cassandra\n",
|
||||
"from langchain.document_loaders import TextLoader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a3c3999a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.document_loaders import TextLoader\n",
|
||||
"loader = TextLoader('../../../state_of_the_union.txt')\n",
|
||||
"documents = loader.load()\n",
|
||||
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
||||
"docs = text_splitter.split_documents(documents)\n",
|
||||
"\n",
|
||||
"embedding_function = OpenAIEmbeddings()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6e104aee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"table_name = 'my_vector_db_table'\n",
|
||||
"\n",
|
||||
"docsearch = Cassandra.from_documents(\n",
|
||||
" documents=docs,\n",
|
||||
" embedding=embedding_function,\n",
|
||||
" session=session,\n",
|
||||
" keyspace=keyspace_name,\n",
|
||||
" table_name=table_name,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
||||
"docs = docsearch.similarity_search(query)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f509ee02",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"## if you already have an index, you can load it and use it like this:\n",
|
||||
"\n",
|
||||
"# docsearch_preexisting = Cassandra(\n",
|
||||
"# embedding=embedding_function,\n",
|
||||
"# session=session,\n",
|
||||
"# keyspace=keyspace_name,\n",
|
||||
"# table_name=table_name,\n",
|
||||
"# )\n",
|
||||
"\n",
|
||||
"# docsearch_preexisting.similarity_search(query, k=2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9c608226",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(docs[0].page_content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d46d1452",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Maximal Marginal Relevance Searches\n",
|
||||
"\n",
|
||||
"In addition to using similarity search in the retriever object, you can also use `mmr` as retriever.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a359ed74",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"retriever = docsearch.as_retriever(search_type=\"mmr\")\n",
|
||||
"matched_docs = retriever.get_relevant_documents(query)\n",
|
||||
"for i, d in enumerate(matched_docs):\n",
|
||||
" print(f\"\\n## Document {i}\\n\")\n",
|
||||
" print(d.page_content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7c477287",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Or use `max_marginal_relevance_search` directly:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9ca82740",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"found_docs = docsearch.max_marginal_relevance_search(query, k=2, fetch_k=10)\n",
|
||||
"for i, doc in enumerate(found_docs):\n",
|
||||
" print(f\"{i + 1}.\", doc.page_content, \"\\n\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -9,6 +9,7 @@ from langchain.vectorstores.atlas import AtlasDB
|
||||
from langchain.vectorstores.awadb import AwaDB
|
||||
from langchain.vectorstores.azuresearch import AzureSearch
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.cassandra import Cassandra
|
||||
from langchain.vectorstores.chroma import Chroma
|
||||
from langchain.vectorstores.clickhouse import Clickhouse, ClickhouseSettings
|
||||
from langchain.vectorstores.deeplake import DeepLake
|
||||
@ -43,6 +44,7 @@ __all__ = [
|
||||
"AtlasDB",
|
||||
"AwaDB",
|
||||
"AzureSearch",
|
||||
"Cassandra",
|
||||
"Chroma",
|
||||
"Clickhouse",
|
||||
"ClickhouseSettings",
|
||||
|
402
langchain/vectorstores/cassandra.py
Normal file
402
langchain/vectorstores/cassandra.py
Normal file
@ -0,0 +1,402 @@
|
||||
"""Wrapper around Cassandra vector-store capabilities, based on cassIO."""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import typing
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from cassandra.cluster import Session
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
CVST = TypeVar("CVST", bound="Cassandra")
|
||||
|
||||
# a positive number of seconds to expire entries, or None for no expiration.
|
||||
CASSANDRA_VECTORSTORE_DEFAULT_TTL_SECONDS = None
|
||||
|
||||
|
||||
def _hash(_input: str) -> str:
|
||||
"""Use a deterministic hashing approach."""
|
||||
return hashlib.md5(_input.encode()).hexdigest()
|
||||
|
||||
|
||||
class Cassandra(VectorStore):
|
||||
"""Wrapper around Cassandra embeddings platform.
|
||||
|
||||
There is no notion of a default table name, since each embedding
|
||||
function implies its own vector dimension, which is part of the schema.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.vectorstores import Cassandra
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
|
||||
embeddings = OpenAIEmbeddings()
|
||||
session = ...
|
||||
keyspace = 'my_keyspace'
|
||||
vectorstore = Cassandra(embeddings, session, keyspace, 'my_doc_archive')
|
||||
"""
|
||||
|
||||
_embedding_dimension: int | None
|
||||
|
||||
def _getEmbeddingDimension(self) -> int:
|
||||
if self._embedding_dimension is None:
|
||||
self._embedding_dimension = len(
|
||||
self.embedding.embed_query("This is a sample sentence.")
|
||||
)
|
||||
return self._embedding_dimension
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding: Embeddings,
|
||||
session: Session,
|
||||
keyspace: str,
|
||||
table_name: str,
|
||||
ttl_seconds: int | None = CASSANDRA_VECTORSTORE_DEFAULT_TTL_SECONDS,
|
||||
) -> None:
|
||||
try:
|
||||
from cassio.vector import VectorTable
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
raise ImportError(
|
||||
"Could not import cassio python package. "
|
||||
"Please install it with `pip install cassio`."
|
||||
)
|
||||
"""Create a vector table."""
|
||||
self.embedding = embedding
|
||||
self.session = session
|
||||
self.keyspace = keyspace
|
||||
self.table_name = table_name
|
||||
self.ttl_seconds = ttl_seconds
|
||||
#
|
||||
self._embedding_dimension = None
|
||||
#
|
||||
self.table = VectorTable(
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
table=table_name,
|
||||
embedding_dimension=self._getEmbeddingDimension(),
|
||||
auto_id=False, # the `add_texts` contract admits user-provided ids
|
||||
)
|
||||
|
||||
def delete_collection(self) -> None:
|
||||
"""
|
||||
Just an alias for `clear`
|
||||
(to better align with other VectorStore implementations).
|
||||
"""
|
||||
self.clear()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Empty the collection."""
|
||||
self.table.clear()
|
||||
|
||||
def delete_by_document_id(self, document_id: str) -> None:
|
||||
return self.table.delete(document_id)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Run more texts through the embeddings and add to the vectorstore.
|
||||
|
||||
Args:
|
||||
texts (Iterable[str]): Texts to add to the vectorstore.
|
||||
metadatas (Optional[List[dict]], optional): Optional list of metadatas.
|
||||
ids (Optional[List[str]], optional): Optional list of IDs.
|
||||
|
||||
Returns:
|
||||
List[str]: List of IDs of the added texts.
|
||||
"""
|
||||
_texts = list(texts) # lest it be a generator or something
|
||||
if ids is None:
|
||||
# unless otherwise specified, we have deterministic IDs:
|
||||
# re-inserting an existing document will not create a duplicate.
|
||||
# (and effectively update the metadata)
|
||||
ids = [_hash(text) for text in _texts]
|
||||
if metadatas is None:
|
||||
metadatas = [{} for _ in _texts]
|
||||
#
|
||||
ttl_seconds = kwargs.get("ttl_seconds", self.ttl_seconds)
|
||||
#
|
||||
embedding_vectors = self.embedding.embed_documents(_texts)
|
||||
for text, embedding_vector, text_id, metadata in zip(
|
||||
_texts, embedding_vectors, ids, metadatas
|
||||
):
|
||||
self.table.put(
|
||||
document=text,
|
||||
embedding_vector=embedding_vector,
|
||||
document_id=text_id,
|
||||
metadata=metadata,
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
#
|
||||
return ids
|
||||
|
||||
# id-returning search facilities
|
||||
def similarity_search_with_score_id_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
) -> List[Tuple[Document, float, str]]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
|
||||
No support for `filter` query (on metadata) along with vector search.
|
||||
|
||||
Args:
|
||||
embedding (str): Embedding to look up documents similar to.
|
||||
k (int): Number of Documents to return. Defaults to 4.
|
||||
Returns:
|
||||
List of (Document, score, id), the most similar to the query vector.
|
||||
"""
|
||||
hits = self.table.search(
|
||||
embedding_vector=embedding,
|
||||
top_k=k,
|
||||
metric="cos",
|
||||
metric_threshold=None,
|
||||
)
|
||||
# We stick to 'cos' distance as it can be normalized on a 0-1 axis
|
||||
# (1=most relevant), as required by this class' contract.
|
||||
return [
|
||||
(
|
||||
Document(
|
||||
page_content=hit["document"],
|
||||
metadata=hit["metadata"],
|
||||
),
|
||||
0.5 + 0.5 * hit["distance"],
|
||||
hit["document_id"],
|
||||
)
|
||||
for hit in hits
|
||||
]
|
||||
|
||||
def similarity_search_with_score_id(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float, str]]:
|
||||
embedding_vector = self.embedding.embed_query(query)
|
||||
return self.similarity_search_with_score_id_by_vector(
|
||||
embedding=embedding_vector,
|
||||
k=k,
|
||||
)
|
||||
|
||||
# id-unaware search facilities
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
|
||||
No support for `filter` query (on metadata) along with vector search.
|
||||
|
||||
Args:
|
||||
embedding (str): Embedding to look up documents similar to.
|
||||
k (int): Number of Documents to return. Defaults to 4.
|
||||
Returns:
|
||||
List of (Document, score), the most similar to the query vector.
|
||||
"""
|
||||
return [
|
||||
(doc, score)
|
||||
for (doc, score, docId) in self.similarity_search_with_score_id_by_vector(
|
||||
embedding=embedding,
|
||||
k=k,
|
||||
)
|
||||
]
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
#
|
||||
embedding_vector = self.embedding.embed_query(query)
|
||||
return self.similarity_search_by_vector(
|
||||
embedding_vector,
|
||||
k,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
return [
|
||||
doc
|
||||
for doc, _ in self.similarity_search_with_score_by_vector(
|
||||
embedding,
|
||||
k,
|
||||
)
|
||||
]
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
embedding_vector = self.embedding.embed_query(query)
|
||||
return self.similarity_search_with_score_by_vector(
|
||||
embedding_vector,
|
||||
k,
|
||||
)
|
||||
|
||||
# Even though this is a `_`-method,
|
||||
# it is apparently used by VectorSearch parent class
|
||||
# in an exposed method (`similarity_search_with_relevance_scores`).
|
||||
# So we implement it (hmm).
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
return self.similarity_search_with_score(
|
||||
query,
|
||||
k,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
prefetchHits = self.table.search(
|
||||
embedding_vector=embedding,
|
||||
top_k=fetch_k,
|
||||
metric="cos",
|
||||
metric_threshold=None,
|
||||
)
|
||||
# let the mmr utility pick the *indices* in the above array
|
||||
mmrChosenIndices = maximal_marginal_relevance(
|
||||
np.array(embedding, dtype=np.float32),
|
||||
[pfHit["embedding_vector"] for pfHit in prefetchHits],
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
mmrHits = [
|
||||
pfHit
|
||||
for pfIndex, pfHit in enumerate(prefetchHits)
|
||||
if pfIndex in mmrChosenIndices
|
||||
]
|
||||
return [
|
||||
Document(
|
||||
page_content=hit["document"],
|
||||
metadata=hit["metadata"],
|
||||
)
|
||||
for hit in mmrHits
|
||||
]
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Optional.
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
embedding_vector = self.embedding.embed_query(query)
|
||||
return self.max_marginal_relevance_search_by_vector(
|
||||
embedding_vector,
|
||||
k,
|
||||
fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type[CVST],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> CVST:
|
||||
"""Create a Cassandra vectorstore from raw texts.
|
||||
|
||||
No support for specifying text IDs
|
||||
|
||||
Returns:
|
||||
a Cassandra vectorstore.
|
||||
"""
|
||||
session: Session = kwargs["session"]
|
||||
keyspace: str = kwargs["keyspace"]
|
||||
table_name: str = kwargs["table_name"]
|
||||
cassandraStore = cls(
|
||||
embedding=embedding,
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
table_name=table_name,
|
||||
)
|
||||
cassandraStore.add_texts(texts=texts, metadatas=metadatas)
|
||||
return cassandraStore
|
||||
|
||||
@classmethod
|
||||
def from_documents(
|
||||
cls: Type[CVST],
|
||||
documents: List[Document],
|
||||
embedding: Embeddings,
|
||||
**kwargs: Any,
|
||||
) -> CVST:
|
||||
"""Create a Cassandra vectorstore from a document list.
|
||||
|
||||
No support for specifying text IDs
|
||||
|
||||
Returns:
|
||||
a Cassandra vectorstore.
|
||||
"""
|
||||
texts = [doc.page_content for doc in documents]
|
||||
metadatas = [doc.metadata for doc in documents]
|
||||
session: Session = kwargs["session"]
|
||||
keyspace: str = kwargs["keyspace"]
|
||||
table_name: str = kwargs["table_name"]
|
||||
return cls.from_texts(
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
embedding=embedding,
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
table_name=table_name,
|
||||
)
|
@ -1,4 +1,5 @@
|
||||
"""Fake Embedding class for testing purposes."""
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@ -45,3 +46,29 @@ class ConsistentFakeEmbeddings(FakeEmbeddings):
|
||||
if text not in self.known_texts:
|
||||
return [float(1.0)] * 9 + [float(0.0)]
|
||||
return [float(1.0)] * 9 + [float(self.known_texts.index(text))]
|
||||
|
||||
|
||||
class AngularTwoDimensionalEmbeddings(Embeddings):
|
||||
"""
|
||||
From angles (as strings in units of pi) to unit embedding vectors on a circle.
|
||||
"""
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""
|
||||
Make a list of texts into a list of embedding vectors.
|
||||
"""
|
||||
return [self.embed_query(text) for text in texts]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""
|
||||
Convert input text to a 'vector' (list of floats).
|
||||
If the text is a number, use it as the angle for the
|
||||
unit vector in units of pi.
|
||||
Any other input text becomes the singular result [0, 0] !
|
||||
"""
|
||||
try:
|
||||
angle = float(text)
|
||||
return [math.cos(angle * math.pi), math.sin(angle * math.pi)]
|
||||
except ValueError:
|
||||
# Assume: just test string, no attention is paid to values.
|
||||
return [0.0, 0.0]
|
||||
|
135
tests/integration_tests/vectorstores/test_cassandra.py
Normal file
135
tests/integration_tests/vectorstores/test_cassandra.py
Normal file
@ -0,0 +1,135 @@
|
||||
"""Test Cassandra functionality."""
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.vectorstores import Cassandra
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import (
|
||||
AngularTwoDimensionalEmbeddings,
|
||||
ConsistentFakeEmbeddings,
|
||||
Embeddings,
|
||||
)
|
||||
|
||||
|
||||
def _vectorstore_from_texts(
|
||||
texts: List[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
embedding_class: Type[Embeddings] = ConsistentFakeEmbeddings,
|
||||
drop: bool = True,
|
||||
) -> Cassandra:
|
||||
keyspace = "vector_test_keyspace"
|
||||
table_name = "vector_test_table"
|
||||
# get db connection
|
||||
cluster = Cluster()
|
||||
session = cluster.connect()
|
||||
# ensure keyspace exists
|
||||
session.execute(
|
||||
(
|
||||
f"CREATE KEYSPACE IF NOT EXISTS {keyspace} "
|
||||
f"WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}"
|
||||
)
|
||||
)
|
||||
# drop table if required
|
||||
if drop:
|
||||
session.execute(f"DROP TABLE IF EXISTS {keyspace}.{table_name}")
|
||||
#
|
||||
return Cassandra.from_texts(
|
||||
texts,
|
||||
embedding_class(),
|
||||
metadatas=metadatas,
|
||||
session=session,
|
||||
keyspace=keyspace,
|
||||
table_name=table_name,
|
||||
)
|
||||
|
||||
|
||||
def test_cassandra() -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = _vectorstore_from_texts(texts)
|
||||
output = docsearch.similarity_search("foo", k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
def test_cassandra_with_score() -> None:
|
||||
"""Test end to end construction and search with scores and IDs."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _vectorstore_from_texts(texts, metadatas=metadatas)
|
||||
output = docsearch.similarity_search_with_score("foo", k=3)
|
||||
docs = [o[0] for o in output]
|
||||
scores = [o[1] for o in output]
|
||||
assert docs == [
|
||||
Document(page_content="foo", metadata={"page": 0}),
|
||||
Document(page_content="bar", metadata={"page": 1}),
|
||||
Document(page_content="baz", metadata={"page": 2}),
|
||||
]
|
||||
assert scores[0] > scores[1] > scores[2]
|
||||
|
||||
|
||||
def test_cassandra_max_marginal_relevance_search() -> None:
|
||||
"""
|
||||
Test end to end construction and MMR search.
|
||||
The embedding function used here ensures `texts` become
|
||||
the following vectors on a circle (numbered v0 through v3):
|
||||
|
||||
______ v2
|
||||
/ \
|
||||
/ \ v1
|
||||
v3 | . | query
|
||||
\ / v0
|
||||
\______/ (N.B. very crude drawing)
|
||||
|
||||
With fetch_k==3 and k==2, when query is at (1, ),
|
||||
one expects that v2 and v0 are returned (in some order).
|
||||
"""
|
||||
texts = ["-0.125", "+0.125", "+0.25", "+1.0"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _vectorstore_from_texts(
|
||||
texts, metadatas=metadatas, embedding_class=AngularTwoDimensionalEmbeddings
|
||||
)
|
||||
output = docsearch.max_marginal_relevance_search("0.0", k=2, fetch_k=3)
|
||||
output_set = {
|
||||
(mmr_doc.page_content, mmr_doc.metadata["page"]) for mmr_doc in output
|
||||
}
|
||||
assert output_set == {
|
||||
("+0.25", 2),
|
||||
("-0.125", 0),
|
||||
}
|
||||
|
||||
|
||||
def test_cassandra_add_extra() -> None:
|
||||
"""Test end to end construction with further insertions."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _vectorstore_from_texts(texts, metadatas=metadatas)
|
||||
|
||||
docsearch.add_texts(texts, metadatas)
|
||||
texts2 = ["foo2", "bar2", "baz2"]
|
||||
docsearch.add_texts(texts2, metadatas)
|
||||
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 6
|
||||
|
||||
|
||||
def test_cassandra_no_drop() -> None:
|
||||
"""Test end to end construction and re-opening the same index."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = _vectorstore_from_texts(texts, metadatas=metadatas)
|
||||
del docsearch
|
||||
|
||||
texts2 = ["foo2", "bar2", "baz2"]
|
||||
docsearch = _vectorstore_from_texts(texts2, metadatas=metadatas, drop=False)
|
||||
|
||||
output = docsearch.similarity_search("foo", k=10)
|
||||
assert len(output) == 6
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# test_cassandra()
|
||||
# test_cassandra_with_score()
|
||||
# test_cassandra_max_marginal_relevance_search()
|
||||
# test_cassandra_add_extra()
|
||||
# test_cassandra_no_drop()
|
Loading…
Reference in New Issue
Block a user