From 76102971c056bb277bf394068c98fb05ee2fb07d Mon Sep 17 00:00:00 2001 From: Anthony Mahanna <43019056+aMahanna@users.noreply.github.com> Date: Mon, 24 Jul 2023 18:16:52 -0400 Subject: [PATCH] ArangoDB/AQL support for Graph QA Chain (#7880) **Description**: Serves as an introduction to LangChain's support for [ArangoDB](https://github.com/arangodb/arangodb), similar to https://github.com/hwchase17/langchain/pull/7165 and https://github.com/hwchase17/langchain/pull/4881 **Issue**: No issue has been created for this feature **Dependencies**: `python-arango` has been added as an optional dependency via the `CONTRIBUTING.md` guidelines **Twitter handle**: [at]arangodb - Integration test has been added - Notebook has been added: [graph_arangodb_qa.ipynb](https://github.com/amahanna/langchain/blob/master/docs/extras/modules/chains/additional/graph_arangodb_qa.ipynb) [![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/amahanna/langchain/blob/master/docs/extras/modules/chains/additional/graph_arangodb_qa.ipynb) ``` docker run -p 8529:8529 -e ARANGO_ROOT_PASSWORD= arangodb/arangodb ``` ``` pip install git+https://github.com/amahanna/langchain.git ``` ```python from arango import ArangoClient from langchain.chat_models import ChatOpenAI from langchain.graphs import ArangoGraph from langchain.chains import ArangoGraphQAChain db = ArangoClient(hosts="localhost:8529").db(name="_system", username="root", password="", verify=True) graph = ArangoGraph(db) chain = ArangoGraphQAChain.from_llm(ChatOpenAI(temperature=0), graph=graph) chain.run("Is Ned Stark alive?") ``` --------- Co-authored-by: Bagatur --- .../integrations/providers/arangodb.mdx | 23 + .../chains/additional/graph_arangodb_qa.ipynb | 819 ++++++++++++++++++ libs/langchain/langchain/chains/__init__.py | 2 + .../langchain/chains/graph_qa/arangodb.py | 229 +++++ .../langchain/chains/graph_qa/prompts.py | 117 +++ libs/langchain/langchain/graphs/__init__.py | 2 + .../langchain/graphs/arangodb_graph.py | 166 ++++ libs/langchain/poetry.lock | 29 +- libs/langchain/pyproject.toml | 2 + .../chains/test_graph_database_arangodb.py | 78 ++ 10 files changed, 1464 insertions(+), 3 deletions(-) create mode 100644 docs/extras/integrations/providers/arangodb.mdx create mode 100644 docs/extras/modules/chains/additional/graph_arangodb_qa.ipynb create mode 100644 libs/langchain/langchain/chains/graph_qa/arangodb.py create mode 100644 libs/langchain/langchain/graphs/arangodb_graph.py create mode 100644 libs/langchain/tests/integration_tests/chains/test_graph_database_arangodb.py diff --git a/docs/extras/integrations/providers/arangodb.mdx b/docs/extras/integrations/providers/arangodb.mdx new file mode 100644 index 0000000000..e2650f374f --- /dev/null +++ b/docs/extras/integrations/providers/arangodb.mdx @@ -0,0 +1,23 @@ +# ArangoDB + +>[ArangoDB](https://github.com/arangodb/arangodb) is a scalable graph database system to drive value from connected data, faster. Native graphs, an integrated search engine, and JSON support, via a single query language. ArangoDB runs on-prem, in the cloud – anywhere. + +## Dependencies + +Install the [ArangoDB Python Driver](https://github.com/ArangoDB-Community/python-arango) package with +```bash +pip install python-arango +``` + +## Graph QA Chain + +Connect your ArangoDB Database with a Chat Model to get insights on your data. + +See the notebook example [here](/docs/modules/chains/additional/graph_arangodb_qa.html). + +```python +from arango import ArangoClient + +from langchain.graphs import ArangoGraph +from langchain.chains import ArangoGraphQAChain +``` \ No newline at end of file diff --git a/docs/extras/modules/chains/additional/graph_arangodb_qa.ipynb b/docs/extras/modules/chains/additional/graph_arangodb_qa.ipynb new file mode 100644 index 0000000000..f7ab6c46e7 --- /dev/null +++ b/docs/extras/modules/chains/additional/graph_arangodb_qa.ipynb @@ -0,0 +1,819 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c94240f5", + "metadata": { + "id": "c94240f5" + }, + "source": [ + "# ArangoDB QA chain\n", + "\n", + "[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/hwchase17/langchain/blob/master/docs/extras/modules/chains/additional/graph_arangodb_qa.ipynb)\n", + "\n", + "This notebook shows how to use LLMs to provide a natural language interface to an [ArangoDB](https://github.com/arangodb/arangodb#readme) database." + ] + }, + { + "cell_type": "markdown", + "id": "dbc0ee68", + "metadata": { + "id": "dbc0ee68" + }, + "source": [ + "You can get a local ArangoDB instance running via the [ArangoDB Docker image](https://hub.docker.com/_/arangodb): \n", + "\n", + "```\n", + "docker run -p 8529:8529 -e ARANGO_ROOT_PASSWORD= arangodb/arangodb\n", + "```\n", + "\n", + "An alternative is to use the [ArangoDB Cloud Connector package](https://github.com/arangodb/adb-cloud-connector#readme) to get a temporary cloud instance running:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "izi6YoFC8KRH", + "metadata": { + "id": "izi6YoFC8KRH" + }, + "outputs": [], + "source": [ + "%%capture\n", + "!pip install python-arango # The ArangoDB Python Driver\n", + "!pip install adb-cloud-connector # The ArangoDB Cloud Instance provisioner\n", + "!pip install openai\n", + "!pip install langchain" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "62812aad", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "62812aad", + "outputId": "f7ed8346-d88b-40d1-eaff-68e97e0e157e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Log: requesting new credentials...\n", + "Succcess: new credentials acquired\n", + "{\n", + " \"dbName\": \"TUT3sp29s3pjf1io0h4cfdsq\",\n", + " \"username\": \"TUTo6nkwgzkizej3kysgdyeo8\",\n", + " \"password\": \"TUT9vx0qjqt42i9bq8uik4v9\",\n", + " \"hostname\": \"tutorials.arangodb.cloud\",\n", + " \"port\": 8529,\n", + " \"url\": \"https://tutorials.arangodb.cloud:8529\"\n", + "}\n" + ] + } + ], + "source": [ + "# Instantiate ArangoDB Database\n", + "import json\n", + "from arango import ArangoClient\n", + "from adb_cloud_connector import get_temp_credentials\n", + "\n", + "con = get_temp_credentials()\n", + "\n", + "db = ArangoClient(hosts=con[\"url\"]).db(\n", + " con[\"dbName\"], con[\"username\"], con[\"password\"], verify=True\n", + ")\n", + "\n", + "print(json.dumps(con, indent=2))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "0928915d", + "metadata": { + "id": "0928915d" + }, + "outputs": [], + "source": [ + "# Instantiate the ArangoDB-LangChain Graph\n", + "from langchain.graphs import ArangoGraph\n", + "\n", + "graph = ArangoGraph(db)" + ] + }, + { + "cell_type": "markdown", + "id": "995ea9b9", + "metadata": { + "id": "995ea9b9" + }, + "source": [ + "## Populating the Database\n", + "\n", + "We will rely on the Python Driver to import our [GameOfThrones](https://github.com/arangodb/example-datasets/tree/master/GameOfThrones) data into our database." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "fedd26b9", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fedd26b9", + "outputId": "fc7d9067-e4f5-495e-cd0c-79135bc16fc2" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'error': False,\n", + " 'created': 4,\n", + " 'errors': 0,\n", + " 'empty': 0,\n", + " 'updated': 0,\n", + " 'ignored': 0,\n", + " 'details': []}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "if db.has_graph(\"GameOfThrones\"):\n", + " db.delete_graph(\"GameOfThrones\", drop_collections=True)\n", + "\n", + "db.create_graph(\n", + " \"GameOfThrones\",\n", + " edge_definitions=[\n", + " {\n", + " \"edge_collection\": \"ChildOf\",\n", + " \"from_vertex_collections\": [\"Characters\"],\n", + " \"to_vertex_collections\": [\"Characters\"],\n", + " },\n", + " ],\n", + ")\n", + "\n", + "documents = [\n", + " {\n", + " \"_key\": \"NedStark\",\n", + " \"name\": \"Ned\",\n", + " \"surname\": \"Stark\",\n", + " \"alive\": True,\n", + " \"age\": 41,\n", + " \"gender\": \"male\",\n", + " },\n", + " {\n", + " \"_key\": \"CatelynStark\",\n", + " \"name\": \"Catelyn\",\n", + " \"surname\": \"Stark\",\n", + " \"alive\": False,\n", + " \"age\": 40,\n", + " \"gender\": \"female\",\n", + " },\n", + " {\n", + " \"_key\": \"AryaStark\",\n", + " \"name\": \"Arya\",\n", + " \"surname\": \"Stark\",\n", + " \"alive\": True,\n", + " \"age\": 11,\n", + " \"gender\": \"female\",\n", + " },\n", + " {\n", + " \"_key\": \"BranStark\",\n", + " \"name\": \"Bran\",\n", + " \"surname\": \"Stark\",\n", + " \"alive\": True,\n", + " \"age\": 10,\n", + " \"gender\": \"male\",\n", + " },\n", + "]\n", + "\n", + "edges = [\n", + " {\"_to\": \"Characters/NedStark\", \"_from\": \"Characters/AryaStark\"},\n", + " {\"_to\": \"Characters/NedStark\", \"_from\": \"Characters/BranStark\"},\n", + " {\"_to\": \"Characters/CatelynStark\", \"_from\": \"Characters/AryaStark\"},\n", + " {\"_to\": \"Characters/CatelynStark\", \"_from\": \"Characters/BranStark\"},\n", + "]\n", + "\n", + "db.collection(\"Characters\").import_bulk(documents)\n", + "db.collection(\"ChildOf\").import_bulk(edges)" + ] + }, + { + "cell_type": "markdown", + "id": "58c1a8ea", + "metadata": { + "id": "58c1a8ea" + }, + "source": [ + "## Getting & Setting the ArangoDB Schema\n", + "\n", + "An initial ArangoDB Schema is generated upon instantiating the `ArangoDBGraph` object. Below are the schema's getter & setter methods should you be interested in viewing or modifying the schema:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "4e3de44f", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4e3de44f", + "outputId": "6102f0c6-4a94-4e00-b93e-eb8de2f7d9d5" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"Graph Schema\": [],\n", + " \"Collection Schema\": []\n", + "}\n" + ] + } + ], + "source": [ + "# The schema should be empty here,\n", + "# since `graph` was initialized prior to ArangoDB Data ingestion (see above).\n", + "\n", + "import json\n", + "\n", + "print(json.dumps(graph.schema, indent=4))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "1fe76ccd", + "metadata": { + "id": "1fe76ccd" + }, + "outputs": [], + "source": [ + "graph.set_schema()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "mZ679anj_-Er", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mZ679anj_-Er", + "outputId": "e05229c7-bc61-4803-d720-47e3e9b2b350" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"Graph Schema\": [\n", + " {\n", + " \"graph_name\": \"GameOfThrones\",\n", + " \"edge_definitions\": [\n", + " {\n", + " \"edge_collection\": \"ChildOf\",\n", + " \"from_vertex_collections\": [\n", + " \"Characters\"\n", + " ],\n", + " \"to_vertex_collections\": [\n", + " \"Characters\"\n", + " ]\n", + " }\n", + " ]\n", + " }\n", + " ],\n", + " \"Collection Schema\": [\n", + " {\n", + " \"collection_name\": \"ChildOf\",\n", + " \"collection_type\": \"edge\",\n", + " \"edge_properties\": [\n", + " {\n", + " \"name\": \"_key\",\n", + " \"type\": \"str\"\n", + " },\n", + " {\n", + " \"name\": \"_id\",\n", + " \"type\": \"str\"\n", + " },\n", + " {\n", + " \"name\": \"_from\",\n", + " \"type\": \"str\"\n", + " },\n", + " {\n", + " \"name\": \"_to\",\n", + " \"type\": \"str\"\n", + " },\n", + " {\n", + " \"name\": \"_rev\",\n", + " \"type\": \"str\"\n", + " }\n", + " ],\n", + " \"example_edge\": {\n", + " \"_key\": \"266218884025\",\n", + " \"_id\": \"ChildOf/266218884025\",\n", + " \"_from\": \"Characters/AryaStark\",\n", + " \"_to\": \"Characters/NedStark\",\n", + " \"_rev\": \"_gVPKGSq---\"\n", + " }\n", + " },\n", + " {\n", + " \"collection_name\": \"Characters\",\n", + " \"collection_type\": \"document\",\n", + " \"document_properties\": [\n", + " {\n", + " \"name\": \"_key\",\n", + " \"type\": \"str\"\n", + " },\n", + " {\n", + " \"name\": \"_id\",\n", + " \"type\": \"str\"\n", + " },\n", + " {\n", + " \"name\": \"_rev\",\n", + " \"type\": \"str\"\n", + " },\n", + " {\n", + " \"name\": \"name\",\n", + " \"type\": \"str\"\n", + " },\n", + " {\n", + " \"name\": \"surname\",\n", + " \"type\": \"str\"\n", + " },\n", + " {\n", + " \"name\": \"alive\",\n", + " \"type\": \"bool\"\n", + " },\n", + " {\n", + " \"name\": \"age\",\n", + " \"type\": \"int\"\n", + " },\n", + " {\n", + " \"name\": \"gender\",\n", + " \"type\": \"str\"\n", + " }\n", + " ],\n", + " \"example_document\": {\n", + " \"_key\": \"NedStark\",\n", + " \"_id\": \"Characters/NedStark\",\n", + " \"_rev\": \"_gVPKGPi---\",\n", + " \"name\": \"Ned\",\n", + " \"surname\": \"Stark\",\n", + " \"alive\": true,\n", + " \"age\": 41,\n", + " \"gender\": \"male\"\n", + " }\n", + " }\n", + " ]\n", + "}\n" + ] + } + ], + "source": [ + "# We can now view the generated schema\n", + "\n", + "import json\n", + "\n", + "print(json.dumps(graph.schema, indent=4))" + ] + }, + { + "cell_type": "markdown", + "id": "68a3c677", + "metadata": { + "id": "68a3c677" + }, + "source": [ + "## Querying the ArangoDB Database\n", + "\n", + "We can now use the ArangoDB Graph QA Chain to inquire about our data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "635c4018", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = \"your-key-here\"" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "7476ce98", + "metadata": { + "id": "7476ce98" + }, + "outputs": [], + "source": [ + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.chains import ArangoGraphQAChain\n", + "\n", + "chain = ArangoGraphQAChain.from_llm(\n", + " ChatOpenAI(temperature=0), graph=graph, verbose=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "ef8ee27b", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 261 + }, + "id": "ef8ee27b", + "outputId": "6008ee92-a3ec-4968-d48d-6a5b66403959" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new ArangoGraphQAChain chain...\u001b[0m\n", + "AQL Query (1):\u001b[32;1m\u001b[1;3m\n", + "WITH Characters\n", + "FOR character IN Characters\n", + "FILTER character.name == \"Ned\" AND character.surname == \"Stark\"\n", + "RETURN character.alive\n", + "\u001b[0m\n", + "AQL Result:\n", + "\u001b[32;1m\u001b[1;3m[True]\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + }, + "text/plain": [ + "'Yes, Ned Stark is alive.'" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.run(\"Is Ned Stark alive?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "9CSig1BgA76q", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 261 + }, + "id": "9CSig1BgA76q", + "outputId": "3060cf15-68e0-4f8a-cdfd-68af3f0e5fbf" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new ArangoGraphQAChain chain...\u001b[0m\n", + "AQL Query (1):\u001b[32;1m\u001b[1;3m\n", + "WITH Characters\n", + "FOR character IN Characters\n", + "FILTER character.name == \"Arya\" && character.surname == \"Stark\"\n", + "RETURN character.age\n", + "\u001b[0m\n", + "AQL Result:\n", + "\u001b[32;1m\u001b[1;3m[11]\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + }, + "text/plain": [ + "'Arya Stark is 11 years old.'" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.run(\"How old is Arya Stark?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "9Fzdic_pA_4y", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 298 + }, + "id": "9Fzdic_pA_4y", + "outputId": "9bd93580-964e-4c53-e273-6723dad5f375" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new ArangoGraphQAChain chain...\u001b[0m\n", + "AQL Query (1):\u001b[32;1m\u001b[1;3m\n", + "WITH Characters, ChildOf\n", + "FOR v, e, p IN 1..1 OUTBOUND 'Characters/AryaStark' ChildOf\n", + " FILTER p.vertices[-1]._key == 'NedStark'\n", + " RETURN p\n", + "\u001b[0m\n", + "AQL Result:\n", + "\u001b[32;1m\u001b[1;3m[{'vertices': [{'_key': 'AryaStark', '_id': 'Characters/AryaStark', '_rev': '_gVPKGPi--B', 'name': 'Arya', 'surname': 'Stark', 'alive': True, 'age': 11, 'gender': 'female'}, {'_key': 'NedStark', '_id': 'Characters/NedStark', '_rev': '_gVPKGPi---', 'name': 'Ned', 'surname': 'Stark', 'alive': True, 'age': 41, 'gender': 'male'}], 'edges': [{'_key': '266218884025', '_id': 'ChildOf/266218884025', '_from': 'Characters/AryaStark', '_to': 'Characters/NedStark', '_rev': '_gVPKGSq---'}], 'weights': [0, 1]}]\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + }, + "text/plain": [ + "'Yes, Arya Stark and Ned Stark are related. According to the information retrieved from the database, there is a relationship between them. Arya Stark is the child of Ned Stark.'" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.run(\"Are Arya Stark and Ned Stark related?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "zq_oeDpAOXpF", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 261 + }, + "id": "zq_oeDpAOXpF", + "outputId": "a47f37b5-4d7b-41c6-fbc7-7b1abc25fa20" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new ArangoGraphQAChain chain...\u001b[0m\n", + "AQL Query (1):\u001b[32;1m\u001b[1;3m\n", + "WITH Characters, ChildOf\n", + "FOR v, e IN 1..1 OUTBOUND 'Characters/AryaStark' ChildOf\n", + "FILTER v.alive == false\n", + "RETURN e\n", + "\u001b[0m\n", + "AQL Result:\n", + "\u001b[32;1m\u001b[1;3m[{'_key': '266218884027', '_id': 'ChildOf/266218884027', '_from': 'Characters/AryaStark', '_to': 'Characters/CatelynStark', '_rev': '_gVPKGSu---'}]\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + }, + "text/plain": [ + "'Yes, Arya Stark has a dead parent. The parent is Catelyn Stark.'" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.run(\"Does Arya Stark have a dead parent?\")" + ] + }, + { + "cell_type": "markdown", + "id": "Ob_3aGauGd7d", + "metadata": { + "id": "Ob_3aGauGd7d" + }, + "source": [ + "## Chain Modifiers" + ] + }, + { + "cell_type": "markdown", + "id": "3P490E2dGiBp", + "metadata": { + "id": "3P490E2dGiBp" + }, + "source": [ + "You can alter the values of the following `ArangoDBGraphQAChain` class variables to modify the behaviour of your chain results\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "1B9h3PvzJ41T", + "metadata": { + "id": "1B9h3PvzJ41T" + }, + "outputs": [], + "source": [ + "# Specify the maximum number of AQL Query Results to return\n", + "chain.top_k = 10\n", + "\n", + "# Specify whether or not to return the AQL Query in the output dictionary\n", + "chain.return_aql_query = True\n", + "\n", + "# Specify whether or not to return the AQL JSON Result in the output dictionary\n", + "chain.return_aql_result = True\n", + "\n", + "# Specify the maximum amount of AQL Generation attempts that should be made\n", + "chain.max_aql_generation_attempts = 5\n", + "\n", + "# Specify a set of AQL Query Examples, which are passed to\n", + "# the AQL Generation Prompt Template to promote few-shot-learning.\n", + "# Defaults to an empty string.\n", + "chain.aql_examples = \"\"\"\n", + "# Is Ned Stark alive?\n", + "RETURN DOCUMENT('Characters/NedStark').alive\n", + "\n", + "# Is Arya Stark the child of Ned Stark?\n", + "FOR e IN ChildOf\n", + " FILTER e._from == \"Characters/AryaStark\" AND e._to == \"Characters/NedStark\"\n", + " RETURN e\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "49cnjYV-PUv3", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 209 + }, + "id": "49cnjYV-PUv3", + "outputId": "f05f0a86-f922-47d1-91b9-5380ef1f996d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new ArangoGraphQAChain chain...\u001b[0m\n", + "AQL Query (1):\u001b[32;1m\u001b[1;3m\n", + "RETURN DOCUMENT('Characters/NedStark').alive\n", + "\u001b[0m\n", + "AQL Result:\n", + "\u001b[32;1m\u001b[1;3m[True]\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + }, + "text/plain": [ + "'Yes, according to the information in the database, Ned Stark is alive.'" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.run(\"Is Ned Stark alive?\")\n", + "\n", + "# chain(\"Is Ned Stark alive?\") # Returns a dictionary with the AQL Query & AQL Result" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "nWfALJ8dPczE", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 244 + }, + "id": "nWfALJ8dPczE", + "outputId": "1235baae-f3f7-438e-ef24-658cd17f727d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new ArangoGraphQAChain chain...\u001b[0m\n", + "AQL Query (1):\u001b[32;1m\u001b[1;3m\n", + "FOR e IN ChildOf\n", + " FILTER e._from == \"Characters/BranStark\" AND e._to == \"Characters/NedStark\"\n", + " RETURN e\n", + "\u001b[0m\n", + "AQL Result:\n", + "\u001b[32;1m\u001b[1;3m[{'_key': '266218884026', '_id': 'ChildOf/266218884026', '_from': 'Characters/BranStark', '_to': 'Characters/NedStark', '_rev': '_gVPKGSq--_'}]\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + }, + "text/plain": [ + "'Yes, according to the information in the ArangoDB database, Bran Stark is indeed the child of Ned Stark.'" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.run(\"Is Bran Stark the child of Ned Stark?\")" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [ + "995ea9b9", + "58c1a8ea", + "68a3c677", + "Ob_3aGauGd7d" + ], + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/langchain/langchain/chains/__init__.py b/libs/langchain/langchain/chains/__init__.py index fdde1cce6d..c806db1d04 100644 --- a/libs/langchain/langchain/chains/__init__.py +++ b/libs/langchain/langchain/chains/__init__.py @@ -28,6 +28,7 @@ from langchain.chains.conversational_retrieval.base import ( ) from langchain.chains.example_generator import generate_example from langchain.chains.flare.base import FlareChain +from langchain.chains.graph_qa.arangodb import ArangoGraphQAChain from langchain.chains.graph_qa.base import GraphQAChain from langchain.chains.graph_qa.cypher import GraphCypherQAChain from langchain.chains.graph_qa.hugegraph import HugeGraphQAChain @@ -77,6 +78,7 @@ from langchain.chains.transform import TransformChain __all__ = [ "APIChain", "AnalyzeDocumentChain", + "ArangoGraphQAChain", "ChatVectorDBChain", "ConstitutionalChain", "ConversationChain", diff --git a/libs/langchain/langchain/chains/graph_qa/arangodb.py b/libs/langchain/langchain/chains/graph_qa/arangodb.py new file mode 100644 index 0000000000..42eba0b34f --- /dev/null +++ b/libs/langchain/langchain/chains/graph_qa/arangodb.py @@ -0,0 +1,229 @@ +"""Question answering over a graph.""" +from __future__ import annotations + +import re +from typing import Any, Dict, List, Optional + +from pydantic import Field + +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import CallbackManagerForChainRun +from langchain.chains.base import Chain +from langchain.chains.graph_qa.prompts import ( + AQL_FIX_PROMPT, + AQL_GENERATION_PROMPT, + AQL_QA_PROMPT, +) +from langchain.chains.llm import LLMChain +from langchain.graphs.arangodb_graph import ArangoGraph +from langchain.schema import BasePromptTemplate + + +class ArangoGraphQAChain(Chain): + """Chain for question-answering against a graph by generating AQL statements.""" + + graph: ArangoGraph = Field(exclude=True) + aql_generation_chain: LLMChain + aql_fix_chain: LLMChain + qa_chain: LLMChain + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + + # Specifies the maximum number of AQL Query Results to return + top_k = 10 + + # Specifies the set of AQL Query Examples that promote few-shot-learning + aql_examples = "" + + # Specify whether to return the AQL Query in the output dictionary + return_aql_query: bool = False + + # Specify whether to return the AQL JSON Result in the output dictionary + return_aql_result: bool = False + + # Specify the maximum amount of AQL Generation attempts that should be made + max_aql_generation_attempts = 3 + + @property + def input_keys(self) -> List[str]: + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + return [self.output_key] + + @property + def _chain_type(self) -> str: + return "graph_aql_chain" + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + *, + qa_prompt: BasePromptTemplate = AQL_QA_PROMPT, + aql_generation_prompt: BasePromptTemplate = AQL_GENERATION_PROMPT, + aql_fix_prompt: BasePromptTemplate = AQL_FIX_PROMPT, + **kwargs: Any, + ) -> ArangoGraphQAChain: + """Initialize from LLM.""" + qa_chain = LLMChain(llm=llm, prompt=qa_prompt) + aql_generation_chain = LLMChain(llm=llm, prompt=aql_generation_prompt) + aql_fix_chain = LLMChain(llm=llm, prompt=aql_fix_prompt) + + return cls( + qa_chain=qa_chain, + aql_generation_chain=aql_generation_chain, + aql_fix_chain=aql_fix_chain, + **kwargs, + ) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + """ + Generate an AQL statement from user input, use it retrieve a response + from an ArangoDB Database instance, and respond to the user input + in natural language. + + Users can modify the following ArangoGraphQAChain Class Variables: + + :var top_k: The maximum number of AQL Query Results to return + :type top_k: int + + :var aql_examples: A set of AQL Query Examples that are passed to + the AQL Generation Prompt Template to promote few-shot-learning. + Defaults to an empty string. + :type aql_examples: str + + :var return_aql_query: Whether to return the AQL Query in the + output dictionary. Defaults to False. + :type return_aql_query: bool + + :var return_aql_result: Whether to return the AQL Query in the + output dictionary. Defaults to False + :type return_aql_result: bool + + :var max_aql_generation_attempts: The maximum amount of AQL + Generation attempts to be made prior to raising the last + AQL Query Execution Error. Defaults to 3. + :type max_aql_generation_attempts: int + """ + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + user_input = inputs[self.input_key] + + ######################### + # Generate AQL Query # + aql_generation_output = self.aql_generation_chain.run( + { + "adb_schema": self.graph.schema, + "aql_examples": self.aql_examples, + "user_input": user_input, + }, + callbacks=callbacks, + ) + ######################### + + aql_query = "" + aql_error = "" + aql_result = None + aql_generation_attempt = 1 + + while ( + aql_result is None + and aql_generation_attempt < self.max_aql_generation_attempts + 1 + ): + ##################### + # Extract AQL Query # + pattern = r"```(?i:aql)?(.*?)```" + matches = re.findall(pattern, aql_generation_output, re.DOTALL) + if not matches: + _run_manager.on_text( + "Invalid Response: ", end="\n", verbose=self.verbose + ) + _run_manager.on_text( + aql_generation_output, color="red", end="\n", verbose=self.verbose + ) + raise ValueError(f"Response is Invalid: {aql_generation_output}") + + aql_query = matches[0] + ##################### + + _run_manager.on_text( + f"AQL Query ({aql_generation_attempt}):", verbose=self.verbose + ) + _run_manager.on_text( + aql_query, color="green", end="\n", verbose=self.verbose + ) + + ##################### + # Execute AQL Query # + from arango import AQLQueryExecuteError + + try: + aql_result = self.graph.query(aql_query, self.top_k) + except AQLQueryExecuteError as e: + aql_error = e.error_message + + _run_manager.on_text( + "AQL Query Execution Error: ", end="\n", verbose=self.verbose + ) + _run_manager.on_text( + aql_error, color="yellow", end="\n\n", verbose=self.verbose + ) + + ######################## + # Retry AQL Generation # + aql_generation_output = self.aql_fix_chain.run( + { + "adb_schema": self.graph.schema, + "aql_query": aql_query, + "aql_error": aql_error, + }, + callbacks=callbacks, + ) + ######################## + + ##################### + + aql_generation_attempt += 1 + + if aql_result is None: + m = f""" + Maximum amount of AQL Query Generation attempts reached. + Unable to execute the AQL Query due to the following error: + {aql_error} + """ + raise ValueError(m) + + _run_manager.on_text("AQL Result:", end="\n", verbose=self.verbose) + _run_manager.on_text( + str(aql_result), color="green", end="\n", verbose=self.verbose + ) + + ######################## + # Interpret AQL Result # + result = self.qa_chain( + { + "adb_schema": self.graph.schema, + "user_input": user_input, + "aql_query": aql_query, + "aql_result": aql_result, + }, + callbacks=callbacks, + ) + ######################## + + # Return results # + result = {self.output_key: result[self.qa_chain.output_key]} + + if self.return_aql_query: + result["aql_query"] = aql_query + + if self.return_aql_result: + result["aql_result"] = aql_result + + return result diff --git a/libs/langchain/langchain/chains/graph_qa/prompts.py b/libs/langchain/langchain/chains/graph_qa/prompts.py index 9ae0757abb..0392a99955 100644 --- a/libs/langchain/langchain/chains/graph_qa/prompts.py +++ b/libs/langchain/langchain/chains/graph_qa/prompts.py @@ -198,6 +198,123 @@ SPARQL_QA_PROMPT = PromptTemplate( ) +AQL_GENERATION_TEMPLATE = """Task: Generate an ArangoDB Query Language (AQL) query from a User Input. + +You are an ArangoDB Query Language (AQL) expert responsible for translating a `User Input` into an ArangoDB Query Language (AQL) query. + +You are given an `ArangoDB Schema`. It is a JSON Object containing: +1. `Graph Schema`: Lists all Graphs within the ArangoDB Database Instance, along with their Edge Relationships. +2. `Collection Schema`: Lists all Collections within the ArangoDB Database Instance, along with their document/edge properties and a document/edge example. + +You may also be given a set of `AQL Query Examples` to help you create the `AQL Query`. If provided, the `AQL Query Examples` should be used as a reference, similar to how `ArangoDB Schema` should be used. + +Things you should do: +- Think step by step. +- Rely on `ArangoDB Schema` and `AQL Query Examples` (if provided) to generate the query. +- Begin the `AQL Query` by the `WITH` AQL keyword to specify all of the ArangoDB Collections required. +- Return the `AQL Query` wrapped in 3 backticks (```). +- Use only the provided relationship types and properties in the `ArangoDB Schema` and any `AQL Query Examples` queries. +- Only answer to requests related to generating an AQL Query. +- If a request is unrelated to generating AQL Query, say that you cannot help the user. + +Things you should not do: +- Do not use any properties/relationships that can't be inferred from the `ArangoDB Schema` or the `AQL Query Examples`. +- Do not include any text except the generated AQL Query. +- Do not provide explanations or apologies in your responses. +- Do not generate an AQL Query that removes or deletes any data. + +Under no circumstance should you generate an AQL Query that deletes any data whatsoever. + +ArangoDB Schema: +{adb_schema} + +AQL Query Examples (Optional): +{aql_examples} + +User Input: +{user_input} + +AQL Query: +""" + +AQL_GENERATION_PROMPT = PromptTemplate( + input_variables=["adb_schema", "aql_examples", "user_input"], + template=AQL_GENERATION_TEMPLATE, +) + +AQL_FIX_TEMPLATE = """Task: Address the ArangoDB Query Language (AQL) error message of an ArangoDB Query Language query. + +You are an ArangoDB Query Language (AQL) expert responsible for correcting the provided `AQL Query` based on the provided `AQL Error`. + +The `AQL Error` explains why the `AQL Query` could not be executed in the database. +The `AQL Error` may also contain the position of the error relative to the total number of lines of the `AQL Query`. +For example, 'error X at position 2:5' denotes that the error X occurs on line 2, column 5 of the `AQL Query`. + +You are also given the `ArangoDB Schema`. It is a JSON Object containing: +1. `Graph Schema`: Lists all Graphs within the ArangoDB Database Instance, along with their Edge Relationships. +2. `Collection Schema`: Lists all Collections within the ArangoDB Database Instance, along with their document/edge properties and a document/edge example. + +You will output the `Corrected AQL Query` wrapped in 3 backticks (```). Do not include any text except the Corrected AQL Query. + +Remember to think step by step. + +ArangoDB Schema: +{adb_schema} + +AQL Query: +{aql_query} + +AQL Error: +{aql_error} + +Corrected AQL Query: +""" + +AQL_FIX_PROMPT = PromptTemplate( + input_variables=[ + "adb_schema", + "aql_query", + "aql_error", + ], + template=AQL_FIX_TEMPLATE, +) + +AQL_QA_TEMPLATE = """Task: Generate a natural language `Summary` from the results of an ArangoDB Query Language query. + +You are an ArangoDB Query Language (AQL) expert responsible for creating a well-written `Summary` from the `User Input` and associated `AQL Result`. + +A user has executed an ArangoDB Query Language query, which has returned the AQL Result in JSON format. +You are responsible for creating an `Summary` based on the AQL Result. + +You are given the following information: +- `ArangoDB Schema`: contains a schema representation of the user's ArangoDB Database. +- `User Input`: the original question/request of the user, which has been translated into an AQL Query. +- `AQL Query`: the AQL equivalent of the `User Input`, translated by another AI Model. Should you deem it to be incorrect, suggest a different AQL Query. +- `AQL Result`: the JSON output returned by executing the `AQL Query` within the ArangoDB Database. + +Remember to think step by step. + +Your `Summary` should sound like it is a response to the `User Input`. +Your `Summary` should not include any mention of the `AQL Query` or the `AQL Result`. + +ArangoDB Schema: +{adb_schema} + +User Input: +{user_input} + +AQL Query: +{aql_query} + +AQL Result: +{aql_result} +""" +AQL_QA_PROMPT = PromptTemplate( + input_variables=["adb_schema", "user_input", "aql_query", "aql_result"], + template=AQL_QA_TEMPLATE, +) + + NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS = """ Instructions: Generate the query in openCypher format and follow these rules: diff --git a/libs/langchain/langchain/graphs/__init__.py b/libs/langchain/langchain/graphs/__init__.py index 5ee08832a1..a0786a4fa7 100644 --- a/libs/langchain/langchain/graphs/__init__.py +++ b/libs/langchain/langchain/graphs/__init__.py @@ -1,4 +1,5 @@ """Graph implementations.""" +from langchain.graphs.arangodb_graph import ArangoGraph from langchain.graphs.hugegraph import HugeGraph from langchain.graphs.kuzu_graph import KuzuGraph from langchain.graphs.nebula_graph import NebulaGraph @@ -15,4 +16,5 @@ __all__ = [ "KuzuGraph", "HugeGraph", "RdfGraph", + "ArangoGraph", ] diff --git a/libs/langchain/langchain/graphs/arangodb_graph.py b/libs/langchain/langchain/graphs/arangodb_graph.py new file mode 100644 index 0000000000..8771953d59 --- /dev/null +++ b/libs/langchain/langchain/graphs/arangodb_graph.py @@ -0,0 +1,166 @@ +import os +from math import ceil +from typing import Any, Dict, List, Optional + + +class ArangoGraph: + """ArangoDB wrapper for graph operations.""" + + def __init__(self, db: Any) -> None: + """Create a new ArangoDB graph wrapper instance.""" + self.set_db(db) + self.set_schema() + + @property + def db(self) -> Any: + return self.__db + + @property + def schema(self) -> Dict[str, Any]: + return self.__schema + + def set_db(self, db: Any) -> None: + from arango.database import Database + + if not isinstance(db, Database): + msg = "**db** parameter must inherit from arango.database.Database" + raise TypeError(msg) + + self.__db: Database = db + self.set_schema() + + def set_schema(self, schema: Optional[Dict[str, Any]] = None) -> None: + """ + Set the schema of the ArangoDB Database. + Auto-generates Schema if **schema** is None. + """ + self.__schema = self.generate_schema() if schema is None else schema + + def generate_schema( + self, sample_ratio: float = 0 + ) -> Dict[str, List[Dict[str, Any]]]: + """ + Generates the schema of the ArangoDB Database and returns it + User can specify a **sample_ratio** (0 to 1) to determine the + ratio of documents/edges used (in relation to the Collection size) + to render each Collection Schema. + """ + if not 0 <= sample_ratio <= 1: + raise ValueError("**sample_ratio** value must be in between 0 to 1") + + # Stores the Edge Relationships between each ArangoDB Document Collection + graph_schema: List[Dict[str, Any]] = [ + {"graph_name": g["name"], "edge_definitions": g["edge_definitions"]} + for g in self.db.graphs() + ] + + # Stores the schema of every ArangoDB Document/Edge collection + collection_schema: List[Dict[str, Any]] = [] + + for collection in self.db.collections(): + if collection["system"]: + continue + + # Extract collection name, type, and size + col_name: str = collection["name"] + col_type: str = collection["type"] + col_size: int = self.db.collection(col_name).count() + + # Set number of ArangoDB documents/edges to retrieve + limit_amount = ceil(sample_ratio * col_size) or 1 + + aql = f""" + FOR doc in {col_name} + LIMIT {limit_amount} + RETURN doc + """ + + doc: Dict[str, Any] + properties: List[Dict[str, str]] = [] + for doc in self.__db.aql.execute(aql): + for key, value in doc.items(): + properties.append({"name": key, "type": type(value).__name__}) + + collection_schema.append( + { + "collection_name": col_name, + "collection_type": col_type, + f"{col_type}_properties": properties, + f"example_{col_type}": doc, + } + ) + + return {"Graph Schema": graph_schema, "Collection Schema": collection_schema} + + def query( + self, query: str, top_k: Optional[int] = None, **kwargs: Any + ) -> List[Dict[str, Any]]: + """Query the ArangoDB database.""" + import itertools + + cursor = self.__db.aql.execute(query, **kwargs) + return [doc for doc in itertools.islice(cursor, top_k)] + + @classmethod + def from_db_credentials( + cls, + url: Optional[str] = None, + dbname: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + ) -> Any: + """Convenience constructor that builds Arango DB from credentials. + + Args: + url: Arango DB url. Can be passed in as named arg or set as environment + var ``ARANGODB_URL``. Defaults to "http://localhost:8529". + dbname: Arango DB name. Can be passed in as named arg or set as + environment var ``ARANGODB_DBNAME``. Defaults to "_system". + username: Can be passed in as named arg or set as environment var + ``ARANGODB_USERNAME``. Defaults to "root". + password: Can be passed ni as named arg or set as environment var + ``ARANGODB_PASSWORD``. Defaults to "". + + Returns: + An arango.database.StandardDatabase. + """ + db = get_arangodb_client( + url=url, dbname=dbname, username=username, password=password + ) + return cls(db) + + +def get_arangodb_client( + url: Optional[str] = None, + dbname: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, +) -> Any: + """Convenience method that gets Arango DB from credentials. + + Args: + url: Arango DB url. Can be passed in as named arg or set as environment + var ``ARANGODB_URL``. Defaults to "http://localhost:8529". + dbname: Arango DB name. Can be passed in as named arg or set as + environment var ``ARANGODB_DBNAME``. Defaults to "_system". + username: Can be passed in as named arg or set as environment var + ``ARANGODB_USERNAME``. Defaults to "root". + password: Can be passed ni as named arg or set as environment var + ``ARANGODB_PASSWORD``. Defaults to "". + + Returns: + An arango.database.StandardDatabase. + """ + try: + from arango import ArangoClient + except ImportError as e: + raise ImportError( + "Unable to import arango, please install with `pip install python-arango`." + ) from e + + _url: str = url or os.environ.get("ARANGODB_URL", "http://localhost:8529") # type: ignore[assignment] # noqa: E501 + _dbname: str = dbname or os.environ.get("ARANGODB_DBNAME", "_system") # type: ignore[assignment] # noqa: E501 + _username: str = username or os.environ.get("ARANGODB_USERNAME", "root") # type: ignore[assignment] # noqa: E501 + _password: str = password or os.environ.get("ARANGODB_PASSWORD", "") # type: ignore[assignment] # noqa: E501 + + return ArangoClient(_url).db(_dbname, _username, _password, verify=True) diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index db85bdffad..88615250fb 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -4398,7 +4398,6 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" files = [ {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, - {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"}, ] [[package]] @@ -8671,6 +8670,30 @@ files = [ [package.dependencies] watchdog = ">=2.0.0" +[[package]] +name = "python-arango" +version = "7.5.9" +description = "Python Driver for ArangoDB" +category = "main" +optional = true +python-versions = ">=3.8" +files = [ + {file = "python-arango-7.5.9.tar.gz", hash = "sha256:c16e5b0bfa7a662015e1708f723c5c324e05a32cf98169229f01fd20200cd0db"}, + {file = "python_arango-7.5.9-py3-none-any.whl", hash = "sha256:aa8587420e7b704c646bf7791b8149674ee562ec0202ba2363100a9cb5dec00f"}, +] + +[package.dependencies] +importlib-metadata = ">=4.7.1" +packaging = ">=23.1" +PyJWT = "*" +requests = "*" +requests-toolbelt = "*" +setuptools = ">=42" +urllib3 = ">=1.26.0" + +[package.extras] +dev = ["black (>=22.3.0)", "flake8 (>=4.0.1)", "isort (>=5.10.1)", "mock", "mypy (>=0.942)", "pre-commit (>=2.17.0)", "pytest (>=7.1.1)", "pytest-cov (>=3.0.0)", "sphinx", "sphinx-rtd-theme", "types-pkg-resources", "types-requests", "types-setuptools"] + [[package]] name = "python-dateutil" version = "2.8.2" @@ -12461,7 +12484,7 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["O365", "aleph-alpha-client", "amadeus", "anthropic", "arxiv", "atlassian-python-api", "awadb", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-cosmos", "azure-identity", "beautifulsoup4", "clarifai", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "esprima", "faiss-cpu", "google-api-python-client", "google-auth", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "jq", "lancedb", "langkit", "lark", "libdeeplake", "lxml", "manifest-ml", "marqo", "momento", "nebula3-python", "neo4j", "networkx", "nlpcloud", "nltk", "nomic", "octoai-sdk", "openai", "openlm", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pymongo", "pyowm", "pypdf", "pytesseract", "pyvespa", "qdrant-client", "rdflib", "redis", "requests-toolbelt", "sentence-transformers", "singlestoredb", "spacy", "steamship", "tensorflow-text", "tigrisdb", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"] +all = ["O365", "aleph-alpha-client", "amadeus", "anthropic", "arxiv", "atlassian-python-api", "awadb", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-cosmos", "azure-identity", "beautifulsoup4", "clarifai", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "esprima", "faiss-cpu", "google-api-python-client", "google-auth", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "jq", "lancedb", "langkit", "lark", "libdeeplake", "lxml", "manifest-ml", "marqo", "momento", "nebula3-python", "neo4j", "networkx", "nlpcloud", "nltk", "nomic", "octoai-sdk", "openai", "openlm", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pymongo", "pyowm", "pypdf", "pytesseract", "python-arango", "pyvespa", "qdrant-client", "rdflib", "redis", "requests-toolbelt", "sentence-transformers", "singlestoredb", "spacy", "steamship", "tensorflow-text", "tigrisdb", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"] azure = ["azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-core", "azure-cosmos", "azure-identity", "azure-search-documents", "openai"] clarifai = ["clarifai"] cohere = ["cohere"] @@ -12477,4 +12500,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "c9c8d7972feb1de7b227f222f53b9a3f9b497c9aa5d4ef7666a9eafe0db423d6" +content-hash = "7a8847de4dd88e71b423ff148823523220a5649340178e8ab1f7bafb03a290d2" diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index cc8d617c6e..97698f6238 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -124,6 +124,7 @@ langsmith = "~0.0.11" rank-bm25 = {version = "^0.2.2", optional = true} amadeus = {version = ">=8.1.0", optional = true} geopandas = {version = "^0.13.1", optional = true} +python-arango = {version = "^7.5.9", optional = true} [tool.poetry.group.test.dependencies] # The only dependencies that should be added are @@ -314,6 +315,7 @@ all = [ "octoai-sdk", "rdflib", "amadeus", + "python-arango", ] # An extra used to be able to add extended testing. diff --git a/libs/langchain/tests/integration_tests/chains/test_graph_database_arangodb.py b/libs/langchain/tests/integration_tests/chains/test_graph_database_arangodb.py new file mode 100644 index 0000000000..32494beb05 --- /dev/null +++ b/libs/langchain/tests/integration_tests/chains/test_graph_database_arangodb.py @@ -0,0 +1,78 @@ +"""Test Graph Database Chain.""" +from typing import Any + +from langchain.chains.graph_qa.arangodb import ArangoGraphQAChain +from langchain.graphs import ArangoGraph +from langchain.graphs.arangodb_graph import get_arangodb_client +from langchain.llms.openai import OpenAI + + +def populate_arangodb_database(db: Any) -> None: + if db.has_graph("GameOfThrones"): + return + + db.create_graph( + "GameOfThrones", + edge_definitions=[ + { + "edge_collection": "ChildOf", + "from_vertex_collections": ["Characters"], + "to_vertex_collections": ["Characters"], + }, + ], + ) + + documents = [ + { + "_key": "NedStark", + "name": "Ned", + "surname": "Stark", + "alive": True, + "age": 41, + "gender": "male", + }, + { + "_key": "AryaStark", + "name": "Arya", + "surname": "Stark", + "alive": True, + "age": 11, + "gender": "female", + }, + ] + + edges = [{"_to": "Characters/NedStark", "_from": "Characters/AryaStark"}] + + db.collection("Characters").import_bulk(documents) + db.collection("ChildOf").import_bulk(edges) + + +def test_connect_arangodb() -> None: + """Test that the ArangoDB database is correctly instantiated and connected.""" + graph = ArangoGraph(get_arangodb_client()) + + sample_aql_result = graph.query("RETURN 'hello world'") + assert ["hello_world"] == sample_aql_result + + +def test_aql_generation() -> None: + """Test that AQL statement is correctly generated and executed.""" + db = get_arangodb_client() + + populate_arangodb_database(db) + + graph = ArangoGraph(db) + chain = ArangoGraphQAChain.from_llm(OpenAI(temperature=0), graph=graph) + chain.return_aql_result = True + + output = chain("Is Ned Stark alive?") + assert output["aql_result"] == [True] + assert "Yes" in output["result"] + + output = chain("How old is Arya Stark?") + assert output["aql_result"] == [11] + assert "11" in output["result"] + + output = chain("What is the relationship between Arya Stark and Ned Stark?") + assert len(output["aql_result"]) == 1 + assert "child of" in output["result"]