Neptune graph and openCypher QA Chain (#8035)

## Description
This PR adds a graph class and an openCypher QA chain to work with the
Amazon Neptune database.

## Dependencies
`requests` which is included in the LangChain dependencies.

## Maintainers for Review
@krlawrence
@baskaryan

### Twitter handle
pjain7
pull/8046/head
Piyush Jain 1 year ago committed by GitHub
parent 995220b797
commit 31b7ddc12c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,52 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Neptune Open Cypher QA Chain\n",
"This QA chain queries Neptune graph database using openCypher and returns human readable response\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.graphs.neptune_graph import NeptuneGraph\n",
"\n",
"\n",
"host = \"<neptune-host>\"\n",
"port = 80\n",
"use_https = False\n",
"\n",
"graph = NeptuneGraph(host=host, port=port, use_https=use_https)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.chains.graph_qa.neptune_cypher import NeptuneOpenCypherQAChain\n",
"\n",
"llm = ChatOpenAI(temperature=0, model=\"gpt-4\")\n",
"\n",
"chain = NeptuneOpenCypherQAChain.from_llm(llm=llm, graph=graph)\n",
"\n",
"chain.run(\"how many outgoing routes does the Austin airport have?\")"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -0,0 +1,141 @@
from __future__ import annotations
import re
from typing import Any, Dict, List, Optional
from pydantic import Field
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import (
CYPHER_QA_PROMPT,
NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
)
from langchain.chains.llm import LLMChain
from langchain.graphs import NeptuneGraph
from langchain.prompts.base import BasePromptTemplate
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
def extract_cypher(text: str) -> str:
# The pattern to find Cypher code enclosed in triple backticks
pattern = r"```(.*?)```"
# Find all matches in the input text
matches = re.findall(pattern, text, re.DOTALL)
return matches[0] if matches else text
class NeptuneOpenCypherQAChain(Chain):
"""Chain for question-answering against a Neptune graph
by generating openCypher statements.
Example:
.. code-block:: python
chain = NeptuneOpenCypherQAChain.from_llm(
llm=llm,
graph=graph
)
response = chain.run(query)
"""
graph: NeptuneGraph = Field(exclude=True)
cypher_generation_chain: LLMChain
qa_chain: LLMChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
top_k: int = 10
return_intermediate_steps: bool = False
"""Whether or not to return the intermediate steps along with the final answer."""
return_direct: bool = False
"""Whether or not to return the result of querying the graph directly."""
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
cypher_prompt: BasePromptTemplate = NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
**kwargs: Any,
) -> NeptuneOpenCypherQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt)
return cls(
qa_chain=qa_chain,
cypher_generation_chain=cypher_generation_chain,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Generate Cypher statement, use it to look up in db and answer question."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
question = inputs[self.input_key]
intermediate_steps: List = []
generated_cypher = self.cypher_generation_chain.run(
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
)
# Extract Cypher code if it is wrapped in backticks
generated_cypher = extract_cypher(generated_cypher)
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_cypher, color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"query": generated_cypher})
context = self.graph.query(generated_cypher)
if self.return_direct:
final_result = context
else:
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"context": context})
result = self.qa_chain(
{"question": question, "context": context},
callbacks=callbacks,
)
final_result = result[self.qa_chain.output_key]
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
return chain_result

@ -196,3 +196,21 @@ Helpful Answer:"""
SPARQL_QA_PROMPT = PromptTemplate(
input_variables=["context", "prompt"], template=SPARQL_QA_TEMPLATE
)
NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS = """
Instructions:
Generate the query in openCypher format and follow these rules:
Do not use `NONE`, `ALL` or `ANY` predicate functions, rather use list comprehensions.
Do not use `REDUCE` function. Rather use a combination of list comprehension and the `UNWIND` clause to achieve similar results.
Do not use `FOREACH` clause. Rather use a combination of `WITH` and `UNWIND` clauses to achieve similar results.
\n"""
NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(
"Instructions:", NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS
)
NEPTUNE_OPENCYPHER_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "question"],
template=NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE,
)

@ -3,6 +3,7 @@ from langchain.graphs.hugegraph import HugeGraph
from langchain.graphs.kuzu_graph import KuzuGraph
from langchain.graphs.nebula_graph import NebulaGraph
from langchain.graphs.neo4j_graph import Neo4jGraph
from langchain.graphs.neptune_graph import NeptuneGraph
from langchain.graphs.networkx_graph import NetworkxEntityGraph
from langchain.graphs.rdf_graph import RdfGraph
@ -10,6 +11,7 @@ __all__ = [
"NetworkxEntityGraph",
"Neo4jGraph",
"NebulaGraph",
"NeptuneGraph",
"KuzuGraph",
"HugeGraph",
"RdfGraph",

@ -0,0 +1,199 @@
import json
from typing import Any, Dict, List, Tuple, Union
import requests
class NeptuneQueryException(Exception):
"""A class to handle queries that fail to execute"""
def __init__(self, exception: Union[str, Dict]):
if isinstance(exception, dict):
self.message = exception["message"] if "message" in exception else "unknown"
self.details = exception["details"] if "details" in exception else "unknown"
else:
self.message = exception
self.details = "unknown"
def get_message(self) -> str:
return self.message
def get_details(self) -> Any:
return self.details
class NeptuneGraph:
"""Neptune wrapper for graph operations. This version
does not support Sigv4 signing of requests.
Example:
.. code-block:: python
graph = NeptuneGraph(
host='<my-cluster>',
port=8182
)
"""
def __init__(self, host: str, port: int = 8182, use_https: bool = True) -> None:
"""Create a new Neptune graph wrapper instance."""
if use_https:
self.summary_url = (
f"https://{host}:{port}/pg/statistics/summary?mode=detailed"
)
self.query_url = f"https://{host}:{port}/openCypher"
else:
self.summary_url = (
f"http://{host}:{port}/pg/statistics/summary?mode=detailed"
)
self.query_url = f"http://{host}:{port}/openCypher"
# Set schema
try:
self._refresh_schema()
except NeptuneQueryException:
raise ValueError("Could not get schema for Neptune database")
@property
def get_schema(self) -> str:
"""Returns the schema of the Neptune database"""
return self.schema
def query(self, query: str, params: dict = {}) -> Dict[str, Any]:
"""Query Neptune database."""
response = requests.post(url=self.query_url, data={"query": query})
if response.ok:
results = json.loads(response.content.decode())
return results
else:
raise NeptuneQueryException(
{
"message": "The generated query failed to execute",
"details": response.content.decode(),
}
)
def _get_summary(self) -> Dict:
response = requests.get(url=self.summary_url)
if not response.ok:
raise NeptuneQueryException(
{
"message": (
"Summary API is not available for this instance of Neptune,"
"ensure the engine version is >=1.2.1.0"
),
"details": response.content.decode(),
}
)
try:
summary = response.json()["payload"]["graphSummary"]
except Exception:
raise NeptuneQueryException(
{
"message": "Summary API did not return a valid response.",
"details": response.content.decode(),
}
)
else:
return summary
def _get_labels(self) -> Tuple[List[str], List[str]]:
"""Get node and edge labels from the Neptune statistics summary"""
summary = self._get_summary()
n_labels = summary["nodeLabels"]
e_labels = summary["edgeLabels"]
return n_labels, e_labels
def _get_triples(self, e_labels: List[str]) -> List[str]:
triple_query = """
MATCH (a)-[e:{e_label}]->(b)
WITH a,e,b LIMIT 3000
RETURN DISTINCT labels(a) AS from, type(e) AS edge, labels(b) AS to
LIMIT 10
"""
triple_template = "(:{a})-[:{e}]->(:{b})"
triple_schema = []
for label in e_labels:
q = triple_query.format(e_label=label)
data = self.query(q)
for d in data["results"]:
triple = triple_template.format(
a=d["from"][0], e=d["edge"], b=d["to"][0]
)
triple_schema.append(triple)
return triple_schema
def _get_node_properties(self, n_labels: List[str], types: Dict) -> List:
node_properties_query = """
MATCH (a:{n_label})
RETURN properties(a) AS props
LIMIT 100
"""
node_properties = []
for label in n_labels:
q = node_properties_query.format(n_label=label)
data = {"label": label, "properties": self.query(q)["results"]}
s = set({})
for p in data["properties"]:
for k, v in p["props"].items():
s.add((k, types[type(v).__name__]))
np = {
"properties": [{"property": k, "type": v} for k, v in s],
"labels": label,
}
node_properties.append(np)
return node_properties
def _get_edge_properties(self, e_labels: List[str], types: Dict[str, Any]) -> List:
edge_properties_query = """
MATCH ()-[e:{e_label}]->()
RETURN properties(e) AS props
LIMIT 100
"""
edge_properties = []
for label in e_labels:
q = edge_properties_query.format(e_label=label)
data = {"label": label, "properties": self.query(q)["results"]}
s = set({})
for p in data["properties"]:
for k, v in p["props"].items():
s.add((k, types[type(v).__name__]))
ep = {
"type": label,
"properties": [{"property": k, "type": v} for k, v in s],
}
edge_properties.append(ep)
return edge_properties
def _refresh_schema(self) -> None:
"""
Refreshes the Neptune graph schema information.
"""
types = {
"str": "STRING",
"float": "DOUBLE",
"int": "INTEGER",
"list": "LIST",
"dict": "MAP",
}
n_labels, e_labels = self._get_labels()
triple_schema = self._get_triples(e_labels)
node_properties = self._get_node_properties(n_labels, types)
edge_properties = self._get_edge_properties(e_labels, types)
self.schema = f"""
Node properties are the following:
{node_properties}
Relationship properties are the following:
{edge_properties}
The relationships are the following:
{triple_schema}
"""
Loading…
Cancel
Save