mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
Neptune graph and openCypher QA Chain (#8035)
## Description This PR adds a graph class and an openCypher QA chain to work with the Amazon Neptune database. ## Dependencies `requests` which is included in the LangChain dependencies. ## Maintainers for Review @krlawrence @baskaryan ### Twitter handle pjain7
This commit is contained in:
parent
995220b797
commit
31b7ddc12c
@ -0,0 +1,52 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Neptune Open Cypher QA Chain\n",
|
||||||
|
"This QA chain queries Neptune graph database using openCypher and returns human readable response\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.graphs.neptune_graph import NeptuneGraph\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"host = \"<neptune-host>\"\n",
|
||||||
|
"port = 80\n",
|
||||||
|
"use_https = False\n",
|
||||||
|
"\n",
|
||||||
|
"graph = NeptuneGraph(host=host, port=port, use_https=use_https)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chat_models import ChatOpenAI\n",
|
||||||
|
"from langchain.chains.graph_qa.neptune_cypher import NeptuneOpenCypherQAChain\n",
|
||||||
|
"\n",
|
||||||
|
"llm = ChatOpenAI(temperature=0, model=\"gpt-4\")\n",
|
||||||
|
"\n",
|
||||||
|
"chain = NeptuneOpenCypherQAChain.from_llm(llm=llm, graph=graph)\n",
|
||||||
|
"\n",
|
||||||
|
"chain.run(\"how many outgoing routes does the Austin airport have?\")"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"language_info": {
|
||||||
|
"name": "python"
|
||||||
|
},
|
||||||
|
"orig_nbformat": 4
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
141
langchain/chains/graph_qa/neptune_cypher.py
Normal file
141
langchain/chains/graph_qa/neptune_cypher.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from langchain.base_language import BaseLanguageModel
|
||||||
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chains.graph_qa.prompts import (
|
||||||
|
CYPHER_QA_PROMPT,
|
||||||
|
NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
|
||||||
|
)
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.graphs import NeptuneGraph
|
||||||
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
|
|
||||||
|
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||||
|
|
||||||
|
|
||||||
|
def extract_cypher(text: str) -> str:
|
||||||
|
# The pattern to find Cypher code enclosed in triple backticks
|
||||||
|
pattern = r"```(.*?)```"
|
||||||
|
|
||||||
|
# Find all matches in the input text
|
||||||
|
matches = re.findall(pattern, text, re.DOTALL)
|
||||||
|
|
||||||
|
return matches[0] if matches else text
|
||||||
|
|
||||||
|
|
||||||
|
class NeptuneOpenCypherQAChain(Chain):
|
||||||
|
"""Chain for question-answering against a Neptune graph
|
||||||
|
by generating openCypher statements.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
chain = NeptuneOpenCypherQAChain.from_llm(
|
||||||
|
llm=llm,
|
||||||
|
graph=graph
|
||||||
|
)
|
||||||
|
response = chain.run(query)
|
||||||
|
"""
|
||||||
|
|
||||||
|
graph: NeptuneGraph = Field(exclude=True)
|
||||||
|
cypher_generation_chain: LLMChain
|
||||||
|
qa_chain: LLMChain
|
||||||
|
input_key: str = "query" #: :meta private:
|
||||||
|
output_key: str = "result" #: :meta private:
|
||||||
|
top_k: int = 10
|
||||||
|
return_intermediate_steps: bool = False
|
||||||
|
"""Whether or not to return the intermediate steps along with the final answer."""
|
||||||
|
return_direct: bool = False
|
||||||
|
"""Whether or not to return the result of querying the graph directly."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_keys(self) -> List[str]:
|
||||||
|
"""Return the input keys.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
return [self.input_key]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_keys(self) -> List[str]:
|
||||||
|
"""Return the output keys.
|
||||||
|
|
||||||
|
:meta private:
|
||||||
|
"""
|
||||||
|
_output_keys = [self.output_key]
|
||||||
|
return _output_keys
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_llm(
|
||||||
|
cls,
|
||||||
|
llm: BaseLanguageModel,
|
||||||
|
*,
|
||||||
|
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
||||||
|
cypher_prompt: BasePromptTemplate = NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> NeptuneOpenCypherQAChain:
|
||||||
|
"""Initialize from LLM."""
|
||||||
|
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||||
|
cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
qa_chain=qa_chain,
|
||||||
|
cypher_generation_chain=cypher_generation_chain,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Generate Cypher statement, use it to look up in db and answer question."""
|
||||||
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
|
callbacks = _run_manager.get_child()
|
||||||
|
question = inputs[self.input_key]
|
||||||
|
|
||||||
|
intermediate_steps: List = []
|
||||||
|
|
||||||
|
generated_cypher = self.cypher_generation_chain.run(
|
||||||
|
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract Cypher code if it is wrapped in backticks
|
||||||
|
generated_cypher = extract_cypher(generated_cypher)
|
||||||
|
|
||||||
|
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
|
||||||
|
_run_manager.on_text(
|
||||||
|
generated_cypher, color="green", end="\n", verbose=self.verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
intermediate_steps.append({"query": generated_cypher})
|
||||||
|
|
||||||
|
context = self.graph.query(generated_cypher)
|
||||||
|
|
||||||
|
if self.return_direct:
|
||||||
|
final_result = context
|
||||||
|
else:
|
||||||
|
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
||||||
|
_run_manager.on_text(
|
||||||
|
str(context), color="green", end="\n", verbose=self.verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
intermediate_steps.append({"context": context})
|
||||||
|
|
||||||
|
result = self.qa_chain(
|
||||||
|
{"question": question, "context": context},
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
|
final_result = result[self.qa_chain.output_key]
|
||||||
|
|
||||||
|
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
||||||
|
if self.return_intermediate_steps:
|
||||||
|
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
||||||
|
|
||||||
|
return chain_result
|
@ -196,3 +196,21 @@ Helpful Answer:"""
|
|||||||
SPARQL_QA_PROMPT = PromptTemplate(
|
SPARQL_QA_PROMPT = PromptTemplate(
|
||||||
input_variables=["context", "prompt"], template=SPARQL_QA_TEMPLATE
|
input_variables=["context", "prompt"], template=SPARQL_QA_TEMPLATE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS = """
|
||||||
|
Instructions:
|
||||||
|
Generate the query in openCypher format and follow these rules:
|
||||||
|
Do not use `NONE`, `ALL` or `ANY` predicate functions, rather use list comprehensions.
|
||||||
|
Do not use `REDUCE` function. Rather use a combination of list comprehension and the `UNWIND` clause to achieve similar results.
|
||||||
|
Do not use `FOREACH` clause. Rather use a combination of `WITH` and `UNWIND` clauses to achieve similar results.
|
||||||
|
\n"""
|
||||||
|
|
||||||
|
NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(
|
||||||
|
"Instructions:", NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS
|
||||||
|
)
|
||||||
|
|
||||||
|
NEPTUNE_OPENCYPHER_GENERATION_PROMPT = PromptTemplate(
|
||||||
|
input_variables=["schema", "question"],
|
||||||
|
template=NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE,
|
||||||
|
)
|
||||||
|
@ -3,6 +3,7 @@ from langchain.graphs.hugegraph import HugeGraph
|
|||||||
from langchain.graphs.kuzu_graph import KuzuGraph
|
from langchain.graphs.kuzu_graph import KuzuGraph
|
||||||
from langchain.graphs.nebula_graph import NebulaGraph
|
from langchain.graphs.nebula_graph import NebulaGraph
|
||||||
from langchain.graphs.neo4j_graph import Neo4jGraph
|
from langchain.graphs.neo4j_graph import Neo4jGraph
|
||||||
|
from langchain.graphs.neptune_graph import NeptuneGraph
|
||||||
from langchain.graphs.networkx_graph import NetworkxEntityGraph
|
from langchain.graphs.networkx_graph import NetworkxEntityGraph
|
||||||
from langchain.graphs.rdf_graph import RdfGraph
|
from langchain.graphs.rdf_graph import RdfGraph
|
||||||
|
|
||||||
@ -10,6 +11,7 @@ __all__ = [
|
|||||||
"NetworkxEntityGraph",
|
"NetworkxEntityGraph",
|
||||||
"Neo4jGraph",
|
"Neo4jGraph",
|
||||||
"NebulaGraph",
|
"NebulaGraph",
|
||||||
|
"NeptuneGraph",
|
||||||
"KuzuGraph",
|
"KuzuGraph",
|
||||||
"HugeGraph",
|
"HugeGraph",
|
||||||
"RdfGraph",
|
"RdfGraph",
|
||||||
|
199
langchain/graphs/neptune_graph.py
Normal file
199
langchain/graphs/neptune_graph.py
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
import json
|
||||||
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
class NeptuneQueryException(Exception):
|
||||||
|
"""A class to handle queries that fail to execute"""
|
||||||
|
|
||||||
|
def __init__(self, exception: Union[str, Dict]):
|
||||||
|
if isinstance(exception, dict):
|
||||||
|
self.message = exception["message"] if "message" in exception else "unknown"
|
||||||
|
self.details = exception["details"] if "details" in exception else "unknown"
|
||||||
|
else:
|
||||||
|
self.message = exception
|
||||||
|
self.details = "unknown"
|
||||||
|
|
||||||
|
def get_message(self) -> str:
|
||||||
|
return self.message
|
||||||
|
|
||||||
|
def get_details(self) -> Any:
|
||||||
|
return self.details
|
||||||
|
|
||||||
|
|
||||||
|
class NeptuneGraph:
|
||||||
|
"""Neptune wrapper for graph operations. This version
|
||||||
|
does not support Sigv4 signing of requests.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
graph = NeptuneGraph(
|
||||||
|
host='<my-cluster>',
|
||||||
|
port=8182
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, host: str, port: int = 8182, use_https: bool = True) -> None:
|
||||||
|
"""Create a new Neptune graph wrapper instance."""
|
||||||
|
|
||||||
|
if use_https:
|
||||||
|
self.summary_url = (
|
||||||
|
f"https://{host}:{port}/pg/statistics/summary?mode=detailed"
|
||||||
|
)
|
||||||
|
self.query_url = f"https://{host}:{port}/openCypher"
|
||||||
|
else:
|
||||||
|
self.summary_url = (
|
||||||
|
f"http://{host}:{port}/pg/statistics/summary?mode=detailed"
|
||||||
|
)
|
||||||
|
self.query_url = f"http://{host}:{port}/openCypher"
|
||||||
|
|
||||||
|
# Set schema
|
||||||
|
try:
|
||||||
|
self._refresh_schema()
|
||||||
|
except NeptuneQueryException:
|
||||||
|
raise ValueError("Could not get schema for Neptune database")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def get_schema(self) -> str:
|
||||||
|
"""Returns the schema of the Neptune database"""
|
||||||
|
return self.schema
|
||||||
|
|
||||||
|
def query(self, query: str, params: dict = {}) -> Dict[str, Any]:
|
||||||
|
"""Query Neptune database."""
|
||||||
|
response = requests.post(url=self.query_url, data={"query": query})
|
||||||
|
if response.ok:
|
||||||
|
results = json.loads(response.content.decode())
|
||||||
|
return results
|
||||||
|
else:
|
||||||
|
raise NeptuneQueryException(
|
||||||
|
{
|
||||||
|
"message": "The generated query failed to execute",
|
||||||
|
"details": response.content.decode(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_summary(self) -> Dict:
|
||||||
|
response = requests.get(url=self.summary_url)
|
||||||
|
if not response.ok:
|
||||||
|
raise NeptuneQueryException(
|
||||||
|
{
|
||||||
|
"message": (
|
||||||
|
"Summary API is not available for this instance of Neptune,"
|
||||||
|
"ensure the engine version is >=1.2.1.0"
|
||||||
|
),
|
||||||
|
"details": response.content.decode(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
summary = response.json()["payload"]["graphSummary"]
|
||||||
|
except Exception:
|
||||||
|
raise NeptuneQueryException(
|
||||||
|
{
|
||||||
|
"message": "Summary API did not return a valid response.",
|
||||||
|
"details": response.content.decode(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return summary
|
||||||
|
|
||||||
|
def _get_labels(self) -> Tuple[List[str], List[str]]:
|
||||||
|
"""Get node and edge labels from the Neptune statistics summary"""
|
||||||
|
summary = self._get_summary()
|
||||||
|
n_labels = summary["nodeLabels"]
|
||||||
|
e_labels = summary["edgeLabels"]
|
||||||
|
return n_labels, e_labels
|
||||||
|
|
||||||
|
def _get_triples(self, e_labels: List[str]) -> List[str]:
|
||||||
|
triple_query = """
|
||||||
|
MATCH (a)-[e:{e_label}]->(b)
|
||||||
|
WITH a,e,b LIMIT 3000
|
||||||
|
RETURN DISTINCT labels(a) AS from, type(e) AS edge, labels(b) AS to
|
||||||
|
LIMIT 10
|
||||||
|
"""
|
||||||
|
|
||||||
|
triple_template = "(:{a})-[:{e}]->(:{b})"
|
||||||
|
triple_schema = []
|
||||||
|
for label in e_labels:
|
||||||
|
q = triple_query.format(e_label=label)
|
||||||
|
data = self.query(q)
|
||||||
|
for d in data["results"]:
|
||||||
|
triple = triple_template.format(
|
||||||
|
a=d["from"][0], e=d["edge"], b=d["to"][0]
|
||||||
|
)
|
||||||
|
triple_schema.append(triple)
|
||||||
|
|
||||||
|
return triple_schema
|
||||||
|
|
||||||
|
def _get_node_properties(self, n_labels: List[str], types: Dict) -> List:
|
||||||
|
node_properties_query = """
|
||||||
|
MATCH (a:{n_label})
|
||||||
|
RETURN properties(a) AS props
|
||||||
|
LIMIT 100
|
||||||
|
"""
|
||||||
|
node_properties = []
|
||||||
|
for label in n_labels:
|
||||||
|
q = node_properties_query.format(n_label=label)
|
||||||
|
data = {"label": label, "properties": self.query(q)["results"]}
|
||||||
|
s = set({})
|
||||||
|
for p in data["properties"]:
|
||||||
|
for k, v in p["props"].items():
|
||||||
|
s.add((k, types[type(v).__name__]))
|
||||||
|
|
||||||
|
np = {
|
||||||
|
"properties": [{"property": k, "type": v} for k, v in s],
|
||||||
|
"labels": label,
|
||||||
|
}
|
||||||
|
node_properties.append(np)
|
||||||
|
|
||||||
|
return node_properties
|
||||||
|
|
||||||
|
def _get_edge_properties(self, e_labels: List[str], types: Dict[str, Any]) -> List:
|
||||||
|
edge_properties_query = """
|
||||||
|
MATCH ()-[e:{e_label}]->()
|
||||||
|
RETURN properties(e) AS props
|
||||||
|
LIMIT 100
|
||||||
|
"""
|
||||||
|
edge_properties = []
|
||||||
|
for label in e_labels:
|
||||||
|
q = edge_properties_query.format(e_label=label)
|
||||||
|
data = {"label": label, "properties": self.query(q)["results"]}
|
||||||
|
s = set({})
|
||||||
|
for p in data["properties"]:
|
||||||
|
for k, v in p["props"].items():
|
||||||
|
s.add((k, types[type(v).__name__]))
|
||||||
|
|
||||||
|
ep = {
|
||||||
|
"type": label,
|
||||||
|
"properties": [{"property": k, "type": v} for k, v in s],
|
||||||
|
}
|
||||||
|
edge_properties.append(ep)
|
||||||
|
|
||||||
|
return edge_properties
|
||||||
|
|
||||||
|
def _refresh_schema(self) -> None:
|
||||||
|
"""
|
||||||
|
Refreshes the Neptune graph schema information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
types = {
|
||||||
|
"str": "STRING",
|
||||||
|
"float": "DOUBLE",
|
||||||
|
"int": "INTEGER",
|
||||||
|
"list": "LIST",
|
||||||
|
"dict": "MAP",
|
||||||
|
}
|
||||||
|
n_labels, e_labels = self._get_labels()
|
||||||
|
triple_schema = self._get_triples(e_labels)
|
||||||
|
node_properties = self._get_node_properties(n_labels, types)
|
||||||
|
edge_properties = self._get_edge_properties(e_labels, types)
|
||||||
|
|
||||||
|
self.schema = f"""
|
||||||
|
Node properties are the following:
|
||||||
|
{node_properties}
|
||||||
|
Relationship properties are the following:
|
||||||
|
{edge_properties}
|
||||||
|
The relationships are the following:
|
||||||
|
{triple_schema}
|
||||||
|
"""
|
Loading…
Reference in New Issue
Block a user