mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
415d38ae62
This PR addresses a few minor issues with the Cassandra vector store implementation and extends the store to support Metadata search. Thanks to the latest cassIO library (>=0.1.0), metadata filtering is available in the store. Further, - the "relevance" score is prevented from being flipped in the [0,1] interval, thus ensuring that 1 corresponds to the closest vector (this is related to how the underlying cassIO class returns the cosine difference); - bumped the cassIO package version both in the notebooks and the pyproject.toml; - adjusted the textfile location for the vector-store example after the reshuffling of the Langchain repo dir structure; - added demonstration of metadata filtering in the Cassandra vector store notebook; - better docstring for the Cassandra vector store class; - fixed test flakiness and removed offending out-of-place escape chars from a test module docstring; To my knowledge all relevant tests pass and mypy+black+ruff don't complain. (mypy gives unrelated errors in other modules, which clearly don't depend on the content of this PR). Thank you! Stefano --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
327 lines
9.0 KiB
Plaintext
327 lines
9.0 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "683953b3",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Cassandra\n",
|
|
"\n",
|
|
">[Apache Cassandra®](https://cassandra.apache.org) is a NoSQL, row-oriented, highly scalable and highly available database.\n",
|
|
"\n",
|
|
"Newest Cassandra releases natively [support](https://cwiki.apache.org/confluence/display/CASSANDRA/CEP-30%3A+Approximate+Nearest+Neighbor(ANN)+Vector+Search+via+Storage-Attached+Indexes) Vector Similarity Search.\n",
|
|
"\n",
|
|
"To run this notebook you need either a running Cassandra cluster equipped with Vector Search capabilities (in pre-release at the time of writing) or a DataStax Astra DB instance running in the cloud (you can get one for free at [datastax.com](https://astra.datastax.com)). Check [cassio.org](https://cassio.org/start_here/) for more information."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b4c41cad-08ef-4f72-a545-2151e4598efe",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip install \"cassio>=0.1.0\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b7e46bb0",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Please provide database connection parameters and secrets:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "36128a32",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import getpass\n",
|
|
"\n",
|
|
"database_mode = (input(\"\\n(C)assandra or (A)stra DB? \")).upper()\n",
|
|
"\n",
|
|
"keyspace_name = input(\"\\nKeyspace name? \")\n",
|
|
"\n",
|
|
"if database_mode == \"A\":\n",
|
|
" ASTRA_DB_APPLICATION_TOKEN = getpass.getpass('\\nAstra DB Token (\"AstraCS:...\") ')\n",
|
|
" #\n",
|
|
" ASTRA_DB_SECURE_BUNDLE_PATH = input(\"Full path to your Secure Connect Bundle? \")\n",
|
|
"elif database_mode == \"C\":\n",
|
|
" CASSANDRA_CONTACT_POINTS = input(\n",
|
|
" \"Contact points? (comma-separated, empty for localhost) \"\n",
|
|
" ).strip()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "4f22aac2",
|
|
"metadata": {},
|
|
"source": [
|
|
"#### depending on whether local or cloud-based Astra DB, create the corresponding database connection \"Session\" object"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "677f8576",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from cassandra.cluster import Cluster\n",
|
|
"from cassandra.auth import PlainTextAuthProvider\n",
|
|
"\n",
|
|
"if database_mode == \"C\":\n",
|
|
" if CASSANDRA_CONTACT_POINTS:\n",
|
|
" cluster = Cluster(\n",
|
|
" [cp.strip() for cp in CASSANDRA_CONTACT_POINTS.split(\",\") if cp.strip()]\n",
|
|
" )\n",
|
|
" else:\n",
|
|
" cluster = Cluster()\n",
|
|
" session = cluster.connect()\n",
|
|
"elif database_mode == \"A\":\n",
|
|
" ASTRA_DB_CLIENT_ID = \"token\"\n",
|
|
" cluster = Cluster(\n",
|
|
" cloud={\n",
|
|
" \"secure_connect_bundle\": ASTRA_DB_SECURE_BUNDLE_PATH,\n",
|
|
" },\n",
|
|
" auth_provider=PlainTextAuthProvider(\n",
|
|
" ASTRA_DB_CLIENT_ID,\n",
|
|
" ASTRA_DB_APPLICATION_TOKEN,\n",
|
|
" ),\n",
|
|
" )\n",
|
|
" session = cluster.connect()\n",
|
|
"else:\n",
|
|
" raise NotImplementedError"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "320af802-9271-46ee-948f-d2453933d44b",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Please provide OpenAI access key\n",
|
|
"\n",
|
|
"We want to use `OpenAIEmbeddings` so we have to get the OpenAI API Key."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ffea66e4-bc23-46a9-9580-b348dfe7b7a7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e98a139b",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Creation and usage of the Vector Store"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "aac9563e",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
|
"from langchain.text_splitter import CharacterTextSplitter\n",
|
|
"from langchain.vectorstores import Cassandra\n",
|
|
"from langchain.document_loaders import TextLoader"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a3c3999a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from langchain.document_loaders import TextLoader\n",
|
|
"\n",
|
|
"SOURCE_FILE_NAME = \"../../modules/state_of_the_union.txt\"\n",
|
|
"\n",
|
|
"loader = TextLoader(SOURCE_FILE_NAME)\n",
|
|
"documents = loader.load()\n",
|
|
"text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
|
|
"docs = text_splitter.split_documents(documents)\n",
|
|
"\n",
|
|
"embedding_function = OpenAIEmbeddings()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "6e104aee",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"table_name = \"my_vector_db_table\"\n",
|
|
"\n",
|
|
"docsearch = Cassandra.from_documents(\n",
|
|
" documents=docs,\n",
|
|
" embedding=embedding_function,\n",
|
|
" session=session,\n",
|
|
" keyspace=keyspace_name,\n",
|
|
" table_name=table_name,\n",
|
|
")\n",
|
|
"\n",
|
|
"query = \"What did the president say about Ketanji Brown Jackson\"\n",
|
|
"docs = docsearch.similarity_search(query)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f509ee02",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"## if you already have an index, you can load it and use it like this:\n",
|
|
"\n",
|
|
"# docsearch_preexisting = Cassandra(\n",
|
|
"# embedding=embedding_function,\n",
|
|
"# session=session,\n",
|
|
"# keyspace=keyspace_name,\n",
|
|
"# table_name=table_name,\n",
|
|
"# )\n",
|
|
"\n",
|
|
"# docs = docsearch_preexisting.similarity_search(query, k=2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9c608226",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(docs[0].page_content)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d46d1452",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Maximal Marginal Relevance Searches\n",
|
|
"\n",
|
|
"In addition to using similarity search in the retriever object, you can also use `mmr` as retriever.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a359ed74",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"retriever = docsearch.as_retriever(search_type=\"mmr\")\n",
|
|
"matched_docs = retriever.get_relevant_documents(query)\n",
|
|
"for i, d in enumerate(matched_docs):\n",
|
|
" print(f\"\\n## Document {i}\\n\")\n",
|
|
" print(d.page_content)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7c477287",
|
|
"metadata": {},
|
|
"source": [
|
|
"Or use `max_marginal_relevance_search` directly:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9ca82740",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"found_docs = docsearch.max_marginal_relevance_search(query, k=2, fetch_k=10)\n",
|
|
"for i, doc in enumerate(found_docs):\n",
|
|
" print(f\"{i + 1}.\", doc.page_content, \"\\n\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "da791c5f",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Metadata filtering\n",
|
|
"\n",
|
|
"You can specify filtering on metadata when running searches in the vector store. By default, when inserting documents, the only metadata is the `\"source\"` (but you can customize the metadata at insertion time).\n",
|
|
"\n",
|
|
"Since only one files was inserted, this is just a demonstration of how filters are passed:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "93f132fa",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"filter = {\"source\": SOURCE_FILE_NAME}\n",
|
|
"filtered_docs = docsearch.similarity_search(query, filter=filter, k=5)\n",
|
|
"print(f\"{len(filtered_docs)} documents retrieved.\")\n",
|
|
"print(f\"{filtered_docs[0].page_content[:64]} ...\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1b413ec4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"filter = {\"source\": \"nonexisting_file.txt\"}\n",
|
|
"filtered_docs2 = docsearch.similarity_search(query, filter=filter)\n",
|
|
"print(f\"{len(filtered_docs2)} documents retrieved.\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a0fea764",
|
|
"metadata": {},
|
|
"source": [
|
|
"Please visit the [cassIO documentation](https://cassio.org/frameworks/langchain/about/) for more on using vector stores with Langchain."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.12"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|