pull/20853/head
Eugene Yurtsev 1 month ago
parent 9c91e333cc
commit 5b7ad94a95

@ -0,0 +1 @@
"""Question answering over a knowledge graph."""

@ -0,0 +1,241 @@
"""Question answering over a graph."""
from __future__ import annotations
import re
from typing import Any, Dict, List, Optional
from langchain_community.graphs.arangodb_graph import ArangoGraph
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import (
AQL_FIX_PROMPT,
AQL_GENERATION_PROMPT,
AQL_QA_PROMPT,
)
from langchain.chains.llm import LLMChain
class ArangoGraphQAChain(Chain):
"""Chain for question-answering against a graph by generating AQL statements.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
graph: ArangoGraph = Field(exclude=True)
aql_generation_chain: LLMChain
aql_fix_chain: LLMChain
qa_chain: LLMChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
# Specifies the maximum number of AQL Query Results to return
top_k: int = 10
# Specifies the set of AQL Query Examples that promote few-shot-learning
aql_examples: str = ""
# Specify whether to return the AQL Query in the output dictionary
return_aql_query: bool = False
# Specify whether to return the AQL JSON Result in the output dictionary
return_aql_result: bool = False
# Specify the maximum amount of AQL Generation attempts that should be made
max_aql_generation_attempts: int = 3
@property
def input_keys(self) -> List[str]:
return [self.input_key]
@property
def output_keys(self) -> List[str]:
return [self.output_key]
@property
def _chain_type(self) -> str:
return "graph_aql_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
qa_prompt: BasePromptTemplate = AQL_QA_PROMPT,
aql_generation_prompt: BasePromptTemplate = AQL_GENERATION_PROMPT,
aql_fix_prompt: BasePromptTemplate = AQL_FIX_PROMPT,
**kwargs: Any,
) -> ArangoGraphQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
aql_generation_chain = LLMChain(llm=llm, prompt=aql_generation_prompt)
aql_fix_chain = LLMChain(llm=llm, prompt=aql_fix_prompt)
return cls(
qa_chain=qa_chain,
aql_generation_chain=aql_generation_chain,
aql_fix_chain=aql_fix_chain,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""
Generate an AQL statement from user input, use it retrieve a response
from an ArangoDB Database instance, and respond to the user input
in natural language.
Users can modify the following ArangoGraphQAChain Class Variables:
:var top_k: The maximum number of AQL Query Results to return
:type top_k: int
:var aql_examples: A set of AQL Query Examples that are passed to
the AQL Generation Prompt Template to promote few-shot-learning.
Defaults to an empty string.
:type aql_examples: str
:var return_aql_query: Whether to return the AQL Query in the
output dictionary. Defaults to False.
:type return_aql_query: bool
:var return_aql_result: Whether to return the AQL Query in the
output dictionary. Defaults to False
:type return_aql_result: bool
:var max_aql_generation_attempts: The maximum amount of AQL
Generation attempts to be made prior to raising the last
AQL Query Execution Error. Defaults to 3.
:type max_aql_generation_attempts: int
"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
user_input = inputs[self.input_key]
#########################
# Generate AQL Query #
aql_generation_output = self.aql_generation_chain.run(
{
"adb_schema": self.graph.schema,
"aql_examples": self.aql_examples,
"user_input": user_input,
},
callbacks=callbacks,
)
#########################
aql_query = ""
aql_error = ""
aql_result = None
aql_generation_attempt = 1
while (
aql_result is None
and aql_generation_attempt < self.max_aql_generation_attempts + 1
):
#####################
# Extract AQL Query #
pattern = r"```(?i:aql)?(.*?)```"
matches = re.findall(pattern, aql_generation_output, re.DOTALL)
if not matches:
_run_manager.on_text(
"Invalid Response: ", end="\n", verbose=self.verbose
)
_run_manager.on_text(
aql_generation_output, color="red", end="\n", verbose=self.verbose
)
raise ValueError(f"Response is Invalid: {aql_generation_output}")
aql_query = matches[0]
#####################
_run_manager.on_text(
f"AQL Query ({aql_generation_attempt}):", verbose=self.verbose
)
_run_manager.on_text(
aql_query, color="green", end="\n", verbose=self.verbose
)
#####################
# Execute AQL Query #
from arango import AQLQueryExecuteError
try:
aql_result = self.graph.query(aql_query, self.top_k)
except AQLQueryExecuteError as e:
aql_error = e.error_message
_run_manager.on_text(
"AQL Query Execution Error: ", end="\n", verbose=self.verbose
)
_run_manager.on_text(
aql_error, color="yellow", end="\n\n", verbose=self.verbose
)
########################
# Retry AQL Generation #
aql_generation_output = self.aql_fix_chain.run(
{
"adb_schema": self.graph.schema,
"aql_query": aql_query,
"aql_error": aql_error,
},
callbacks=callbacks,
)
########################
#####################
aql_generation_attempt += 1
if aql_result is None:
m = f"""
Maximum amount of AQL Query Generation attempts reached.
Unable to execute the AQL Query due to the following error:
{aql_error}
"""
raise ValueError(m)
_run_manager.on_text("AQL Result:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(aql_result), color="green", end="\n", verbose=self.verbose
)
########################
# Interpret AQL Result #
result = self.qa_chain(
{
"adb_schema": self.graph.schema,
"user_input": user_input,
"aql_query": aql_query,
"aql_result": aql_result,
},
callbacks=callbacks,
)
########################
# Return results #
result = {self.output_key: result[self.qa_chain.output_key]}
if self.return_aql_query:
result["aql_query"] = aql_query
if self.return_aql_result:
result["aql_result"] = aql_result
return result

@ -0,0 +1,100 @@
"""Question answering over a graph."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from langchain_community.graphs.networkx_graph import NetworkxEntityGraph, get_entities
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, GRAPH_QA_PROMPT
from langchain.chains.llm import LLMChain
class GraphQAChain(Chain):
"""Chain for question-answering against a graph.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
graph: NetworkxEntityGraph = Field(exclude=True)
entity_extraction_chain: LLMChain
qa_chain: LLMChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
qa_prompt: BasePromptTemplate = GRAPH_QA_PROMPT,
entity_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT,
**kwargs: Any,
) -> GraphQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
entity_chain = LLMChain(llm=llm, prompt=entity_prompt)
return cls(
qa_chain=qa_chain,
entity_extraction_chain=entity_chain,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""Extract entities, look up info and answer question."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
entity_string = self.entity_extraction_chain.run(question)
_run_manager.on_text("Entities Extracted:", end="\n", verbose=self.verbose)
_run_manager.on_text(
entity_string, color="green", end="\n", verbose=self.verbose
)
entities = get_entities(entity_string)
context = ""
all_triplets = []
for entity in entities:
all_triplets.extend(self.graph.get_entity_knowledge(entity))
context = "\n".join(all_triplets)
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(context, color="green", end="\n", verbose=self.verbose)
result = self.qa_chain(
{"question": question, "context": context},
callbacks=_run_manager.get_child(),
)
return {self.output_key: result[self.qa_chain.output_key]}

@ -0,0 +1,292 @@
"""Question answering over a graph."""
from __future__ import annotations
import re
from typing import Any, Dict, List, Optional
from langchain_community.graphs.graph_store import GraphStore
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain.chains.base import Chain
from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT
from langchain.chains.llm import LLMChain
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
def extract_cypher(text: str) -> str:
"""Extract Cypher code from a text.
Args:
text: Text to extract Cypher code from.
Returns:
Cypher code extracted from the text.
"""
# The pattern to find Cypher code enclosed in triple backticks
pattern = r"```(.*?)```"
# Find all matches in the input text
matches = re.findall(pattern, text, re.DOTALL)
return matches[0] if matches else text
def construct_schema(
structured_schema: Dict[str, Any],
include_types: List[str],
exclude_types: List[str],
) -> str:
"""Filter the schema based on included or excluded types"""
def filter_func(x: str) -> bool:
return x in include_types if include_types else x not in exclude_types
filtered_schema: Dict[str, Any] = {
"node_props": {
k: v
for k, v in structured_schema.get("node_props", {}).items()
if filter_func(k)
},
"rel_props": {
k: v
for k, v in structured_schema.get("rel_props", {}).items()
if filter_func(k)
},
"relationships": [
r
for r in structured_schema.get("relationships", [])
if all(filter_func(r[t]) for t in ["start", "end", "type"])
],
}
# Format node properties
formatted_node_props = []
for label, properties in filtered_schema["node_props"].items():
props_str = ", ".join(
[f"{prop['property']}: {prop['type']}" for prop in properties]
)
formatted_node_props.append(f"{label} {{{props_str}}}")
# Format relationship properties
formatted_rel_props = []
for rel_type, properties in filtered_schema["rel_props"].items():
props_str = ", ".join(
[f"{prop['property']}: {prop['type']}" for prop in properties]
)
formatted_rel_props.append(f"{rel_type} {{{props_str}}}")
# Format relationships
formatted_rels = [
f"(:{el['start']})-[:{el['type']}]->(:{el['end']})"
for el in filtered_schema["relationships"]
]
return "\n".join(
[
"Node properties are the following:",
",".join(formatted_node_props),
"Relationship properties are the following:",
",".join(formatted_rel_props),
"The relationships are the following:",
",".join(formatted_rels),
]
)
class GraphCypherQAChain(Chain):
"""Chain for question-answering against a graph by generating Cypher statements.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
graph: GraphStore = Field(exclude=True)
cypher_generation_chain: LLMChain
qa_chain: LLMChain
graph_schema: str
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
top_k: int = 10
"""Number of results to return from the query"""
return_intermediate_steps: bool = False
"""Whether or not to return the intermediate steps along with the final answer."""
return_direct: bool = False
"""Whether or not to return the result of querying the graph directly."""
cypher_query_corrector: Optional[CypherQueryCorrector] = None
"""Optional cypher validation tool"""
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@property
def _chain_type(self) -> str:
return "graph_cypher_chain"
@classmethod
def from_llm(
cls,
llm: Optional[BaseLanguageModel] = None,
*,
qa_prompt: Optional[BasePromptTemplate] = None,
cypher_prompt: Optional[BasePromptTemplate] = None,
cypher_llm: Optional[BaseLanguageModel] = None,
qa_llm: Optional[BaseLanguageModel] = None,
exclude_types: List[str] = [],
include_types: List[str] = [],
validate_cypher: bool = False,
qa_llm_kwargs: Optional[Dict[str, Any]] = None,
cypher_llm_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> GraphCypherQAChain:
"""Initialize from LLM."""
if not cypher_llm and not llm:
raise ValueError("Either `llm` or `cypher_llm` parameters must be provided")
if not qa_llm and not llm:
raise ValueError("Either `llm` or `qa_llm` parameters must be provided")
if cypher_llm and qa_llm and llm:
raise ValueError(
"You can specify up to two of 'cypher_llm', 'qa_llm'"
", and 'llm', but not all three simultaneously."
)
if cypher_prompt and cypher_llm_kwargs:
raise ValueError(
"Specifying cypher_prompt and cypher_llm_kwargs together is"
" not allowed. Please pass prompt via cypher_llm_kwargs."
)
if qa_prompt and qa_llm_kwargs:
raise ValueError(
"Specifying qa_prompt and qa_llm_kwargs together is"
" not allowed. Please pass prompt via qa_llm_kwargs."
)
use_qa_llm_kwargs = qa_llm_kwargs if qa_llm_kwargs is not None else {}
use_cypher_llm_kwargs = (
cypher_llm_kwargs if cypher_llm_kwargs is not None else {}
)
if "prompt" not in use_qa_llm_kwargs:
use_qa_llm_kwargs["prompt"] = (
qa_prompt if qa_prompt is not None else CYPHER_QA_PROMPT
)
if "prompt" not in use_cypher_llm_kwargs:
use_cypher_llm_kwargs["prompt"] = (
cypher_prompt if cypher_prompt is not None else CYPHER_GENERATION_PROMPT
)
qa_chain = LLMChain(llm=qa_llm or llm, **use_qa_llm_kwargs) # type: ignore[arg-type]
cypher_generation_chain = LLMChain(
llm=cypher_llm or llm, # type: ignore[arg-type]
**use_cypher_llm_kwargs, # type: ignore[arg-type]
)
if exclude_types and include_types:
raise ValueError(
"Either `exclude_types` or `include_types` "
"can be provided, but not both"
)
graph_schema = construct_schema(
kwargs["graph"].get_structured_schema, include_types, exclude_types
)
cypher_query_corrector = None
if validate_cypher:
corrector_schema = [
Schema(el["start"], el["type"], el["end"])
for el in kwargs["graph"].structured_schema.get("relationships")
]
cypher_query_corrector = CypherQueryCorrector(corrector_schema)
return cls(
graph_schema=graph_schema,
qa_chain=qa_chain,
cypher_generation_chain=cypher_generation_chain,
cypher_query_corrector=cypher_query_corrector,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Generate Cypher statement, use it to look up in db and answer question."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
question = inputs[self.input_key]
intermediate_steps: List = []
generated_cypher = self.cypher_generation_chain.run(
{"question": question, "schema": self.graph_schema}, callbacks=callbacks
)
# Extract Cypher code if it is wrapped in backticks
generated_cypher = extract_cypher(generated_cypher)
# Correct Cypher query if enabled
if self.cypher_query_corrector:
generated_cypher = self.cypher_query_corrector(generated_cypher)
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_cypher, color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"query": generated_cypher})
# Retrieve and limit the number of results
# Generated Cypher be null if query corrector identifies invalid schema
if generated_cypher:
context = self.graph.query(generated_cypher)[: self.top_k]
else:
context = []
if self.return_direct:
final_result = context
else:
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"context": context})
result = self.qa_chain(
{"question": question, "context": context},
callbacks=callbacks,
)
final_result = result[self.qa_chain.output_key]
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
return chain_result

@ -0,0 +1,260 @@
import re
from collections import namedtuple
from typing import Any, Dict, List, Optional, Tuple
Schema = namedtuple("Schema", ["left_node", "relation", "right_node"])
class CypherQueryCorrector:
"""
Used to correct relationship direction in generated Cypher statements.
This code is copied from the winner's submission to the Cypher competition:
https://github.com/sakusaku-rich/cypher-direction-competition
"""
property_pattern = re.compile(r"\{.+?\}")
node_pattern = re.compile(r"\(.+?\)")
path_pattern = re.compile(
r"(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))(<?-)(\[.*?\])?(->?)(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))"
)
node_relation_node_pattern = re.compile(
r"(\()+(?P<left_node>[^()]*?)\)(?P<relation>.*?)\((?P<right_node>[^()]*?)(\))+"
)
relation_type_pattern = re.compile(r":(?P<relation_type>.+?)?(\{.+\})?]")
def __init__(self, schemas: List[Schema]):
"""
Args:
schemas: list of schemas
"""
self.schemas = schemas
def clean_node(self, node: str) -> str:
"""
Args:
node: node in string format
"""
node = re.sub(self.property_pattern, "", node)
node = node.replace("(", "")
node = node.replace(")", "")
node = node.strip()
return node
def detect_node_variables(self, query: str) -> Dict[str, List[str]]:
"""
Args:
query: cypher query
"""
nodes = re.findall(self.node_pattern, query)
nodes = [self.clean_node(node) for node in nodes]
res: Dict[str, Any] = {}
for node in nodes:
parts = node.split(":")
if parts == "":
continue
variable = parts[0]
if variable not in res:
res[variable] = []
res[variable] += parts[1:]
return res
def extract_paths(self, query: str) -> "List[str]":
"""
Args:
query: cypher query
"""
paths = []
idx = 0
while matched := self.path_pattern.findall(query[idx:]):
matched = matched[0]
matched = [
m for i, m in enumerate(matched) if i not in [1, len(matched) - 1]
]
path = "".join(matched)
idx = query.find(path) + len(path) - len(matched[-1])
paths.append(path)
return paths
def judge_direction(self, relation: str) -> str:
"""
Args:
relation: relation in string format
"""
direction = "BIDIRECTIONAL"
if relation[0] == "<":
direction = "INCOMING"
if relation[-1] == ">":
direction = "OUTGOING"
return direction
def extract_node_variable(self, part: str) -> Optional[str]:
"""
Args:
part: node in string format
"""
part = part.lstrip("(").rstrip(")")
idx = part.find(":")
if idx != -1:
part = part[:idx]
return None if part == "" else part
def detect_labels(
self, str_node: str, node_variable_dict: Dict[str, Any]
) -> List[str]:
"""
Args:
str_node: node in string format
node_variable_dict: dictionary of node variables
"""
splitted_node = str_node.split(":")
variable = splitted_node[0]
labels = []
if variable in node_variable_dict:
labels = node_variable_dict[variable]
elif variable == "" and len(splitted_node) > 1:
labels = splitted_node[1:]
return labels
def verify_schema(
self,
from_node_labels: List[str],
relation_types: List[str],
to_node_labels: List[str],
) -> bool:
"""
Args:
from_node_labels: labels of the from node
relation_type: type of the relation
to_node_labels: labels of the to node
"""
valid_schemas = self.schemas
if from_node_labels != []:
from_node_labels = [label.strip("`") for label in from_node_labels]
valid_schemas = [
schema for schema in valid_schemas if schema[0] in from_node_labels
]
if to_node_labels != []:
to_node_labels = [label.strip("`") for label in to_node_labels]
valid_schemas = [
schema for schema in valid_schemas if schema[2] in to_node_labels
]
if relation_types != []:
relation_types = [type.strip("`") for type in relation_types]
valid_schemas = [
schema for schema in valid_schemas if schema[1] in relation_types
]
return valid_schemas != []
def detect_relation_types(self, str_relation: str) -> Tuple[str, List[str]]:
"""
Args:
str_relation: relation in string format
"""
relation_direction = self.judge_direction(str_relation)
relation_type = self.relation_type_pattern.search(str_relation)
if relation_type is None or relation_type.group("relation_type") is None:
return relation_direction, []
relation_types = [
t.strip().strip("!")
for t in relation_type.group("relation_type").split("|")
]
return relation_direction, relation_types
def correct_query(self, query: str) -> str:
"""
Args:
query: cypher query
"""
node_variable_dict = self.detect_node_variables(query)
paths = self.extract_paths(query)
for path in paths:
original_path = path
start_idx = 0
while start_idx < len(path):
match_res = re.match(self.node_relation_node_pattern, path[start_idx:])
if match_res is None:
break
start_idx += match_res.start()
match_dict = match_res.groupdict()
left_node_labels = self.detect_labels(
match_dict["left_node"], node_variable_dict
)
right_node_labels = self.detect_labels(
match_dict["right_node"], node_variable_dict
)
end_idx = (
start_idx
+ 4
+ len(match_dict["left_node"])
+ len(match_dict["relation"])
+ len(match_dict["right_node"])
)
original_partial_path = original_path[start_idx : end_idx + 1]
relation_direction, relation_types = self.detect_relation_types(
match_dict["relation"]
)
if relation_types != [] and "".join(relation_types).find("*") != -1:
start_idx += (
len(match_dict["left_node"]) + len(match_dict["relation"]) + 2
)
continue
if relation_direction == "OUTGOING":
is_legal = self.verify_schema(
left_node_labels, relation_types, right_node_labels
)
if not is_legal:
is_legal = self.verify_schema(
right_node_labels, relation_types, left_node_labels
)
if is_legal:
corrected_relation = "<" + match_dict["relation"][:-1]
corrected_partial_path = original_partial_path.replace(
match_dict["relation"], corrected_relation
)
query = query.replace(
original_partial_path, corrected_partial_path
)
else:
return ""
elif relation_direction == "INCOMING":
is_legal = self.verify_schema(
right_node_labels, relation_types, left_node_labels
)
if not is_legal:
is_legal = self.verify_schema(
left_node_labels, relation_types, right_node_labels
)
if is_legal:
corrected_relation = match_dict["relation"][1:] + ">"
corrected_partial_path = original_partial_path.replace(
match_dict["relation"], corrected_relation
)
query = query.replace(
original_partial_path, corrected_partial_path
)
else:
return ""
else:
is_legal = self.verify_schema(
left_node_labels, relation_types, right_node_labels
)
is_legal |= self.verify_schema(
right_node_labels, relation_types, left_node_labels
)
if not is_legal:
return ""
start_idx += (
len(match_dict["left_node"]) + len(match_dict["relation"]) + 2
)
return query
def __call__(self, query: str) -> str:
"""Correct the query to make it valid. If
Args:
query: cypher query
"""
return self.correct_query(query)

@ -0,0 +1,154 @@
"""Question answering over a graph."""
from __future__ import annotations
import re
from typing import Any, Dict, List, Optional
from langchain_community.graphs import FalkorDBGraph
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT
from langchain.chains.llm import LLMChain
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
def extract_cypher(text: str) -> str:
"""
Extract Cypher code from a text.
Args:
text: Text to extract Cypher code from.
Returns:
Cypher code extracted from the text.
"""
# The pattern to find Cypher code enclosed in triple backticks
pattern = r"```(.*?)```"
# Find all matches in the input text
matches = re.findall(pattern, text, re.DOTALL)
return matches[0] if matches else text
class FalkorDBQAChain(Chain):
"""Chain for question-answering against a graph by generating Cypher statements.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
graph: FalkorDBGraph = Field(exclude=True)
cypher_generation_chain: LLMChain
qa_chain: LLMChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
top_k: int = 10
"""Number of results to return from the query"""
return_intermediate_steps: bool = False
"""Whether or not to return the intermediate steps along with the final answer."""
return_direct: bool = False
"""Whether or not to return the result of querying the graph directly."""
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@property
def _chain_type(self) -> str:
return "graph_cypher_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
cypher_prompt: BasePromptTemplate = CYPHER_GENERATION_PROMPT,
**kwargs: Any,
) -> FalkorDBQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt)
return cls(
qa_chain=qa_chain,
cypher_generation_chain=cypher_generation_chain,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Generate Cypher statement, use it to look up in db and answer question."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
question = inputs[self.input_key]
intermediate_steps: List = []
generated_cypher = self.cypher_generation_chain.run(
{"question": question, "schema": self.graph.schema}, callbacks=callbacks
)
# Extract Cypher code if it is wrapped in backticks
generated_cypher = extract_cypher(generated_cypher)
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_cypher, color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"query": generated_cypher})
# Retrieve and limit the number of results
context = self.graph.query(generated_cypher)[: self.top_k]
if self.return_direct:
final_result = context
else:
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"context": context})
result = self.qa_chain(
{"question": question, "context": context},
callbacks=callbacks,
)
final_result = result[self.qa_chain.output_key]
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
return chain_result

@ -0,0 +1,221 @@
"""Question answering over a graph."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from langchain_community.graphs import GremlinGraph
from langchain_core.callbacks.manager import CallbackManager, CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import (
CYPHER_QA_PROMPT,
GRAPHDB_SPARQL_FIX_TEMPLATE,
GREMLIN_GENERATION_PROMPT,
)
from langchain.chains.llm import LLMChain
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
def extract_gremlin(text: str) -> str:
"""Extract Gremlin code from a text.
Args:
text: Text to extract Gremlin code from.
Returns:
Gremlin code extracted from the text.
"""
text = text.replace("`", "")
if text.startswith("gremlin"):
text = text[len("gremlin") :]
return text.replace("\n", "")
class GremlinQAChain(Chain):
"""Chain for question-answering against a graph by generating gremlin statements.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
graph: GremlinGraph = Field(exclude=True)
gremlin_generation_chain: LLMChain
qa_chain: LLMChain
gremlin_fix_chain: LLMChain
max_fix_retries: int = 3
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
top_k: int = 100
return_direct: bool = False
return_intermediate_steps: bool = False
@property
def input_keys(self) -> List[str]:
"""Input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
gremlin_fix_prompt: BasePromptTemplate = PromptTemplate(
input_variables=["error_message", "generated_sparql", "schema"],
template=GRAPHDB_SPARQL_FIX_TEMPLATE.replace("SPARQL", "Gremlin").replace(
"in Turtle format", ""
),
),
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
gremlin_prompt: BasePromptTemplate = GREMLIN_GENERATION_PROMPT,
**kwargs: Any,
) -> GremlinQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
gremlin_generation_chain = LLMChain(llm=llm, prompt=gremlin_prompt)
gremlinl_fix_chain = LLMChain(llm=llm, prompt=gremlin_fix_prompt)
return cls(
qa_chain=qa_chain,
gremlin_generation_chain=gremlin_generation_chain,
gremlin_fix_chain=gremlinl_fix_chain,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""Generate gremlin statement, use it to look up in db and answer question."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
question = inputs[self.input_key]
intermediate_steps: List = []
chain_response = self.gremlin_generation_chain.invoke(
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
)
generated_gremlin = extract_gremlin(
chain_response[self.gremlin_generation_chain.output_key]
)
_run_manager.on_text("Generated gremlin:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_gremlin, color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"query": generated_gremlin})
if generated_gremlin:
context = self.execute_with_retry(
_run_manager, callbacks, generated_gremlin
)[: self.top_k]
else:
context = []
if self.return_direct:
final_result = context
else:
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"context": context})
result = self.qa_chain.invoke(
{"question": question, "context": context},
callbacks=callbacks,
)
final_result = result[self.qa_chain.output_key]
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
return chain_result
def execute_query(self, query: str) -> List[Any]:
try:
return self.graph.query(query)
except Exception as e:
if hasattr(e, "status_message"):
raise ValueError(e.status_message)
else:
raise ValueError(str(e))
def execute_with_retry(
self,
_run_manager: CallbackManagerForChainRun,
callbacks: CallbackManager,
generated_gremlin: str,
) -> List[Any]:
try:
return self.execute_query(generated_gremlin)
except Exception as e:
retries = 0
error_message = str(e)
self.log_invalid_query(_run_manager, generated_gremlin, error_message)
while retries < self.max_fix_retries:
try:
fix_chain_result = self.gremlin_fix_chain.invoke(
{
"error_message": error_message,
# we are borrowing template from sparql
"generated_sparql": generated_gremlin,
"schema": self.schema,
},
callbacks=callbacks,
)
fixed_gremlin = fix_chain_result[self.gremlin_fix_chain.output_key]
return self.execute_query(fixed_gremlin)
except Exception as e:
retries += 1
parse_exception = str(e)
self.log_invalid_query(_run_manager, fixed_gremlin, parse_exception)
raise ValueError("The generated Gremlin query is invalid.")
def log_invalid_query(
self,
_run_manager: CallbackManagerForChainRun,
generated_query: str,
error_message: str,
) -> None:
_run_manager.on_text("Invalid Gremlin query: ", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_query, color="red", end="\n", verbose=self.verbose
)
_run_manager.on_text(
"Gremlin Query Parse Error: ", end="\n", verbose=self.verbose
)
_run_manager.on_text(
error_message, color="red", end="\n\n", verbose=self.verbose
)

@ -0,0 +1,106 @@
"""Question answering over a graph."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from langchain_community.graphs.hugegraph import HugeGraph
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import (
CYPHER_QA_PROMPT,
GREMLIN_GENERATION_PROMPT,
)
from langchain.chains.llm import LLMChain
class HugeGraphQAChain(Chain):
"""Chain for question-answering against a graph by generating gremlin statements.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
graph: HugeGraph = Field(exclude=True)
gremlin_generation_chain: LLMChain
qa_chain: LLMChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
gremlin_prompt: BasePromptTemplate = GREMLIN_GENERATION_PROMPT,
**kwargs: Any,
) -> HugeGraphQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
gremlin_generation_chain = LLMChain(llm=llm, prompt=gremlin_prompt)
return cls(
qa_chain=qa_chain,
gremlin_generation_chain=gremlin_generation_chain,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""Generate gremlin statement, use it to look up in db and answer question."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
question = inputs[self.input_key]
generated_gremlin = self.gremlin_generation_chain.run(
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
)
_run_manager.on_text("Generated gremlin:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_gremlin, color="green", end="\n", verbose=self.verbose
)
context = self.graph.query(generated_gremlin)
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
result = self.qa_chain(
{"question": question, "context": context},
callbacks=callbacks,
)
return {self.output_key: result[self.qa_chain.output_key]}

@ -0,0 +1,131 @@
"""Question answering over a graph."""
from __future__ import annotations
import re
from typing import Any, Dict, List, Optional
from langchain_community.graphs.kuzu_graph import KuzuGraph
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, KUZU_GENERATION_PROMPT
from langchain.chains.llm import LLMChain
def remove_prefix(text: str, prefix: str) -> str:
if text.startswith(prefix):
return text[len(prefix) :]
return text
def extract_cypher(text: str) -> str:
"""Extract Cypher code from a text.
Args:
text: Text to extract Cypher code from.
Returns:
Cypher code extracted from the text.
"""
# The pattern to find Cypher code enclosed in triple backticks
pattern = r"```(.*?)```"
# Find all matches in the input text
matches = re.findall(pattern, text, re.DOTALL)
return matches[0] if matches else text
class KuzuQAChain(Chain):
"""Question-answering against a graph by generating Cypher statements for Kùzu.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
graph: KuzuGraph = Field(exclude=True)
cypher_generation_chain: LLMChain
qa_chain: LLMChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
cypher_prompt: BasePromptTemplate = KUZU_GENERATION_PROMPT,
**kwargs: Any,
) -> KuzuQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt)
return cls(
qa_chain=qa_chain,
cypher_generation_chain=cypher_generation_chain,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""Generate Cypher statement, use it to look up in db and answer question."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
question = inputs[self.input_key]
generated_cypher = self.cypher_generation_chain.run(
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
)
# Extract Cypher code if it is wrapped in triple backticks
# with the language marker "cypher"
generated_cypher = remove_prefix(extract_cypher(generated_cypher), "cypher")
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_cypher, color="green", end="\n", verbose=self.verbose
)
context = self.graph.query(generated_cypher)
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
result = self.qa_chain(
{"question": question, "context": context},
callbacks=callbacks,
)
return {self.output_key: result[self.qa_chain.output_key]}

@ -0,0 +1,103 @@
"""Question answering over a graph."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from langchain_community.graphs.nebula_graph import NebulaGraph
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, NGQL_GENERATION_PROMPT
from langchain.chains.llm import LLMChain
class NebulaGraphQAChain(Chain):
"""Chain for question-answering against a graph by generating nGQL statements.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
graph: NebulaGraph = Field(exclude=True)
ngql_generation_chain: LLMChain
qa_chain: LLMChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
ngql_prompt: BasePromptTemplate = NGQL_GENERATION_PROMPT,
**kwargs: Any,
) -> NebulaGraphQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
ngql_generation_chain = LLMChain(llm=llm, prompt=ngql_prompt)
return cls(
qa_chain=qa_chain,
ngql_generation_chain=ngql_generation_chain,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""Generate nGQL statement, use it to look up in db and answer question."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
question = inputs[self.input_key]
generated_ngql = self.ngql_generation_chain.run(
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
)
_run_manager.on_text("Generated nGQL:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_ngql, color="green", end="\n", verbose=self.verbose
)
context = self.graph.query(generated_ngql)
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
result = self.qa_chain(
{"question": question, "context": context},
callbacks=callbacks,
)
return {self.output_key: result[self.qa_chain.output_key]}

@ -0,0 +1,217 @@
from __future__ import annotations
import re
from typing import Any, Dict, List, Optional
from langchain_community.graphs import BaseNeptuneGraph
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import (
CYPHER_QA_PROMPT,
NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT,
)
from langchain.chains.llm import LLMChain
from langchain.chains.prompt_selector import ConditionalPromptSelector
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
def trim_query(query: str) -> str:
"""Trim the query to only include Cypher keywords."""
keywords = (
"CALL",
"CREATE",
"DELETE",
"DETACH",
"LIMIT",
"MATCH",
"MERGE",
"OPTIONAL",
"ORDER",
"REMOVE",
"RETURN",
"SET",
"SKIP",
"UNWIND",
"WITH",
"WHERE",
"//",
)
lines = query.split("\n")
new_query = ""
for line in lines:
if line.strip().upper().startswith(keywords):
new_query += line + "\n"
return new_query
def extract_cypher(text: str) -> str:
"""Extract Cypher code from text using Regex."""
# The pattern to find Cypher code enclosed in triple backticks
pattern = r"```(.*?)```"
# Find all matches in the input text
matches = re.findall(pattern, text, re.DOTALL)
return matches[0] if matches else text
def use_simple_prompt(llm: BaseLanguageModel) -> bool:
"""Decides whether to use the simple prompt"""
if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore
return True
# Bedrock anthropic
if hasattr(llm, "model_id") and "anthropic" in llm.model_id: # type: ignore
return True
return False
PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
conditionals=[(use_simple_prompt, NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT)],
)
class NeptuneOpenCypherQAChain(Chain):
"""Chain for question-answering against a Neptune graph
by generating openCypher statements.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
Example:
.. code-block:: python
chain = NeptuneOpenCypherQAChain.from_llm(
llm=llm,
graph=graph
)
response = chain.run(query)
"""
graph: BaseNeptuneGraph = Field(exclude=True)
cypher_generation_chain: LLMChain
qa_chain: LLMChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
top_k: int = 10
return_intermediate_steps: bool = False
"""Whether or not to return the intermediate steps along with the final answer."""
return_direct: bool = False
"""Whether or not to return the result of querying the graph directly."""
extra_instructions: Optional[str] = None
"""Extra instructions by the appended to the query generation prompt."""
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
cypher_prompt: Optional[BasePromptTemplate] = None,
extra_instructions: Optional[str] = None,
**kwargs: Any,
) -> NeptuneOpenCypherQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
_cypher_prompt = cypher_prompt or PROMPT_SELECTOR.get_prompt(llm)
cypher_generation_chain = LLMChain(llm=llm, prompt=_cypher_prompt)
return cls(
qa_chain=qa_chain,
cypher_generation_chain=cypher_generation_chain,
extra_instructions=extra_instructions,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
"""Generate Cypher statement, use it to look up in db and answer question."""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
question = inputs[self.input_key]
intermediate_steps: List = []
generated_cypher = self.cypher_generation_chain.run(
{
"question": question,
"schema": self.graph.get_schema,
"extra_instructions": self.extra_instructions or "",
},
callbacks=callbacks,
)
# Extract Cypher code if it is wrapped in backticks
generated_cypher = extract_cypher(generated_cypher)
generated_cypher = trim_query(generated_cypher)
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_cypher, color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"query": generated_cypher})
context = self.graph.query(generated_cypher)
if self.return_direct:
final_result = context
else:
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"context": context})
result = self.qa_chain(
{"question": question, "context": context},
callbacks=callbacks,
)
final_result = result[self.qa_chain.output_key]
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
return chain_result

@ -0,0 +1,196 @@
"""
Question answering over an RDF or OWL graph using SPARQL.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from langchain_community.graphs import NeptuneRdfGraph
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import SPARQL_QA_PROMPT
from langchain.chains.llm import LLMChain
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
SPARQL_GENERATION_TEMPLATE = """
Task: Generate a SPARQL SELECT statement for querying a graph database.
For instance, to find all email addresses of John Doe, the following
query in backticks would be suitable:
```
PREFIX foaf: <http://xmlns.com/foaf/0.1/>
SELECT ?email
WHERE {{
?person foaf:name "John Doe" .
?person foaf:mbox ?email .
}}
```
Instructions:
Use only the node types and properties provided in the schema.
Do not use any node types and properties that are not explicitly provided.
Include all necessary prefixes.
Examples:
Schema:
{schema}
Note: Be as concise as possible.
Do not include any explanations or apologies in your responses.
Do not respond to any questions that ask for anything else than
for you to construct a SPARQL query.
Do not include any text except the SPARQL query generated.
The question is:
{prompt}"""
SPARQL_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "prompt"], template=SPARQL_GENERATION_TEMPLATE
)
def extract_sparql(query: str) -> str:
query = query.strip()
querytoks = query.split("```")
if len(querytoks) == 3:
query = querytoks[1]
if query.startswith("sparql"):
query = query[6:]
elif query.startswith("<sparql>") and query.endswith("</sparql>"):
query = query[8:-9]
return query
class NeptuneSparqlQAChain(Chain):
"""Chain for question-answering against a Neptune graph
by generating SPARQL statements.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
Example:
.. code-block:: python
chain = NeptuneSparqlQAChain.from_llm(
llm=llm,
graph=graph
)
response = chain.invoke(query)
"""
graph: NeptuneRdfGraph = Field(exclude=True)
sparql_generation_chain: LLMChain
qa_chain: LLMChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
top_k: int = 10
return_intermediate_steps: bool = False
"""Whether or not to return the intermediate steps along with the final answer."""
return_direct: bool = False
"""Whether or not to return the result of querying the graph directly."""
extra_instructions: Optional[str] = None
"""Extra instructions by the appended to the query generation prompt."""
@property
def input_keys(self) -> List[str]:
return [self.input_key]
@property
def output_keys(self) -> List[str]:
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
qa_prompt: BasePromptTemplate = SPARQL_QA_PROMPT,
sparql_prompt: BasePromptTemplate = SPARQL_GENERATION_PROMPT,
examples: Optional[str] = None,
**kwargs: Any,
) -> NeptuneSparqlQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
template_to_use = SPARQL_GENERATION_TEMPLATE
if examples:
template_to_use = template_to_use.replace(
"Examples:", "Examples: " + examples
)
sparql_prompt = PromptTemplate(
input_variables=["schema", "prompt"], template=template_to_use
)
sparql_generation_chain = LLMChain(llm=llm, prompt=sparql_prompt)
return cls( # type: ignore[call-arg]
qa_chain=qa_chain,
sparql_generation_chain=sparql_generation_chain,
examples=examples,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""
Generate SPARQL query, use it to retrieve a response from the gdb and answer
the question.
"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
prompt = inputs[self.input_key]
intermediate_steps: List = []
generated_sparql = self.sparql_generation_chain.run(
{"prompt": prompt, "schema": self.graph.get_schema}, callbacks=callbacks
)
# Extract SPARQL
generated_sparql = extract_sparql(generated_sparql)
_run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_sparql, color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"query": generated_sparql})
context = self.graph.query(generated_sparql)
if self.return_direct:
final_result = context
else:
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"context": context})
result = self.qa_chain(
{"prompt": prompt, "context": context},
callbacks=callbacks,
)
final_result = result[self.qa_chain.output_key]
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
return chain_result

@ -0,0 +1,190 @@
"""Question answering over a graph."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Dict, List, Optional
if TYPE_CHECKING:
import rdflib
from langchain_community.graphs import OntotextGraphDBGraph
from langchain_core.callbacks.manager import CallbackManager, CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import (
GRAPHDB_QA_PROMPT,
GRAPHDB_SPARQL_FIX_PROMPT,
GRAPHDB_SPARQL_GENERATION_PROMPT,
)
from langchain.chains.llm import LLMChain
class OntotextGraphDBQAChain(Chain):
"""Question-answering against Ontotext GraphDB
https://graphdb.ontotext.com/ by generating SPARQL queries.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
graph: OntotextGraphDBGraph = Field(exclude=True)
sparql_generation_chain: LLMChain
sparql_fix_chain: LLMChain
max_fix_retries: int
qa_chain: LLMChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
@property
def input_keys(self) -> List[str]:
return [self.input_key]
@property
def output_keys(self) -> List[str]:
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
sparql_generation_prompt: BasePromptTemplate = GRAPHDB_SPARQL_GENERATION_PROMPT,
sparql_fix_prompt: BasePromptTemplate = GRAPHDB_SPARQL_FIX_PROMPT,
max_fix_retries: int = 5,
qa_prompt: BasePromptTemplate = GRAPHDB_QA_PROMPT,
**kwargs: Any,
) -> OntotextGraphDBQAChain:
"""Initialize from LLM."""
sparql_generation_chain = LLMChain(llm=llm, prompt=sparql_generation_prompt)
sparql_fix_chain = LLMChain(llm=llm, prompt=sparql_fix_prompt)
max_fix_retries = max_fix_retries
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
return cls(
qa_chain=qa_chain,
sparql_generation_chain=sparql_generation_chain,
sparql_fix_chain=sparql_fix_chain,
max_fix_retries=max_fix_retries,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""
Generate a SPARQL query, use it to retrieve a response from GraphDB and answer
the question.
"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
prompt = inputs[self.input_key]
ontology_schema = self.graph.get_schema
sparql_generation_chain_result = self.sparql_generation_chain.invoke(
{"prompt": prompt, "schema": ontology_schema}, callbacks=callbacks
)
generated_sparql = sparql_generation_chain_result[
self.sparql_generation_chain.output_key
]
generated_sparql = self._get_prepared_sparql_query(
_run_manager, callbacks, generated_sparql, ontology_schema
)
query_results = self._execute_query(generated_sparql)
qa_chain_result = self.qa_chain.invoke(
{"prompt": prompt, "context": query_results}, callbacks=callbacks
)
result = qa_chain_result[self.qa_chain.output_key]
return {self.output_key: result}
def _get_prepared_sparql_query(
self,
_run_manager: CallbackManagerForChainRun,
callbacks: CallbackManager,
generated_sparql: str,
ontology_schema: str,
) -> str:
try:
return self._prepare_sparql_query(_run_manager, generated_sparql)
except Exception as e:
retries = 0
error_message = str(e)
self._log_invalid_sparql_query(
_run_manager, generated_sparql, error_message
)
while retries < self.max_fix_retries:
try:
sparql_fix_chain_result = self.sparql_fix_chain.invoke(
{
"error_message": error_message,
"generated_sparql": generated_sparql,
"schema": ontology_schema,
},
callbacks=callbacks,
)
generated_sparql = sparql_fix_chain_result[
self.sparql_fix_chain.output_key
]
return self._prepare_sparql_query(_run_manager, generated_sparql)
except Exception as e:
retries += 1
parse_exception = str(e)
self._log_invalid_sparql_query(
_run_manager, generated_sparql, parse_exception
)
raise ValueError("The generated SPARQL query is invalid.")
def _prepare_sparql_query(
self, _run_manager: CallbackManagerForChainRun, generated_sparql: str
) -> str:
from rdflib.plugins.sparql import prepareQuery
prepareQuery(generated_sparql)
self._log_prepared_sparql_query(_run_manager, generated_sparql)
return generated_sparql
def _log_prepared_sparql_query(
self, _run_manager: CallbackManagerForChainRun, generated_query: str
) -> None:
_run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_query, color="green", end="\n", verbose=self.verbose
)
def _log_invalid_sparql_query(
self,
_run_manager: CallbackManagerForChainRun,
generated_query: str,
error_message: str,
) -> None:
_run_manager.on_text("Invalid SPARQL query: ", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_query, color="red", end="\n", verbose=self.verbose
)
_run_manager.on_text(
"SPARQL Query Parse Error: ", end="\n", verbose=self.verbose
)
_run_manager.on_text(
error_message, color="red", end="\n\n", verbose=self.verbose
)
def _execute_query(self, query: str) -> List[rdflib.query.ResultRow]:
try:
return self.graph.query(query)
except Exception:
raise ValueError("Failed to execute the generated SPARQL query.")

@ -0,0 +1,415 @@
# flake8: noqa
from langchain_core.prompts.prompt import PromptTemplate
_DEFAULT_ENTITY_EXTRACTION_TEMPLATE = """Extract all entities from the following text. As a guideline, a proper noun is generally capitalized. You should definitely extract all names and places.
Return the output as a single comma-separated list, or NONE if there is nothing of note to return.
EXAMPLE
i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff.
Output: Langchain
END OF EXAMPLE
EXAMPLE
i'm trying to improve Langchain's interfaces, the UX, its integrations with various products the user might want ... a lot of stuff. I'm working with Sam.
Output: Langchain, Sam
END OF EXAMPLE
Begin!
{input}
Output:"""
ENTITY_EXTRACTION_PROMPT = PromptTemplate(
input_variables=["input"], template=_DEFAULT_ENTITY_EXTRACTION_TEMPLATE
)
_DEFAULT_GRAPH_QA_TEMPLATE = """Use the following knowledge triplets to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
{context}
Question: {question}
Helpful Answer:"""
GRAPH_QA_PROMPT = PromptTemplate(
template=_DEFAULT_GRAPH_QA_TEMPLATE, input_variables=["context", "question"]
)
CYPHER_GENERATION_TEMPLATE = """Task:Generate Cypher statement to query a graph database.
Instructions:
Use only the provided relationship types and properties in the schema.
Do not use any other relationship types or properties that are not provided.
Schema:
{schema}
Note: Do not include any explanations or apologies in your responses.
Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
Do not include any text except the generated Cypher statement.
The question is:
{question}"""
CYPHER_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE
)
NEBULAGRAPH_EXTRA_INSTRUCTIONS = """
Instructions:
First, generate cypher then convert it to NebulaGraph Cypher dialect(rather than standard):
1. it requires explicit label specification only when referring to node properties: v.`Foo`.name
2. note explicit label specification is not needed for edge properties, so it's e.name instead of e.`Bar`.name
3. it uses double equals sign for comparison: `==` rather than `=`
For instance:
```diff
< MATCH (p:person)-[e:directed]->(m:movie) WHERE m.name = 'The Godfather II'
< RETURN p.name, e.year, m.name;
---
> MATCH (p:`person`)-[e:directed]->(m:`movie`) WHERE m.`movie`.`name` == 'The Godfather II'
> RETURN p.`person`.`name`, e.year, m.`movie`.`name`;
```\n"""
NGQL_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(
"Generate Cypher", "Generate NebulaGraph Cypher"
).replace("Instructions:", NEBULAGRAPH_EXTRA_INSTRUCTIONS)
NGQL_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "question"], template=NGQL_GENERATION_TEMPLATE
)
KUZU_EXTRA_INSTRUCTIONS = """
Instructions:
Generate the Kùzu dialect of Cypher with the following rules in mind:
1. Do not use a `WHERE EXISTS` clause to check the existence of a property.
2. Do not omit the relationship pattern. Always use `()-[]->()` instead of `()->()`.
3. Do not include any notes or comments even if the statement does not produce the expected result.
```\n"""
KUZU_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(
"Generate Cypher", "Generate Kùzu Cypher"
).replace("Instructions:", KUZU_EXTRA_INSTRUCTIONS)
KUZU_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "question"], template=KUZU_GENERATION_TEMPLATE
)
GREMLIN_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace("Cypher", "Gremlin")
GREMLIN_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "question"], template=GREMLIN_GENERATION_TEMPLATE
)
CYPHER_QA_TEMPLATE = """You are an assistant that helps to form nice and human understandable answers.
The information part contains the provided information that you must use to construct an answer.
The provided information is authoritative, you must never doubt it or try to use your internal knowledge to correct it.
Make the answer sound as a response to the question. Do not mention that you based the result on the given information.
Here is an example:
Question: Which managers own Neo4j stocks?
Context:[manager:CTL LLC, manager:JANE STREET GROUP LLC]
Helpful Answer: CTL LLC, JANE STREET GROUP LLC owns Neo4j stocks.
Follow this example when generating answers.
If the provided information is empty, say that you don't know the answer.
Information:
{context}
Question: {question}
Helpful Answer:"""
CYPHER_QA_PROMPT = PromptTemplate(
input_variables=["context", "question"], template=CYPHER_QA_TEMPLATE
)
SPARQL_INTENT_TEMPLATE = """Task: Identify the intent of a prompt and return the appropriate SPARQL query type.
You are an assistant that distinguishes different types of prompts and returns the corresponding SPARQL query types.
Consider only the following query types:
* SELECT: this query type corresponds to questions
* UPDATE: this query type corresponds to all requests for deleting, inserting, or changing triples
Note: Be as concise as possible.
Do not include any explanations or apologies in your responses.
Do not respond to any questions that ask for anything else than for you to identify a SPARQL query type.
Do not include any unnecessary whitespaces or any text except the query type, i.e., either return 'SELECT' or 'UPDATE'.
The prompt is:
{prompt}
Helpful Answer:"""
SPARQL_INTENT_PROMPT = PromptTemplate(
input_variables=["prompt"], template=SPARQL_INTENT_TEMPLATE
)
SPARQL_GENERATION_SELECT_TEMPLATE = """Task: Generate a SPARQL SELECT statement for querying a graph database.
For instance, to find all email addresses of John Doe, the following query in backticks would be suitable:
```
PREFIX foaf: <http://xmlns.com/foaf/0.1/>
SELECT ?email
WHERE {{
?person foaf:name "John Doe" .
?person foaf:mbox ?email .
}}
```
Instructions:
Use only the node types and properties provided in the schema.
Do not use any node types and properties that are not explicitly provided.
Include all necessary prefixes.
Schema:
{schema}
Note: Be as concise as possible.
Do not include any explanations or apologies in your responses.
Do not respond to any questions that ask for anything else than for you to construct a SPARQL query.
Do not include any text except the SPARQL query generated.
The question is:
{prompt}"""
SPARQL_GENERATION_SELECT_PROMPT = PromptTemplate(
input_variables=["schema", "prompt"], template=SPARQL_GENERATION_SELECT_TEMPLATE
)
SPARQL_GENERATION_UPDATE_TEMPLATE = """Task: Generate a SPARQL UPDATE statement for updating a graph database.
For instance, to add 'jane.doe@foo.bar' as a new email address for Jane Doe, the following query in backticks would be suitable:
```
PREFIX foaf: <http://xmlns.com/foaf/0.1/>
INSERT {{
?person foaf:mbox <mailto:jane.doe@foo.bar> .
}}
WHERE {{
?person foaf:name "Jane Doe" .
}}
```
Instructions:
Make the query as short as possible and avoid adding unnecessary triples.
Use only the node types and properties provided in the schema.
Do not use any node types and properties that are not explicitly provided.
Include all necessary prefixes.
Schema:
{schema}
Note: Be as concise as possible.
Do not include any explanations or apologies in your responses.
Do not respond to any questions that ask for anything else than for you to construct a SPARQL query.
Return only the generated SPARQL query, nothing else.
The information to be inserted is:
{prompt}"""
SPARQL_GENERATION_UPDATE_PROMPT = PromptTemplate(
input_variables=["schema", "prompt"], template=SPARQL_GENERATION_UPDATE_TEMPLATE
)
SPARQL_QA_TEMPLATE = """Task: Generate a natural language response from the results of a SPARQL query.
You are an assistant that creates well-written and human understandable answers.
The information part contains the information provided, which you can use to construct an answer.
The information provided is authoritative, you must never doubt it or try to use your internal knowledge to correct it.
Make your response sound like the information is coming from an AI assistant, but don't add any information.
Information:
{context}
Question: {prompt}
Helpful Answer:"""
SPARQL_QA_PROMPT = PromptTemplate(
input_variables=["context", "prompt"], template=SPARQL_QA_TEMPLATE
)
GRAPHDB_SPARQL_GENERATION_TEMPLATE = """
Write a SPARQL SELECT query for querying a graph database.
The ontology schema delimited by triple backticks in Turtle format is:
```
{schema}
```
Use only the classes and properties provided in the schema to construct the SPARQL query.
Do not use any classes or properties that are not explicitly provided in the SPARQL query.
Include all necessary prefixes.
Do not include any explanations or apologies in your responses.
Do not wrap the query in backticks.
Do not include any text except the SPARQL query generated.
The question delimited by triple backticks is:
```
{prompt}
```
"""
GRAPHDB_SPARQL_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "prompt"],
template=GRAPHDB_SPARQL_GENERATION_TEMPLATE,
)
GRAPHDB_SPARQL_FIX_TEMPLATE = """
This following SPARQL query delimited by triple backticks
```
{generated_sparql}
```
is not valid.
The error delimited by triple backticks is
```
{error_message}
```
Give me a correct version of the SPARQL query.
Do not change the logic of the query.
Do not include any explanations or apologies in your responses.
Do not wrap the query in backticks.
Do not include any text except the SPARQL query generated.
The ontology schema delimited by triple backticks in Turtle format is:
```
{schema}
```
"""
GRAPHDB_SPARQL_FIX_PROMPT = PromptTemplate(
input_variables=["error_message", "generated_sparql", "schema"],
template=GRAPHDB_SPARQL_FIX_TEMPLATE,
)
GRAPHDB_QA_TEMPLATE = """Task: Generate a natural language response from the results of a SPARQL query.
You are an assistant that creates well-written and human understandable answers.
The information part contains the information provided, which you can use to construct an answer.
The information provided is authoritative, you must never doubt it or try to use your internal knowledge to correct it.
Make your response sound like the information is coming from an AI assistant, but don't add any information.
Don't use internal knowledge to answer the question, just say you don't know if no information is available.
Information:
{context}
Question: {prompt}
Helpful Answer:"""
GRAPHDB_QA_PROMPT = PromptTemplate(
input_variables=["context", "prompt"], template=GRAPHDB_QA_TEMPLATE
)
AQL_GENERATION_TEMPLATE = """Task: Generate an ArangoDB Query Language (AQL) query from a User Input.
You are an ArangoDB Query Language (AQL) expert responsible for translating a `User Input` into an ArangoDB Query Language (AQL) query.
You are given an `ArangoDB Schema`. It is a JSON Object containing:
1. `Graph Schema`: Lists all Graphs within the ArangoDB Database Instance, along with their Edge Relationships.
2. `Collection Schema`: Lists all Collections within the ArangoDB Database Instance, along with their document/edge properties and a document/edge example.
You may also be given a set of `AQL Query Examples` to help you create the `AQL Query`. If provided, the `AQL Query Examples` should be used as a reference, similar to how `ArangoDB Schema` should be used.
Things you should do:
- Think step by step.
- Rely on `ArangoDB Schema` and `AQL Query Examples` (if provided) to generate the query.
- Begin the `AQL Query` by the `WITH` AQL keyword to specify all of the ArangoDB Collections required.
- Return the `AQL Query` wrapped in 3 backticks (```).
- Use only the provided relationship types and properties in the `ArangoDB Schema` and any `AQL Query Examples` queries.
- Only answer to requests related to generating an AQL Query.
- If a request is unrelated to generating AQL Query, say that you cannot help the user.
Things you should not do:
- Do not use any properties/relationships that can't be inferred from the `ArangoDB Schema` or the `AQL Query Examples`.
- Do not include any text except the generated AQL Query.
- Do not provide explanations or apologies in your responses.
- Do not generate an AQL Query that removes or deletes any data.
Under no circumstance should you generate an AQL Query that deletes any data whatsoever.
ArangoDB Schema:
{adb_schema}
AQL Query Examples (Optional):
{aql_examples}
User Input:
{user_input}
AQL Query:
"""
AQL_GENERATION_PROMPT = PromptTemplate(
input_variables=["adb_schema", "aql_examples", "user_input"],
template=AQL_GENERATION_TEMPLATE,
)
AQL_FIX_TEMPLATE = """Task: Address the ArangoDB Query Language (AQL) error message of an ArangoDB Query Language query.
You are an ArangoDB Query Language (AQL) expert responsible for correcting the provided `AQL Query` based on the provided `AQL Error`.
The `AQL Error` explains why the `AQL Query` could not be executed in the database.
The `AQL Error` may also contain the position of the error relative to the total number of lines of the `AQL Query`.
For example, 'error X at position 2:5' denotes that the error X occurs on line 2, column 5 of the `AQL Query`.
You are also given the `ArangoDB Schema`. It is a JSON Object containing:
1. `Graph Schema`: Lists all Graphs within the ArangoDB Database Instance, along with their Edge Relationships.
2. `Collection Schema`: Lists all Collections within the ArangoDB Database Instance, along with their document/edge properties and a document/edge example.
You will output the `Corrected AQL Query` wrapped in 3 backticks (```). Do not include any text except the Corrected AQL Query.
Remember to think step by step.
ArangoDB Schema:
{adb_schema}
AQL Query:
{aql_query}
AQL Error:
{aql_error}
Corrected AQL Query:
"""
AQL_FIX_PROMPT = PromptTemplate(
input_variables=[
"adb_schema",
"aql_query",
"aql_error",
],
template=AQL_FIX_TEMPLATE,
)
AQL_QA_TEMPLATE = """Task: Generate a natural language `Summary` from the results of an ArangoDB Query Language query.
You are an ArangoDB Query Language (AQL) expert responsible for creating a well-written `Summary` from the `User Input` and associated `AQL Result`.
A user has executed an ArangoDB Query Language query, which has returned the AQL Result in JSON format.
You are responsible for creating an `Summary` based on the AQL Result.
You are given the following information:
- `ArangoDB Schema`: contains a schema representation of the user's ArangoDB Database.
- `User Input`: the original question/request of the user, which has been translated into an AQL Query.
- `AQL Query`: the AQL equivalent of the `User Input`, translated by another AI Model. Should you deem it to be incorrect, suggest a different AQL Query.
- `AQL Result`: the JSON output returned by executing the `AQL Query` within the ArangoDB Database.
Remember to think step by step.
Your `Summary` should sound like it is a response to the `User Input`.
Your `Summary` should not include any mention of the `AQL Query` or the `AQL Result`.
ArangoDB Schema:
{adb_schema}
User Input:
{user_input}
AQL Query:
{aql_query}
AQL Result:
{aql_result}
"""
AQL_QA_PROMPT = PromptTemplate(
input_variables=["adb_schema", "user_input", "aql_query", "aql_result"],
template=AQL_QA_TEMPLATE,
)
NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS = """
Instructions:
Generate the query in openCypher format and follow these rules:
Do not use `NONE`, `ALL` or `ANY` predicate functions, rather use list comprehensions.
Do not use `REDUCE` function. Rather use a combination of list comprehension and the `UNWIND` clause to achieve similar results.
Do not use `FOREACH` clause. Rather use a combination of `WITH` and `UNWIND` clauses to achieve similar results.{extra_instructions}
\n"""
NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(
"Instructions:", NEPTUNE_OPENCYPHER_EXTRA_INSTRUCTIONS
)
NEPTUNE_OPENCYPHER_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "question", "extra_instructions"],
template=NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE,
)
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE = """
Write an openCypher query to answer the following question. Do not explain the answer. Only return the query.{extra_instructions}
Question: "{question}".
Here is the property graph schema:
{schema}
\n"""
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT = PromptTemplate(
input_variables=["schema", "question", "extra_instructions"],
template=NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE,
)

@ -0,0 +1,152 @@
"""
Question answering over an RDF or OWL graph using SPARQL.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from langchain_community.graphs.rdf_graph import RdfGraph
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import (
SPARQL_GENERATION_SELECT_PROMPT,
SPARQL_GENERATION_UPDATE_PROMPT,
SPARQL_INTENT_PROMPT,
SPARQL_QA_PROMPT,
)
from langchain.chains.llm import LLMChain
class GraphSparqlQAChain(Chain):
"""Question-answering against an RDF or OWL graph by generating SPARQL statements.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
graph: RdfGraph = Field(exclude=True)
sparql_generation_select_chain: LLMChain
sparql_generation_update_chain: LLMChain
sparql_intent_chain: LLMChain
qa_chain: LLMChain
return_sparql_query: bool = False
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
sparql_query_key: str = "sparql_query" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return the output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
qa_prompt: BasePromptTemplate = SPARQL_QA_PROMPT,
sparql_select_prompt: BasePromptTemplate = SPARQL_GENERATION_SELECT_PROMPT,
sparql_update_prompt: BasePromptTemplate = SPARQL_GENERATION_UPDATE_PROMPT,
sparql_intent_prompt: BasePromptTemplate = SPARQL_INTENT_PROMPT,
**kwargs: Any,
) -> GraphSparqlQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
sparql_generation_select_chain = LLMChain(llm=llm, prompt=sparql_select_prompt)
sparql_generation_update_chain = LLMChain(llm=llm, prompt=sparql_update_prompt)
sparql_intent_chain = LLMChain(llm=llm, prompt=sparql_intent_prompt)
return cls(
qa_chain=qa_chain,
sparql_generation_select_chain=sparql_generation_select_chain,
sparql_generation_update_chain=sparql_generation_update_chain,
sparql_intent_chain=sparql_intent_chain,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""
Generate SPARQL query, use it to retrieve a response from the gdb and answer
the question.
"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
prompt = inputs[self.input_key]
_intent = self.sparql_intent_chain.run({"prompt": prompt}, callbacks=callbacks)
intent = _intent.strip()
if "SELECT" in intent and "UPDATE" not in intent:
sparql_generation_chain = self.sparql_generation_select_chain
intent = "SELECT"
elif "UPDATE" in intent and "SELECT" not in intent:
sparql_generation_chain = self.sparql_generation_update_chain
intent = "UPDATE"
else:
raise ValueError(
"I am sorry, but this prompt seems to fit none of the currently "
"supported SPARQL query types, i.e., SELECT and UPDATE."
)
_run_manager.on_text("Identified intent:", end="\n", verbose=self.verbose)
_run_manager.on_text(intent, color="green", end="\n", verbose=self.verbose)
generated_sparql = sparql_generation_chain.run(
{"prompt": prompt, "schema": self.graph.get_schema}, callbacks=callbacks
)
_run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_sparql, color="green", end="\n", verbose=self.verbose
)
if intent == "SELECT":
context = self.graph.query(generated_sparql)
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
result = self.qa_chain(
{"prompt": prompt, "context": context},
callbacks=callbacks,
)
res = result[self.qa_chain.output_key]
elif intent == "UPDATE":
self.graph.update(generated_sparql)
res = "Successfully inserted triples into the graph."
else:
raise ValueError("Unsupported SPARQL query type.")
chain_result: Dict[str, Any] = {self.output_key: res}
if self.return_sparql_query:
chain_result[self.sparql_query_key] = generated_sparql
return chain_result
Loading…
Cancel
Save