diff --git a/docs/modules/chains/examples/graph_cypher_qa.ipynb b/docs/modules/chains/examples/graph_cypher_qa.ipynb new file mode 100644 index 0000000000..b93bf64ee7 --- /dev/null +++ b/docs/modules/chains/examples/graph_cypher_qa.ipynb @@ -0,0 +1,230 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c94240f5", + "metadata": {}, + "source": [ + "# GraphCypherQAChain\n", + "\n", + "This notebook shows how to use LLMs to provide a natural language interface to a graph database you can query with the Cypher query language." + ] + }, + { + "cell_type": "markdown", + "id": "dbc0ee68", + "metadata": {}, + "source": [ + "You will need to have a running Neo4j instance. One option is to create a [free Neo4j database instance in their Aura cloud service](https://neo4j.com/cloud/platform/aura-graph-database/). You can also run the database locally using the [Neo4j Desktop application](https://neo4j.com/download/), or running a docker container.\n", + "You can run a local docker container by running the executing the following script:\n", + "\n", + "```\n", + "docker run \\\n", + " --name neo4j \\\n", + " -p 7474:7474 -p 7687:7687 \\\n", + " -d \\\n", + " -e NEO4J_AUTH=neo4j/pleaseletmein \\\n", + " -e NEO4J_PLUGINS=\\[\\\"apoc\\\"\\] \\\n", + " neo4j:latest\n", + "```\n", + "\n", + "If you are using the docker container, you need to wait a couple of second for the database to start." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "62812aad", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.chains import GraphCypherQAChain\n", + "from langchain.graphs import Neo4jGraph" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "0928915d", + "metadata": {}, + "outputs": [], + "source": [ + "graph = Neo4jGraph(\n", + " url=\"bolt://localhost:7687\", username=\"neo4j\", password=\"pleaseletmein\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "995ea9b9", + "metadata": {}, + "source": [ + "## Seeding the database\n", + "\n", + "Assuming your database is empty, you can populate it using Cypher query language. The following Cypher statement is idempotent, which means the database information will be the same if you run it one or multiple times." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "fedd26b9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "graph.query(\n", + " \"\"\"\n", + "MERGE (m:Movie {name:\"Top Gun\"})\n", + "WITH m\n", + "UNWIND [\"Tom Cruise\", \"Val Kilmer\", \"Anthony Edwards\", \"Meg Ryan\"] AS actor\n", + "MERGE (a:Actor {name:actor})\n", + "MERGE (a)-[:ACTED_IN]->(m)\n", + "\"\"\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "58c1a8ea", + "metadata": {}, + "source": [ + "## Refresh graph schema information\n", + "If the schema of database changes, you can refresh the schema information needed to generate Cypher statements." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "4e3de44f", + "metadata": {}, + "outputs": [], + "source": [ + "graph.refresh_schema()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "1fe76ccd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " Node properties are the following:\n", + " [{'properties': [{'property': 'name', 'type': 'STRING'}], 'labels': 'Movie'}, {'properties': [{'property': 'name', 'type': 'STRING'}], 'labels': 'Actor'}]\n", + " Relationship properties are the following:\n", + " []\n", + " The relationships are the following:\n", + " ['(:Actor)-[:ACTED_IN]->(:Movie)']\n", + " \n" + ] + } + ], + "source": [ + "print(graph.get_schema)" + ] + }, + { + "cell_type": "markdown", + "id": "68a3c677", + "metadata": {}, + "source": [ + "## Querying the graph\n", + "\n", + "We can now use the graph cypher QA chain to ask question of the graph" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "7476ce98", + "metadata": {}, + "outputs": [], + "source": [ + "chain = GraphCypherQAChain.from_llm(\n", + " ChatOpenAI(temperature=0), graph=graph, verbose=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "ef8ee27b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\u001b[1m> Entering new GraphCypherQAChain chain...\u001b[0m\n", + "Generated Cypher:\n", + "\u001b[32;1m\u001b[1;3mMATCH (a:Actor)-[:ACTED_IN]->(m:Movie {name: 'Top Gun'})\n", + "RETURN a.name\u001b[0m\n", + "Full Context:\n", + "\u001b[32;1m\u001b[1;3m[{'a.name': 'Tom Cruise'}, {'a.name': 'Val Kilmer'}, {'a.name': 'Anthony Edwards'}, {'a.name': 'Meg Ryan'}]\u001b[0m\n", + "\n", + "\u001b[1m> Finished chain.\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "'Tom Cruise, Val Kilmer, Anthony Edwards, and Meg Ryan played in Top Gun.'" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.run(\"Who played in Top Gun?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4825316", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/chains/__init__.py b/langchain/chains/__init__.py index 05af5f3e31..196803d36f 100644 --- a/langchain/chains/__init__.py +++ b/langchain/chains/__init__.py @@ -10,6 +10,7 @@ from langchain.chains.conversational_retrieval.base import ( ) from langchain.chains.flare.base import FlareChain from langchain.chains.graph_qa.base import GraphQAChain +from langchain.chains.graph_qa.cypher import GraphCypherQAChain from langchain.chains.hyde.base import HypotheticalDocumentEmbedder from langchain.chains.llm import LLMChain from langchain.chains.llm_bash.base import LLMBashChain @@ -58,6 +59,7 @@ __all__ = [ "HypotheticalDocumentEmbedder", "ChatVectorDBChain", "GraphQAChain", + "GraphCypherQAChain", "ConstitutionalChain", "QAGenerationChain", "RetrievalQA", diff --git a/langchain/chains/graph_qa/cypher.py b/langchain/chains/graph_qa/cypher.py new file mode 100644 index 0000000000..b06fb9ce5f --- /dev/null +++ b/langchain/chains/graph_qa/cypher.py @@ -0,0 +1,90 @@ +"""Question answering over a graph.""" +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from pydantic import Field + +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.manager import CallbackManagerForChainRun +from langchain.chains.base import Chain +from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, PROMPT +from langchain.chains.llm import LLMChain +from langchain.graphs.neo4j_graph import Neo4jGraph +from langchain.prompts.base import BasePromptTemplate + + +class GraphCypherQAChain(Chain): + """Chain for question-answering against a graph by generating Cypher statements.""" + + graph: Neo4jGraph = Field(exclude=True) + cypher_generation_chain: LLMChain + qa_chain: LLMChain + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + + @property + def input_keys(self) -> List[str]: + """Return the input keys. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return the output keys. + + :meta private: + """ + _output_keys = [self.output_key] + return _output_keys + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + *, + qa_prompt: BasePromptTemplate = PROMPT, + cypher_prompt: BasePromptTemplate = CYPHER_GENERATION_PROMPT, + **kwargs: Any, + ) -> GraphCypherQAChain: + """Initialize from LLM.""" + qa_chain = LLMChain(llm=llm, prompt=qa_prompt) + cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt) + + return cls( + qa_chain=qa_chain, + cypher_generation_chain=cypher_generation_chain, + **kwargs, + ) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + """Generate Cypher statement, use it to look up in db and answer question.""" + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + question = inputs[self.input_key] + + generated_cypher = self.cypher_generation_chain.run( + {"question": question, "schema": self.graph.get_schema}, callbacks=callbacks + ) + + _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose) + _run_manager.on_text( + generated_cypher, color="green", end="\n", verbose=self.verbose + ) + context = self.graph.query(generated_cypher) + + _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) + _run_manager.on_text( + str(context), color="green", end="\n", verbose=self.verbose + ) + result = self.qa_chain( + {"question": question, "context": context}, + callbacks=callbacks, + ) + return {self.output_key: result[self.qa_chain.output_key]} diff --git a/langchain/chains/graph_qa/prompts.py b/langchain/chains/graph_qa/prompts.py index 6fdf524764..5526c67c47 100644 --- a/langchain/chains/graph_qa/prompts.py +++ b/langchain/chains/graph_qa/prompts.py @@ -32,3 +32,19 @@ Helpful Answer:""" PROMPT = PromptTemplate( template=prompt_template, input_variables=["context", "question"] ) + +CYPHER_GENERATION_TEMPLATE = """Task:Generate Cypher statement to query a graph database. +Instructions: +Use only the provided relationship types and properties in the schema. +Do not use any other relationship types or properties that are not provided. +Schema: +{schema} +Note: Do not include any explanations or apologies in your responses. +Do not respond to any questions that might ask anything else than for you to construct a Cypher statement. +Do not include any text except the generated Cypher statement. + +The question is: +{question}""" +CYPHER_GENERATION_PROMPT = PromptTemplate( + input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE +) diff --git a/langchain/graphs/__init__.py b/langchain/graphs/__init__.py index 68851c6ddb..72b0976a3c 100644 --- a/langchain/graphs/__init__.py +++ b/langchain/graphs/__init__.py @@ -1,4 +1,5 @@ """Graph implementations.""" +from langchain.graphs.neo4j_graph import Neo4jGraph from langchain.graphs.networkx_graph import NetworkxEntityGraph -__all__ = ["NetworkxEntityGraph"] +__all__ = ["NetworkxEntityGraph", "Neo4jGraph"] diff --git a/langchain/graphs/neo4j_graph.py b/langchain/graphs/neo4j_graph.py new file mode 100644 index 0000000000..4433771a25 --- /dev/null +++ b/langchain/graphs/neo4j_graph.py @@ -0,0 +1,100 @@ +from typing import Any, Dict, List + +node_properties_query = """ +CALL apoc.meta.data() +YIELD label, other, elementType, type, property +WHERE NOT type = "RELATIONSHIP" AND elementType = "node" +WITH label AS nodeLabels, collect({property:property, type:type}) AS properties +RETURN {labels: nodeLabels, properties: properties} AS output + +""" + +rel_properties_query = """ +CALL apoc.meta.data() +YIELD label, other, elementType, type, property +WHERE NOT type = "RELATIONSHIP" AND elementType = "relationship" +WITH label AS nodeLabels, collect({property:property, type:type}) AS properties +RETURN {type: nodeLabels, properties: properties} AS output +""" + +rel_query = """ +CALL apoc.meta.data() +YIELD label, other, elementType, type, property +WHERE type = "RELATIONSHIP" AND elementType = "node" +RETURN "(:" + label + ")-[:" + property + "]->(:" + toString(other[0]) + ")" AS output +""" + + +class Neo4jGraph: + """Neo4j wrapper for graph operations.""" + + def __init__( + self, url: str, username: str, password: str, database: str = "neo4j" + ) -> None: + """Create a new Neo4j graph wrapper instance.""" + try: + import neo4j + except ImportError: + raise ValueError( + "Could not import neo4j python package. " + "Please install it with `pip install neo4j`." + ) + + self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password)) + self._database = database + self.schema = "" + # Verify connection + try: + self._driver.verify_connectivity() + except neo4j.exceptions.ServiceUnavailable: + raise ValueError( + "Could not connect to Neo4j database. " + "Please ensure that the url is correct" + ) + except neo4j.exceptions.AuthError: + raise ValueError( + "Could not connect to Neo4j database. " + "Please ensure that the username and password are correct" + ) + # Set schema + try: + self.refresh_schema() + except neo4j.exceptions.ClientError: + raise ValueError( + "Could not use APOC procedures. " + "Please install the APOC plugin in Neo4j." + ) + + @property + def get_schema(self) -> str: + """Returns the schema of the Neo4j database""" + return self.schema + + def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: + """Query Neo4j database.""" + from neo4j.exceptions import CypherSyntaxError + + with self._driver.session(database=self._database) as session: + try: + data = session.run(query, params) + # Hard limit of 50 results + return [r.data() for r in data][:50] + except CypherSyntaxError as e: + raise ValueError("Generated Cypher Statement is not valid\n" f"{e}") + + def refresh_schema(self) -> None: + """ + Refreshes the Neo4j graph schema information. + """ + node_properties = self.query(node_properties_query) + relationships_properties = self.query(rel_properties_query) + relationships = self.query(rel_query) + + self.schema = f""" + Node properties are the following: + {[el['output'] for el in node_properties]} + Relationship properties are the following: + {[el['output'] for el in relationships_properties]} + The relationships are the following: + {[el['output'] for el in relationships]} + """ diff --git a/poetry.lock b/poetry.lock index a8bf8cae9e..4a59e44bd7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. +# This file is automatically @generated by Poetry and should not be changed by hand. [[package]] name = "absl-py" @@ -4808,6 +4808,24 @@ nbformat = "*" sphinx = ">=1.8" traitlets = ">=5" +[[package]] +name = "neo4j" +version = "5.8.1" +description = "Neo4j Bolt driver for Python" +category = "main" +optional = true +python-versions = ">=3.7" +files = [ + {file = "neo4j-5.8.1.tar.gz", hash = "sha256:79c947f402e9f8624587add7b8af742b38cbcdf364d48021c5bff9220457965b"}, +] + +[package.dependencies] +pytz = "*" + +[package.extras] +numpy = ["numpy (>=1.7.0,<2.0.0)"] +pandas = ["numpy (>=1.7.0,<2.0.0)", "pandas (>=1.1.0,<3.0.0)"] + [[package]] name = "nest-asyncio" version = "1.5.6" @@ -6635,6 +6653,7 @@ files = [ {file = "pylance-0.4.12-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:2b86fb8dccc03094c0db37bef0d91bda60e8eb0d1eddf245c6971450c8d8a53f"}, {file = "pylance-0.4.12-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:0bc82914b13204187d673b5f3d45f93219c38a0e9d0542ba251074f639669789"}, {file = "pylance-0.4.12-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5a4bcce77f99ecd4cbebbadb01e58d5d8138d40eb56bdcdbc3b20b0475e7a472"}, + {file = "pylance-0.4.12-cp38-abi3-win_amd64.whl", hash = "sha256:9616931c5300030adb9626d22515710a127d1e46a46737a7a0f980b52f13627c"}, ] [package.dependencies] @@ -10359,14 +10378,14 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\ cffi = ["cffi (>=1.11)"] [extras] -all = ["O365", "aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "azure-cosmos", "azure-identity", "beautifulsoup4", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "gptcache", "hnswlib", "html2text", "huggingface_hub", "jina", "jinja2", "jq", "lancedb", "lark", "lxml", "manifest-ml", "networkx", "nlpcloud", "nltk", "nomic", "openai", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "protobuf", "psycopg2-binary", "pyowm", "pypdf", "pytesseract", "pyvespa", "qdrant-client", "redis", "sentence-transformers", "spacy", "steamship", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"] -azure = ["azure-core", "azure-cosmos", "azure-identity", "openai"] +all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect", "azure-cosmos", "lancedb", "lark", "pexpect", "pyvespa", "O365", "jq", "docarray", "protobuf", "hnswlib", "steamship", "pdfminer-six", "lxml", "requests-toolbelt", "neo4j"] +azure = ["azure-identity", "azure-cosmos", "openai", "azure-core"] cohere = ["cohere"] embeddings = ["sentence-transformers"] -extended-testing = ["atlassian-python-api", "beautifulsoup4", "beautifulsoup4", "chardet", "gql", "html2text", "jq", "lxml", "pandas", "pdfminer-six", "psychicapi", "pymupdf", "pypdf", "pypdfium2", "requests-toolbelt", "telethon", "tqdm", "zep-python"] -hnswlib = ["docarray", "hnswlib", "protobuf"] +extended-testing = ["beautifulsoup4", "chardet", "jq", "pdfminer-six", "pypdf", "pymupdf", "pypdfium2", "tqdm", "lxml", "atlassian-python-api", "beautifulsoup4", "pandas", "telethon", "psychicapi", "zep-python", "gql", "requests-toolbelt", "html2text"] +hnswlib = ["docarray", "protobuf", "hnswlib"] in-memory-store = ["docarray"] -llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "torch", "transformers"] +llms = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"] openai = ["openai", "tiktoken"] qdrant = ["qdrant-client"] text-helpers = ["chardet"] @@ -10374,4 +10393,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "086b4d4d5ca5d0be9d12105f926d667926170bca706a6c6ee152637389d2a22d" +content-hash = "e41e0253ccc1137f3f4a8f3627fb5165f611dfd57e65960c95421e868f4defae" diff --git a/pyproject.toml b/pyproject.toml index a0ea369a48..0e8f1684da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,12 +88,12 @@ pypdfium2 = {version = "^4.10.0", optional = true} gql = {version = "^3.4.1", optional = true} pandas = {version = "^2.0.1", optional = true} telethon = {version = "^1.28.5", optional = true} +neo4j = {version = "^5.8.1", optional = true} psychicapi = {version = "^0.2", optional = true} zep-python = {version="^0.25", optional=true} chardet = {version="^5.1.0", optional=true} requests-toolbelt = {version = "^1.0.0", optional = true} - [tool.poetry.group.docs.dependencies] autodoc_pydantic = "^1.8.0" myst_parser = "^0.18.1" @@ -185,7 +185,67 @@ in_memory_store = ["docarray"] hnswlib = ["docarray", "protobuf", "hnswlib"] embeddings = ["sentence-transformers"] azure = ["azure-identity", "azure-cosmos", "openai", "azure-core"] -all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect", "azure-cosmos", "lancedb", "lark", "pexpect", "pyvespa", "O365", "jq", "docarray", "protobuf", "hnswlib", "steamship", "pdfminer-six", "lxml"] +all = [ + "anthropic", + "cohere", + "openai", + "nlpcloud", + "huggingface_hub", + "jina", + "manifest-ml", + "elasticsearch", + "opensearch-py", + "google-search-results", + "faiss-cpu", + "sentence-transformers", + "transformers", + "spacy", + "nltk", + "wikipedia", + "beautifulsoup4", + "tiktoken", + "torch", + "jinja2", + "pinecone-client", + "pinecone-text", + "weaviate-client", + "redis", + "google-api-python-client", + "wolframalpha", + "qdrant-client", + "tensorflow-text", + "pypdf", + "networkx", + "nomic", + "aleph-alpha-client", + "deeplake", + "pgvector", + "psycopg2-binary", + "pyowm", + "pytesseract", + "html2text", + "atlassian-python-api", + "gptcache", + "duckduckgo-search", + "arxiv", + "azure-identity", + "clickhouse-connect", + "azure-cosmos", + "lancedb", + "lark", + "pexpect", + "pyvespa", + "O365", + "jq", + "docarray", + "protobuf", + "hnswlib", + "steamship", + "pdfminer-six", + "lxml", + "requests-toolbelt", + "neo4j", +] # An extra used to be able to add extended testing. # Please use new-line on formatting to make it easier to add new packages without diff --git a/tests/integration_tests/chains/test_graph_database.py b/tests/integration_tests/chains/test_graph_database.py new file mode 100644 index 0000000000..10a00a2d0f --- /dev/null +++ b/tests/integration_tests/chains/test_graph_database.py @@ -0,0 +1,60 @@ +"""Test Graph Database Chain.""" +import os + +from langchain.chains.graph_qa.cypher import GraphCypherQAChain +from langchain.graphs import Neo4jGraph +from langchain.llms.openai import OpenAI + + +def test_connect_neo4j() -> None: + """Test that Neo4j database is correctly instantiated and connected.""" + url = os.environ.get("NEO4J_URL") + username = os.environ.get("NEO4J_USERNAME") + password = os.environ.get("NEO4J_PASSWORD") + assert url is not None + assert username is not None + assert password is not None + + graph = Neo4jGraph( + url=url, + username=username, + password=password, + ) + + output = graph.query( + """ + RETURN "test" AS output + """ + ) + expected_output = [{"output": "test"}] + assert output == expected_output + + +def test_cypher_generating_run() -> None: + """Test that Cypher statement is correctly generated and executed.""" + url = os.environ.get("NEO4J_URL") + username = os.environ.get("NEO4J_USERNAME") + password = os.environ.get("NEO4J_PASSWORD") + assert url is not None + assert username is not None + assert password is not None + + graph = Neo4jGraph( + url=url, + username=username, + password=password, + ) + # Delete all nodes in the graph + graph.query("MATCH (n) DETACH DELETE n") + # Create two nodes and a relationship + graph.query( + "CREATE (a:Actor {name:'Bruce Willis'})" + "-[:ACTED_IN]->(:Movie {title: 'Pulp Fiction'})" + ) + # Refresh schema information + graph.refresh_schema() + + chain = GraphCypherQAChain.from_llm(OpenAI(temperature=0), graph=graph) + output = chain.run("Who played in Pulp Fiction?") + expected_output = " Bruce Willis played in Pulp Fiction." + assert output == expected_output