{ "cells": [ { "cell_type": "markdown", "id": "683953b3", "metadata": {}, "source": [ "# Cassandra\n", "\n", ">[Apache Cassandra®](https://cassandra.apache.org) is a NoSQL, row-oriented, highly scalable and highly available database.\n", "\n", "Newest Cassandra releases natively [support](https://cwiki.apache.org/confluence/display/CASSANDRA/CEP-30%3A+Approximate+Nearest+Neighbor(ANN)+Vector+Search+via+Storage-Attached+Indexes) Vector Similarity Search.\n", "\n", "To run this notebook you need either a running Cassandra cluster equipped with Vector Search capabilities (in pre-release at the time of writing) or a DataStax Astra DB instance running in the cloud (you can get one for free at [datastax.com](https://astra.datastax.com)). Check [cassio.org](https://cassio.org/start_here/) for more information." ] }, { "cell_type": "code", "execution_count": null, "id": "b4c41cad-08ef-4f72-a545-2151e4598efe", "metadata": { "tags": [] }, "outputs": [], "source": [ "!pip install \"cassio>=0.0.7\"" ] }, { "cell_type": "markdown", "id": "b7e46bb0", "metadata": {}, "source": [ "### Please provide database connection parameters and secrets:" ] }, { "cell_type": "code", "execution_count": null, "id": "36128a32", "metadata": {}, "outputs": [], "source": [ "import os\n", "import getpass\n", "\n", "database_mode = (input(\"\\n(C)assandra or (A)stra DB? \")).upper()\n", "\n", "keyspace_name = input(\"\\nKeyspace name? \")\n", "\n", "if database_mode == \"A\":\n", " ASTRA_DB_APPLICATION_TOKEN = getpass.getpass('\\nAstra DB Token (\"AstraCS:...\") ')\n", " #\n", " ASTRA_DB_SECURE_BUNDLE_PATH = input(\"Full path to your Secure Connect Bundle? \")\n", "elif database_mode == \"C\":\n", " CASSANDRA_CONTACT_POINTS = input(\n", " \"Contact points? (comma-separated, empty for localhost) \"\n", " ).strip()" ] }, { "cell_type": "markdown", "id": "4f22aac2", "metadata": {}, "source": [ "#### depending on whether local or cloud-based Astra DB, create the corresponding database connection \"Session\" object" ] }, { "cell_type": "code", "execution_count": null, "id": "677f8576", "metadata": {}, "outputs": [], "source": [ "from cassandra.cluster import Cluster\n", "from cassandra.auth import PlainTextAuthProvider\n", "\n", "if database_mode == \"C\":\n", " if CASSANDRA_CONTACT_POINTS:\n", " cluster = Cluster(\n", " [cp.strip() for cp in CASSANDRA_CONTACT_POINTS.split(\",\") if cp.strip()]\n", " )\n", " else:\n", " cluster = Cluster()\n", " session = cluster.connect()\n", "elif database_mode == \"A\":\n", " ASTRA_DB_CLIENT_ID = \"token\"\n", " cluster = Cluster(\n", " cloud={\n", " \"secure_connect_bundle\": ASTRA_DB_SECURE_BUNDLE_PATH,\n", " },\n", " auth_provider=PlainTextAuthProvider(\n", " ASTRA_DB_CLIENT_ID,\n", " ASTRA_DB_APPLICATION_TOKEN,\n", " ),\n", " )\n", " session = cluster.connect()\n", "else:\n", " raise NotImplementedError" ] }, { "cell_type": "markdown", "id": "320af802-9271-46ee-948f-d2453933d44b", "metadata": {}, "source": [ "### Please provide OpenAI access key\n", "\n", "We want to use `OpenAIEmbeddings` so we have to get the OpenAI API Key." ] }, { "cell_type": "code", "execution_count": null, "id": "ffea66e4-bc23-46a9-9580-b348dfe7b7a7", "metadata": {}, "outputs": [], "source": [ "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"OpenAI API Key:\")" ] }, { "cell_type": "markdown", "id": "e98a139b", "metadata": {}, "source": [ "### Creation and usage of the Vector Store" ] }, { "cell_type": "code", "execution_count": null, "id": "aac9563e", "metadata": { "tags": [] }, "outputs": [], "source": [ "from langchain.embeddings.openai import OpenAIEmbeddings\n", "from langchain.text_splitter import CharacterTextSplitter\n", "from langchain.vectorstores import Cassandra\n", "from langchain.document_loaders import TextLoader" ] }, { "cell_type": "code", "execution_count": null, "id": "a3c3999a", "metadata": {}, "outputs": [], "source": [ "from langchain.document_loaders import TextLoader\n", "\n", "loader = TextLoader(\"../../../state_of_the_union.txt\")\n", "documents = loader.load()\n", "text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n", "docs = text_splitter.split_documents(documents)\n", "\n", "embedding_function = OpenAIEmbeddings()" ] }, { "cell_type": "code", "execution_count": null, "id": "6e104aee", "metadata": {}, "outputs": [], "source": [ "table_name = \"my_vector_db_table\"\n", "\n", "docsearch = Cassandra.from_documents(\n", " documents=docs,\n", " embedding=embedding_function,\n", " session=session,\n", " keyspace=keyspace_name,\n", " table_name=table_name,\n", ")\n", "\n", "query = \"What did the president say about Ketanji Brown Jackson\"\n", "docs = docsearch.similarity_search(query)" ] }, { "cell_type": "code", "execution_count": null, "id": "f509ee02", "metadata": {}, "outputs": [], "source": [ "## if you already have an index, you can load it and use it like this:\n", "\n", "# docsearch_preexisting = Cassandra(\n", "# embedding=embedding_function,\n", "# session=session,\n", "# keyspace=keyspace_name,\n", "# table_name=table_name,\n", "# )\n", "\n", "# docsearch_preexisting.similarity_search(query, k=2)" ] }, { "cell_type": "code", "execution_count": null, "id": "9c608226", "metadata": {}, "outputs": [], "source": [ "print(docs[0].page_content)" ] }, { "cell_type": "markdown", "id": "d46d1452", "metadata": {}, "source": [ "### Maximal Marginal Relevance Searches\n", "\n", "In addition to using similarity search in the retriever object, you can also use `mmr` as retriever.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a359ed74", "metadata": {}, "outputs": [], "source": [ "retriever = docsearch.as_retriever(search_type=\"mmr\")\n", "matched_docs = retriever.get_relevant_documents(query)\n", "for i, d in enumerate(matched_docs):\n", " print(f\"\\n## Document {i}\\n\")\n", " print(d.page_content)" ] }, { "cell_type": "markdown", "id": "7c477287", "metadata": {}, "source": [ "Or use `max_marginal_relevance_search` directly:" ] }, { "cell_type": "code", "execution_count": null, "id": "9ca82740", "metadata": {}, "outputs": [], "source": [ "found_docs = docsearch.max_marginal_relevance_search(query, k=2, fetch_k=10)\n", "for i, doc in enumerate(found_docs):\n", " print(f\"{i + 1}.\", doc.page_content, \"\\n\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 }