mirror of https://github.com/hwchase17/langchain
pull/20853/head
parent
5b7ad94a95
commit
8dfa6014a9
@ -1,241 +1,3 @@
|
|||||||
"""Question answering over a graph."""
|
from langchain_community.chains.graph_qa.arangodb import ArangoGraphQAChain
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
__all__ = ["ArangoGraphQAChain"]
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from langchain_community.graphs.arangodb_graph import ArangoGraph
|
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
|
||||||
from langchain_core.pydantic_v1 import Field
|
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.chains.graph_qa.prompts import (
|
|
||||||
AQL_FIX_PROMPT,
|
|
||||||
AQL_GENERATION_PROMPT,
|
|
||||||
AQL_QA_PROMPT,
|
|
||||||
)
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
|
|
||||||
|
|
||||||
class ArangoGraphQAChain(Chain):
|
|
||||||
"""Chain for question-answering against a graph by generating AQL statements.
|
|
||||||
|
|
||||||
*Security note*: Make sure that the database connection uses credentials
|
|
||||||
that are narrowly-scoped to only include necessary permissions.
|
|
||||||
Failure to do so may result in data corruption or loss, since the calling
|
|
||||||
code may attempt commands that would result in deletion, mutation
|
|
||||||
of data if appropriately prompted or reading sensitive data if such
|
|
||||||
data is present in the database.
|
|
||||||
The best way to guard against such negative outcomes is to (as appropriate)
|
|
||||||
limit the permissions granted to the credentials used with this tool.
|
|
||||||
|
|
||||||
See https://python.langchain.com/docs/security for more information.
|
|
||||||
"""
|
|
||||||
|
|
||||||
graph: ArangoGraph = Field(exclude=True)
|
|
||||||
aql_generation_chain: LLMChain
|
|
||||||
aql_fix_chain: LLMChain
|
|
||||||
qa_chain: LLMChain
|
|
||||||
input_key: str = "query" #: :meta private:
|
|
||||||
output_key: str = "result" #: :meta private:
|
|
||||||
|
|
||||||
# Specifies the maximum number of AQL Query Results to return
|
|
||||||
top_k: int = 10
|
|
||||||
|
|
||||||
# Specifies the set of AQL Query Examples that promote few-shot-learning
|
|
||||||
aql_examples: str = ""
|
|
||||||
|
|
||||||
# Specify whether to return the AQL Query in the output dictionary
|
|
||||||
return_aql_query: bool = False
|
|
||||||
|
|
||||||
# Specify whether to return the AQL JSON Result in the output dictionary
|
|
||||||
return_aql_result: bool = False
|
|
||||||
|
|
||||||
# Specify the maximum amount of AQL Generation attempts that should be made
|
|
||||||
max_aql_generation_attempts: int = 3
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
return [self.input_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
return [self.output_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _chain_type(self) -> str:
|
|
||||||
return "graph_aql_chain"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm(
|
|
||||||
cls,
|
|
||||||
llm: BaseLanguageModel,
|
|
||||||
*,
|
|
||||||
qa_prompt: BasePromptTemplate = AQL_QA_PROMPT,
|
|
||||||
aql_generation_prompt: BasePromptTemplate = AQL_GENERATION_PROMPT,
|
|
||||||
aql_fix_prompt: BasePromptTemplate = AQL_FIX_PROMPT,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ArangoGraphQAChain:
|
|
||||||
"""Initialize from LLM."""
|
|
||||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
|
||||||
aql_generation_chain = LLMChain(llm=llm, prompt=aql_generation_prompt)
|
|
||||||
aql_fix_chain = LLMChain(llm=llm, prompt=aql_fix_prompt)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
qa_chain=qa_chain,
|
|
||||||
aql_generation_chain=aql_generation_chain,
|
|
||||||
aql_fix_chain=aql_fix_chain,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Generate an AQL statement from user input, use it retrieve a response
|
|
||||||
from an ArangoDB Database instance, and respond to the user input
|
|
||||||
in natural language.
|
|
||||||
|
|
||||||
Users can modify the following ArangoGraphQAChain Class Variables:
|
|
||||||
|
|
||||||
:var top_k: The maximum number of AQL Query Results to return
|
|
||||||
:type top_k: int
|
|
||||||
|
|
||||||
:var aql_examples: A set of AQL Query Examples that are passed to
|
|
||||||
the AQL Generation Prompt Template to promote few-shot-learning.
|
|
||||||
Defaults to an empty string.
|
|
||||||
:type aql_examples: str
|
|
||||||
|
|
||||||
:var return_aql_query: Whether to return the AQL Query in the
|
|
||||||
output dictionary. Defaults to False.
|
|
||||||
:type return_aql_query: bool
|
|
||||||
|
|
||||||
:var return_aql_result: Whether to return the AQL Query in the
|
|
||||||
output dictionary. Defaults to False
|
|
||||||
:type return_aql_result: bool
|
|
||||||
|
|
||||||
:var max_aql_generation_attempts: The maximum amount of AQL
|
|
||||||
Generation attempts to be made prior to raising the last
|
|
||||||
AQL Query Execution Error. Defaults to 3.
|
|
||||||
:type max_aql_generation_attempts: int
|
|
||||||
"""
|
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
||||||
callbacks = _run_manager.get_child()
|
|
||||||
user_input = inputs[self.input_key]
|
|
||||||
|
|
||||||
#########################
|
|
||||||
# Generate AQL Query #
|
|
||||||
aql_generation_output = self.aql_generation_chain.run(
|
|
||||||
{
|
|
||||||
"adb_schema": self.graph.schema,
|
|
||||||
"aql_examples": self.aql_examples,
|
|
||||||
"user_input": user_input,
|
|
||||||
},
|
|
||||||
callbacks=callbacks,
|
|
||||||
)
|
|
||||||
#########################
|
|
||||||
|
|
||||||
aql_query = ""
|
|
||||||
aql_error = ""
|
|
||||||
aql_result = None
|
|
||||||
aql_generation_attempt = 1
|
|
||||||
|
|
||||||
while (
|
|
||||||
aql_result is None
|
|
||||||
and aql_generation_attempt < self.max_aql_generation_attempts + 1
|
|
||||||
):
|
|
||||||
#####################
|
|
||||||
# Extract AQL Query #
|
|
||||||
pattern = r"```(?i:aql)?(.*?)```"
|
|
||||||
matches = re.findall(pattern, aql_generation_output, re.DOTALL)
|
|
||||||
if not matches:
|
|
||||||
_run_manager.on_text(
|
|
||||||
"Invalid Response: ", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
_run_manager.on_text(
|
|
||||||
aql_generation_output, color="red", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
raise ValueError(f"Response is Invalid: {aql_generation_output}")
|
|
||||||
|
|
||||||
aql_query = matches[0]
|
|
||||||
#####################
|
|
||||||
|
|
||||||
_run_manager.on_text(
|
|
||||||
f"AQL Query ({aql_generation_attempt}):", verbose=self.verbose
|
|
||||||
)
|
|
||||||
_run_manager.on_text(
|
|
||||||
aql_query, color="green", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
#####################
|
|
||||||
# Execute AQL Query #
|
|
||||||
from arango import AQLQueryExecuteError
|
|
||||||
|
|
||||||
try:
|
|
||||||
aql_result = self.graph.query(aql_query, self.top_k)
|
|
||||||
except AQLQueryExecuteError as e:
|
|
||||||
aql_error = e.error_message
|
|
||||||
|
|
||||||
_run_manager.on_text(
|
|
||||||
"AQL Query Execution Error: ", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
_run_manager.on_text(
|
|
||||||
aql_error, color="yellow", end="\n\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
########################
|
|
||||||
# Retry AQL Generation #
|
|
||||||
aql_generation_output = self.aql_fix_chain.run(
|
|
||||||
{
|
|
||||||
"adb_schema": self.graph.schema,
|
|
||||||
"aql_query": aql_query,
|
|
||||||
"aql_error": aql_error,
|
|
||||||
},
|
|
||||||
callbacks=callbacks,
|
|
||||||
)
|
|
||||||
########################
|
|
||||||
|
|
||||||
#####################
|
|
||||||
|
|
||||||
aql_generation_attempt += 1
|
|
||||||
|
|
||||||
if aql_result is None:
|
|
||||||
m = f"""
|
|
||||||
Maximum amount of AQL Query Generation attempts reached.
|
|
||||||
Unable to execute the AQL Query due to the following error:
|
|
||||||
{aql_error}
|
|
||||||
"""
|
|
||||||
raise ValueError(m)
|
|
||||||
|
|
||||||
_run_manager.on_text("AQL Result:", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(
|
|
||||||
str(aql_result), color="green", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
########################
|
|
||||||
# Interpret AQL Result #
|
|
||||||
result = self.qa_chain(
|
|
||||||
{
|
|
||||||
"adb_schema": self.graph.schema,
|
|
||||||
"user_input": user_input,
|
|
||||||
"aql_query": aql_query,
|
|
||||||
"aql_result": aql_result,
|
|
||||||
},
|
|
||||||
callbacks=callbacks,
|
|
||||||
)
|
|
||||||
########################
|
|
||||||
|
|
||||||
# Return results #
|
|
||||||
result = {self.output_key: result[self.qa_chain.output_key]}
|
|
||||||
|
|
||||||
if self.return_aql_query:
|
|
||||||
result["aql_query"] = aql_query
|
|
||||||
|
|
||||||
if self.return_aql_result:
|
|
||||||
result["aql_result"] = aql_result
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
@ -1,100 +1,3 @@
|
|||||||
"""Question answering over a graph."""
|
from langchain_community.chains.graph_qa.base import GraphQAChain
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
__all__ = ["GraphQAChain"]
|
||||||
|
|
||||||
from langchain_community.graphs.networkx_graph import NetworkxEntityGraph, get_entities
|
|
||||||
from langchain_core.callbacks.manager import CallbackManagerForChainRun
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
|
||||||
from langchain_core.pydantic_v1 import Field
|
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, GRAPH_QA_PROMPT
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
|
|
||||||
|
|
||||||
class GraphQAChain(Chain):
|
|
||||||
"""Chain for question-answering against a graph.
|
|
||||||
|
|
||||||
*Security note*: Make sure that the database connection uses credentials
|
|
||||||
that are narrowly-scoped to only include necessary permissions.
|
|
||||||
Failure to do so may result in data corruption or loss, since the calling
|
|
||||||
code may attempt commands that would result in deletion, mutation
|
|
||||||
of data if appropriately prompted or reading sensitive data if such
|
|
||||||
data is present in the database.
|
|
||||||
The best way to guard against such negative outcomes is to (as appropriate)
|
|
||||||
limit the permissions granted to the credentials used with this tool.
|
|
||||||
|
|
||||||
See https://python.langchain.com/docs/security for more information.
|
|
||||||
"""
|
|
||||||
|
|
||||||
graph: NetworkxEntityGraph = Field(exclude=True)
|
|
||||||
entity_extraction_chain: LLMChain
|
|
||||||
qa_chain: LLMChain
|
|
||||||
input_key: str = "query" #: :meta private:
|
|
||||||
output_key: str = "result" #: :meta private:
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
"""Input keys.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.input_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
"""Output keys.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
_output_keys = [self.output_key]
|
|
||||||
return _output_keys
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm(
|
|
||||||
cls,
|
|
||||||
llm: BaseLanguageModel,
|
|
||||||
qa_prompt: BasePromptTemplate = GRAPH_QA_PROMPT,
|
|
||||||
entity_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> GraphQAChain:
|
|
||||||
"""Initialize from LLM."""
|
|
||||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
|
||||||
entity_chain = LLMChain(llm=llm, prompt=entity_prompt)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
qa_chain=qa_chain,
|
|
||||||
entity_extraction_chain=entity_chain,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
"""Extract entities, look up info and answer question."""
|
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
||||||
question = inputs[self.input_key]
|
|
||||||
|
|
||||||
entity_string = self.entity_extraction_chain.run(question)
|
|
||||||
|
|
||||||
_run_manager.on_text("Entities Extracted:", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(
|
|
||||||
entity_string, color="green", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
entities = get_entities(entity_string)
|
|
||||||
context = ""
|
|
||||||
all_triplets = []
|
|
||||||
for entity in entities:
|
|
||||||
all_triplets.extend(self.graph.get_entity_knowledge(entity))
|
|
||||||
context = "\n".join(all_triplets)
|
|
||||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(context, color="green", end="\n", verbose=self.verbose)
|
|
||||||
result = self.qa_chain(
|
|
||||||
{"question": question, "context": context},
|
|
||||||
callbacks=_run_manager.get_child(),
|
|
||||||
)
|
|
||||||
return {self.output_key: result[self.qa_chain.output_key]}
|
|
||||||
|
@ -1,292 +1,8 @@
|
|||||||
"""Question answering over a graph."""
|
from langchain_community.chains.graph_qa.cypher import (
|
||||||
from __future__ import annotations
|
GraphCypherQAChain,
|
||||||
|
construct_schema,
|
||||||
import re
|
extract_cypher,
|
||||||
from typing import Any, Dict, List, Optional
|
filter_func,
|
||||||
|
)
|
||||||
from langchain_community.graphs.graph_store import GraphStore
|
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
__all__ = ["GraphCypherQAChain", "construct_schema", "extract_cypher", "filter_func"]
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
|
||||||
from langchain_core.pydantic_v1 import Field
|
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
|
|
||||||
from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
|
|
||||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
|
||||||
|
|
||||||
|
|
||||||
def extract_cypher(text: str) -> str:
|
|
||||||
"""Extract Cypher code from a text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Text to extract Cypher code from.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Cypher code extracted from the text.
|
|
||||||
"""
|
|
||||||
# The pattern to find Cypher code enclosed in triple backticks
|
|
||||||
pattern = r"```(.*?)```"
|
|
||||||
|
|
||||||
# Find all matches in the input text
|
|
||||||
matches = re.findall(pattern, text, re.DOTALL)
|
|
||||||
|
|
||||||
return matches[0] if matches else text
|
|
||||||
|
|
||||||
|
|
||||||
def construct_schema(
|
|
||||||
structured_schema: Dict[str, Any],
|
|
||||||
include_types: List[str],
|
|
||||||
exclude_types: List[str],
|
|
||||||
) -> str:
|
|
||||||
"""Filter the schema based on included or excluded types"""
|
|
||||||
|
|
||||||
def filter_func(x: str) -> bool:
|
|
||||||
return x in include_types if include_types else x not in exclude_types
|
|
||||||
|
|
||||||
filtered_schema: Dict[str, Any] = {
|
|
||||||
"node_props": {
|
|
||||||
k: v
|
|
||||||
for k, v in structured_schema.get("node_props", {}).items()
|
|
||||||
if filter_func(k)
|
|
||||||
},
|
|
||||||
"rel_props": {
|
|
||||||
k: v
|
|
||||||
for k, v in structured_schema.get("rel_props", {}).items()
|
|
||||||
if filter_func(k)
|
|
||||||
},
|
|
||||||
"relationships": [
|
|
||||||
r
|
|
||||||
for r in structured_schema.get("relationships", [])
|
|
||||||
if all(filter_func(r[t]) for t in ["start", "end", "type"])
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Format node properties
|
|
||||||
formatted_node_props = []
|
|
||||||
for label, properties in filtered_schema["node_props"].items():
|
|
||||||
props_str = ", ".join(
|
|
||||||
[f"{prop['property']}: {prop['type']}" for prop in properties]
|
|
||||||
)
|
|
||||||
formatted_node_props.append(f"{label} {{{props_str}}}")
|
|
||||||
|
|
||||||
# Format relationship properties
|
|
||||||
formatted_rel_props = []
|
|
||||||
for rel_type, properties in filtered_schema["rel_props"].items():
|
|
||||||
props_str = ", ".join(
|
|
||||||
[f"{prop['property']}: {prop['type']}" for prop in properties]
|
|
||||||
)
|
|
||||||
formatted_rel_props.append(f"{rel_type} {{{props_str}}}")
|
|
||||||
|
|
||||||
# Format relationships
|
|
||||||
formatted_rels = [
|
|
||||||
f"(:{el['start']})-[:{el['type']}]->(:{el['end']})"
|
|
||||||
for el in filtered_schema["relationships"]
|
|
||||||
]
|
|
||||||
|
|
||||||
return "\n".join(
|
|
||||||
[
|
|
||||||
"Node properties are the following:",
|
|
||||||
",".join(formatted_node_props),
|
|
||||||
"Relationship properties are the following:",
|
|
||||||
",".join(formatted_rel_props),
|
|
||||||
"The relationships are the following:",
|
|
||||||
",".join(formatted_rels),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GraphCypherQAChain(Chain):
|
|
||||||
"""Chain for question-answering against a graph by generating Cypher statements.
|
|
||||||
|
|
||||||
*Security note*: Make sure that the database connection uses credentials
|
|
||||||
that are narrowly-scoped to only include necessary permissions.
|
|
||||||
Failure to do so may result in data corruption or loss, since the calling
|
|
||||||
code may attempt commands that would result in deletion, mutation
|
|
||||||
of data if appropriately prompted or reading sensitive data if such
|
|
||||||
data is present in the database.
|
|
||||||
The best way to guard against such negative outcomes is to (as appropriate)
|
|
||||||
limit the permissions granted to the credentials used with this tool.
|
|
||||||
|
|
||||||
See https://python.langchain.com/docs/security for more information.
|
|
||||||
"""
|
|
||||||
|
|
||||||
graph: GraphStore = Field(exclude=True)
|
|
||||||
cypher_generation_chain: LLMChain
|
|
||||||
qa_chain: LLMChain
|
|
||||||
graph_schema: str
|
|
||||||
input_key: str = "query" #: :meta private:
|
|
||||||
output_key: str = "result" #: :meta private:
|
|
||||||
top_k: int = 10
|
|
||||||
"""Number of results to return from the query"""
|
|
||||||
return_intermediate_steps: bool = False
|
|
||||||
"""Whether or not to return the intermediate steps along with the final answer."""
|
|
||||||
return_direct: bool = False
|
|
||||||
"""Whether or not to return the result of querying the graph directly."""
|
|
||||||
cypher_query_corrector: Optional[CypherQueryCorrector] = None
|
|
||||||
"""Optional cypher validation tool"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
"""Return the input keys.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.input_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
"""Return the output keys.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
_output_keys = [self.output_key]
|
|
||||||
return _output_keys
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _chain_type(self) -> str:
|
|
||||||
return "graph_cypher_chain"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm(
|
|
||||||
cls,
|
|
||||||
llm: Optional[BaseLanguageModel] = None,
|
|
||||||
*,
|
|
||||||
qa_prompt: Optional[BasePromptTemplate] = None,
|
|
||||||
cypher_prompt: Optional[BasePromptTemplate] = None,
|
|
||||||
cypher_llm: Optional[BaseLanguageModel] = None,
|
|
||||||
qa_llm: Optional[BaseLanguageModel] = None,
|
|
||||||
exclude_types: List[str] = [],
|
|
||||||
include_types: List[str] = [],
|
|
||||||
validate_cypher: bool = False,
|
|
||||||
qa_llm_kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
cypher_llm_kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> GraphCypherQAChain:
|
|
||||||
"""Initialize from LLM."""
|
|
||||||
|
|
||||||
if not cypher_llm and not llm:
|
|
||||||
raise ValueError("Either `llm` or `cypher_llm` parameters must be provided")
|
|
||||||
if not qa_llm and not llm:
|
|
||||||
raise ValueError("Either `llm` or `qa_llm` parameters must be provided")
|
|
||||||
if cypher_llm and qa_llm and llm:
|
|
||||||
raise ValueError(
|
|
||||||
"You can specify up to two of 'cypher_llm', 'qa_llm'"
|
|
||||||
", and 'llm', but not all three simultaneously."
|
|
||||||
)
|
|
||||||
if cypher_prompt and cypher_llm_kwargs:
|
|
||||||
raise ValueError(
|
|
||||||
"Specifying cypher_prompt and cypher_llm_kwargs together is"
|
|
||||||
" not allowed. Please pass prompt via cypher_llm_kwargs."
|
|
||||||
)
|
|
||||||
if qa_prompt and qa_llm_kwargs:
|
|
||||||
raise ValueError(
|
|
||||||
"Specifying qa_prompt and qa_llm_kwargs together is"
|
|
||||||
" not allowed. Please pass prompt via qa_llm_kwargs."
|
|
||||||
)
|
|
||||||
use_qa_llm_kwargs = qa_llm_kwargs if qa_llm_kwargs is not None else {}
|
|
||||||
use_cypher_llm_kwargs = (
|
|
||||||
cypher_llm_kwargs if cypher_llm_kwargs is not None else {}
|
|
||||||
)
|
|
||||||
if "prompt" not in use_qa_llm_kwargs:
|
|
||||||
use_qa_llm_kwargs["prompt"] = (
|
|
||||||
qa_prompt if qa_prompt is not None else CYPHER_QA_PROMPT
|
|
||||||
)
|
|
||||||
if "prompt" not in use_cypher_llm_kwargs:
|
|
||||||
use_cypher_llm_kwargs["prompt"] = (
|
|
||||||
cypher_prompt if cypher_prompt is not None else CYPHER_GENERATION_PROMPT
|
|
||||||
)
|
|
||||||
|
|
||||||
qa_chain = LLMChain(llm=qa_llm or llm, **use_qa_llm_kwargs) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
cypher_generation_chain = LLMChain(
|
|
||||||
llm=cypher_llm or llm, # type: ignore[arg-type]
|
|
||||||
**use_cypher_llm_kwargs, # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
|
|
||||||
if exclude_types and include_types:
|
|
||||||
raise ValueError(
|
|
||||||
"Either `exclude_types` or `include_types` "
|
|
||||||
"can be provided, but not both"
|
|
||||||
)
|
|
||||||
|
|
||||||
graph_schema = construct_schema(
|
|
||||||
kwargs["graph"].get_structured_schema, include_types, exclude_types
|
|
||||||
)
|
|
||||||
|
|
||||||
cypher_query_corrector = None
|
|
||||||
if validate_cypher:
|
|
||||||
corrector_schema = [
|
|
||||||
Schema(el["start"], el["type"], el["end"])
|
|
||||||
for el in kwargs["graph"].structured_schema.get("relationships")
|
|
||||||
]
|
|
||||||
cypher_query_corrector = CypherQueryCorrector(corrector_schema)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
graph_schema=graph_schema,
|
|
||||||
qa_chain=qa_chain,
|
|
||||||
cypher_generation_chain=cypher_generation_chain,
|
|
||||||
cypher_query_corrector=cypher_query_corrector,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Generate Cypher statement, use it to look up in db and answer question."""
|
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
||||||
callbacks = _run_manager.get_child()
|
|
||||||
question = inputs[self.input_key]
|
|
||||||
|
|
||||||
intermediate_steps: List = []
|
|
||||||
|
|
||||||
generated_cypher = self.cypher_generation_chain.run(
|
|
||||||
{"question": question, "schema": self.graph_schema}, callbacks=callbacks
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract Cypher code if it is wrapped in backticks
|
|
||||||
generated_cypher = extract_cypher(generated_cypher)
|
|
||||||
|
|
||||||
# Correct Cypher query if enabled
|
|
||||||
if self.cypher_query_corrector:
|
|
||||||
generated_cypher = self.cypher_query_corrector(generated_cypher)
|
|
||||||
|
|
||||||
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(
|
|
||||||
generated_cypher, color="green", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
intermediate_steps.append({"query": generated_cypher})
|
|
||||||
|
|
||||||
# Retrieve and limit the number of results
|
|
||||||
# Generated Cypher be null if query corrector identifies invalid schema
|
|
||||||
if generated_cypher:
|
|
||||||
context = self.graph.query(generated_cypher)[: self.top_k]
|
|
||||||
else:
|
|
||||||
context = []
|
|
||||||
|
|
||||||
if self.return_direct:
|
|
||||||
final_result = context
|
|
||||||
else:
|
|
||||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(
|
|
||||||
str(context), color="green", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
intermediate_steps.append({"context": context})
|
|
||||||
|
|
||||||
result = self.qa_chain(
|
|
||||||
{"question": question, "context": context},
|
|
||||||
callbacks=callbacks,
|
|
||||||
)
|
|
||||||
final_result = result[self.qa_chain.output_key]
|
|
||||||
|
|
||||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
|
||||||
if self.return_intermediate_steps:
|
|
||||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
|
||||||
|
|
||||||
return chain_result
|
|
||||||
|
@ -1,260 +1,3 @@
|
|||||||
import re
|
from langchain_community.chains.graph_qa.cypher_utils import CypherQueryCorrector
|
||||||
from collections import namedtuple
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
Schema = namedtuple("Schema", ["left_node", "relation", "right_node"])
|
__all__ = ["CypherQueryCorrector"]
|
||||||
|
|
||||||
|
|
||||||
class CypherQueryCorrector:
|
|
||||||
"""
|
|
||||||
Used to correct relationship direction in generated Cypher statements.
|
|
||||||
This code is copied from the winner's submission to the Cypher competition:
|
|
||||||
https://github.com/sakusaku-rich/cypher-direction-competition
|
|
||||||
"""
|
|
||||||
|
|
||||||
property_pattern = re.compile(r"\{.+?\}")
|
|
||||||
node_pattern = re.compile(r"\(.+?\)")
|
|
||||||
path_pattern = re.compile(
|
|
||||||
r"(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))(<?-)(\[.*?\])?(->?)(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))"
|
|
||||||
)
|
|
||||||
node_relation_node_pattern = re.compile(
|
|
||||||
r"(\()+(?P<left_node>[^()]*?)\)(?P<relation>.*?)\((?P<right_node>[^()]*?)(\))+"
|
|
||||||
)
|
|
||||||
relation_type_pattern = re.compile(r":(?P<relation_type>.+?)?(\{.+\})?]")
|
|
||||||
|
|
||||||
def __init__(self, schemas: List[Schema]):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
schemas: list of schemas
|
|
||||||
"""
|
|
||||||
self.schemas = schemas
|
|
||||||
|
|
||||||
def clean_node(self, node: str) -> str:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
node: node in string format
|
|
||||||
|
|
||||||
"""
|
|
||||||
node = re.sub(self.property_pattern, "", node)
|
|
||||||
node = node.replace("(", "")
|
|
||||||
node = node.replace(")", "")
|
|
||||||
node = node.strip()
|
|
||||||
return node
|
|
||||||
|
|
||||||
def detect_node_variables(self, query: str) -> Dict[str, List[str]]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
query: cypher query
|
|
||||||
"""
|
|
||||||
nodes = re.findall(self.node_pattern, query)
|
|
||||||
nodes = [self.clean_node(node) for node in nodes]
|
|
||||||
res: Dict[str, Any] = {}
|
|
||||||
for node in nodes:
|
|
||||||
parts = node.split(":")
|
|
||||||
if parts == "":
|
|
||||||
continue
|
|
||||||
variable = parts[0]
|
|
||||||
if variable not in res:
|
|
||||||
res[variable] = []
|
|
||||||
res[variable] += parts[1:]
|
|
||||||
return res
|
|
||||||
|
|
||||||
def extract_paths(self, query: str) -> "List[str]":
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
query: cypher query
|
|
||||||
"""
|
|
||||||
paths = []
|
|
||||||
idx = 0
|
|
||||||
while matched := self.path_pattern.findall(query[idx:]):
|
|
||||||
matched = matched[0]
|
|
||||||
matched = [
|
|
||||||
m for i, m in enumerate(matched) if i not in [1, len(matched) - 1]
|
|
||||||
]
|
|
||||||
path = "".join(matched)
|
|
||||||
idx = query.find(path) + len(path) - len(matched[-1])
|
|
||||||
paths.append(path)
|
|
||||||
return paths
|
|
||||||
|
|
||||||
def judge_direction(self, relation: str) -> str:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
relation: relation in string format
|
|
||||||
"""
|
|
||||||
direction = "BIDIRECTIONAL"
|
|
||||||
if relation[0] == "<":
|
|
||||||
direction = "INCOMING"
|
|
||||||
if relation[-1] == ">":
|
|
||||||
direction = "OUTGOING"
|
|
||||||
return direction
|
|
||||||
|
|
||||||
def extract_node_variable(self, part: str) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
part: node in string format
|
|
||||||
"""
|
|
||||||
part = part.lstrip("(").rstrip(")")
|
|
||||||
idx = part.find(":")
|
|
||||||
if idx != -1:
|
|
||||||
part = part[:idx]
|
|
||||||
return None if part == "" else part
|
|
||||||
|
|
||||||
def detect_labels(
|
|
||||||
self, str_node: str, node_variable_dict: Dict[str, Any]
|
|
||||||
) -> List[str]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
str_node: node in string format
|
|
||||||
node_variable_dict: dictionary of node variables
|
|
||||||
"""
|
|
||||||
splitted_node = str_node.split(":")
|
|
||||||
variable = splitted_node[0]
|
|
||||||
labels = []
|
|
||||||
if variable in node_variable_dict:
|
|
||||||
labels = node_variable_dict[variable]
|
|
||||||
elif variable == "" and len(splitted_node) > 1:
|
|
||||||
labels = splitted_node[1:]
|
|
||||||
return labels
|
|
||||||
|
|
||||||
def verify_schema(
|
|
||||||
self,
|
|
||||||
from_node_labels: List[str],
|
|
||||||
relation_types: List[str],
|
|
||||||
to_node_labels: List[str],
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
from_node_labels: labels of the from node
|
|
||||||
relation_type: type of the relation
|
|
||||||
to_node_labels: labels of the to node
|
|
||||||
"""
|
|
||||||
valid_schemas = self.schemas
|
|
||||||
if from_node_labels != []:
|
|
||||||
from_node_labels = [label.strip("`") for label in from_node_labels]
|
|
||||||
valid_schemas = [
|
|
||||||
schema for schema in valid_schemas if schema[0] in from_node_labels
|
|
||||||
]
|
|
||||||
if to_node_labels != []:
|
|
||||||
to_node_labels = [label.strip("`") for label in to_node_labels]
|
|
||||||
valid_schemas = [
|
|
||||||
schema for schema in valid_schemas if schema[2] in to_node_labels
|
|
||||||
]
|
|
||||||
if relation_types != []:
|
|
||||||
relation_types = [type.strip("`") for type in relation_types]
|
|
||||||
valid_schemas = [
|
|
||||||
schema for schema in valid_schemas if schema[1] in relation_types
|
|
||||||
]
|
|
||||||
return valid_schemas != []
|
|
||||||
|
|
||||||
def detect_relation_types(self, str_relation: str) -> Tuple[str, List[str]]:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
str_relation: relation in string format
|
|
||||||
"""
|
|
||||||
relation_direction = self.judge_direction(str_relation)
|
|
||||||
relation_type = self.relation_type_pattern.search(str_relation)
|
|
||||||
if relation_type is None or relation_type.group("relation_type") is None:
|
|
||||||
return relation_direction, []
|
|
||||||
relation_types = [
|
|
||||||
t.strip().strip("!")
|
|
||||||
for t in relation_type.group("relation_type").split("|")
|
|
||||||
]
|
|
||||||
return relation_direction, relation_types
|
|
||||||
|
|
||||||
def correct_query(self, query: str) -> str:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
query: cypher query
|
|
||||||
"""
|
|
||||||
node_variable_dict = self.detect_node_variables(query)
|
|
||||||
paths = self.extract_paths(query)
|
|
||||||
for path in paths:
|
|
||||||
original_path = path
|
|
||||||
start_idx = 0
|
|
||||||
while start_idx < len(path):
|
|
||||||
match_res = re.match(self.node_relation_node_pattern, path[start_idx:])
|
|
||||||
if match_res is None:
|
|
||||||
break
|
|
||||||
start_idx += match_res.start()
|
|
||||||
match_dict = match_res.groupdict()
|
|
||||||
left_node_labels = self.detect_labels(
|
|
||||||
match_dict["left_node"], node_variable_dict
|
|
||||||
)
|
|
||||||
right_node_labels = self.detect_labels(
|
|
||||||
match_dict["right_node"], node_variable_dict
|
|
||||||
)
|
|
||||||
end_idx = (
|
|
||||||
start_idx
|
|
||||||
+ 4
|
|
||||||
+ len(match_dict["left_node"])
|
|
||||||
+ len(match_dict["relation"])
|
|
||||||
+ len(match_dict["right_node"])
|
|
||||||
)
|
|
||||||
original_partial_path = original_path[start_idx : end_idx + 1]
|
|
||||||
relation_direction, relation_types = self.detect_relation_types(
|
|
||||||
match_dict["relation"]
|
|
||||||
)
|
|
||||||
|
|
||||||
if relation_types != [] and "".join(relation_types).find("*") != -1:
|
|
||||||
start_idx += (
|
|
||||||
len(match_dict["left_node"]) + len(match_dict["relation"]) + 2
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if relation_direction == "OUTGOING":
|
|
||||||
is_legal = self.verify_schema(
|
|
||||||
left_node_labels, relation_types, right_node_labels
|
|
||||||
)
|
|
||||||
if not is_legal:
|
|
||||||
is_legal = self.verify_schema(
|
|
||||||
right_node_labels, relation_types, left_node_labels
|
|
||||||
)
|
|
||||||
if is_legal:
|
|
||||||
corrected_relation = "<" + match_dict["relation"][:-1]
|
|
||||||
corrected_partial_path = original_partial_path.replace(
|
|
||||||
match_dict["relation"], corrected_relation
|
|
||||||
)
|
|
||||||
query = query.replace(
|
|
||||||
original_partial_path, corrected_partial_path
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return ""
|
|
||||||
elif relation_direction == "INCOMING":
|
|
||||||
is_legal = self.verify_schema(
|
|
||||||
right_node_labels, relation_types, left_node_labels
|
|
||||||
)
|
|
||||||
if not is_legal:
|
|
||||||
is_legal = self.verify_schema(
|
|
||||||
left_node_labels, relation_types, right_node_labels
|
|
||||||
)
|
|
||||||
if is_legal:
|
|
||||||
corrected_relation = match_dict["relation"][1:] + ">"
|
|
||||||
corrected_partial_path = original_partial_path.replace(
|
|
||||||
match_dict["relation"], corrected_relation
|
|
||||||
)
|
|
||||||
query = query.replace(
|
|
||||||
original_partial_path, corrected_partial_path
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return ""
|
|
||||||
else:
|
|
||||||
is_legal = self.verify_schema(
|
|
||||||
left_node_labels, relation_types, right_node_labels
|
|
||||||
)
|
|
||||||
is_legal |= self.verify_schema(
|
|
||||||
right_node_labels, relation_types, left_node_labels
|
|
||||||
)
|
|
||||||
if not is_legal:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
start_idx += (
|
|
||||||
len(match_dict["left_node"]) + len(match_dict["relation"]) + 2
|
|
||||||
)
|
|
||||||
return query
|
|
||||||
|
|
||||||
def __call__(self, query: str) -> str:
|
|
||||||
"""Correct the query to make it valid. If
|
|
||||||
Args:
|
|
||||||
query: cypher query
|
|
||||||
"""
|
|
||||||
return self.correct_query(query)
|
|
||||||
|
@ -1,154 +1,3 @@
|
|||||||
"""Question answering over a graph."""
|
from langchain_community.chains.graph_qa.falkordb import FalkorDBQAChain, extract_cypher
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
__all__ = ["FalkorDBQAChain", "extract_cypher"]
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from langchain_community.graphs import FalkorDBGraph
|
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
|
||||||
from langchain_core.pydantic_v1 import Field
|
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
|
|
||||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
|
||||||
|
|
||||||
|
|
||||||
def extract_cypher(text: str) -> str:
|
|
||||||
"""
|
|
||||||
Extract Cypher code from a text.
|
|
||||||
Args:
|
|
||||||
text: Text to extract Cypher code from.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Cypher code extracted from the text.
|
|
||||||
"""
|
|
||||||
# The pattern to find Cypher code enclosed in triple backticks
|
|
||||||
pattern = r"```(.*?)```"
|
|
||||||
|
|
||||||
# Find all matches in the input text
|
|
||||||
matches = re.findall(pattern, text, re.DOTALL)
|
|
||||||
|
|
||||||
return matches[0] if matches else text
|
|
||||||
|
|
||||||
|
|
||||||
class FalkorDBQAChain(Chain):
|
|
||||||
"""Chain for question-answering against a graph by generating Cypher statements.
|
|
||||||
|
|
||||||
*Security note*: Make sure that the database connection uses credentials
|
|
||||||
that are narrowly-scoped to only include necessary permissions.
|
|
||||||
Failure to do so may result in data corruption or loss, since the calling
|
|
||||||
code may attempt commands that would result in deletion, mutation
|
|
||||||
of data if appropriately prompted or reading sensitive data if such
|
|
||||||
data is present in the database.
|
|
||||||
The best way to guard against such negative outcomes is to (as appropriate)
|
|
||||||
limit the permissions granted to the credentials used with this tool.
|
|
||||||
|
|
||||||
See https://python.langchain.com/docs/security for more information.
|
|
||||||
"""
|
|
||||||
|
|
||||||
graph: FalkorDBGraph = Field(exclude=True)
|
|
||||||
cypher_generation_chain: LLMChain
|
|
||||||
qa_chain: LLMChain
|
|
||||||
input_key: str = "query" #: :meta private:
|
|
||||||
output_key: str = "result" #: :meta private:
|
|
||||||
top_k: int = 10
|
|
||||||
"""Number of results to return from the query"""
|
|
||||||
return_intermediate_steps: bool = False
|
|
||||||
"""Whether or not to return the intermediate steps along with the final answer."""
|
|
||||||
return_direct: bool = False
|
|
||||||
"""Whether or not to return the result of querying the graph directly."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
"""Return the input keys.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.input_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
"""Return the output keys.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
_output_keys = [self.output_key]
|
|
||||||
return _output_keys
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _chain_type(self) -> str:
|
|
||||||
return "graph_cypher_chain"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm(
|
|
||||||
cls,
|
|
||||||
llm: BaseLanguageModel,
|
|
||||||
*,
|
|
||||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
|
||||||
cypher_prompt: BasePromptTemplate = CYPHER_GENERATION_PROMPT,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> FalkorDBQAChain:
|
|
||||||
"""Initialize from LLM."""
|
|
||||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
|
||||||
cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
qa_chain=qa_chain,
|
|
||||||
cypher_generation_chain=cypher_generation_chain,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Generate Cypher statement, use it to look up in db and answer question."""
|
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
||||||
callbacks = _run_manager.get_child()
|
|
||||||
question = inputs[self.input_key]
|
|
||||||
|
|
||||||
intermediate_steps: List = []
|
|
||||||
|
|
||||||
generated_cypher = self.cypher_generation_chain.run(
|
|
||||||
{"question": question, "schema": self.graph.schema}, callbacks=callbacks
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract Cypher code if it is wrapped in backticks
|
|
||||||
generated_cypher = extract_cypher(generated_cypher)
|
|
||||||
|
|
||||||
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(
|
|
||||||
generated_cypher, color="green", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
intermediate_steps.append({"query": generated_cypher})
|
|
||||||
|
|
||||||
# Retrieve and limit the number of results
|
|
||||||
context = self.graph.query(generated_cypher)[: self.top_k]
|
|
||||||
|
|
||||||
if self.return_direct:
|
|
||||||
final_result = context
|
|
||||||
else:
|
|
||||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(
|
|
||||||
str(context), color="green", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
intermediate_steps.append({"context": context})
|
|
||||||
|
|
||||||
result = self.qa_chain(
|
|
||||||
{"question": question, "context": context},
|
|
||||||
callbacks=callbacks,
|
|
||||||
)
|
|
||||||
final_result = result[self.qa_chain.output_key]
|
|
||||||
|
|
||||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
|
||||||
if self.return_intermediate_steps:
|
|
||||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
|
||||||
|
|
||||||
return chain_result
|
|
||||||
|
@ -1,221 +1,3 @@
|
|||||||
"""Question answering over a graph."""
|
from langchain_community.chains.graph_qa.gremlin import GremlinQAChain, extract_gremlin
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
__all__ = ["GremlinQAChain", "extract_gremlin"]
|
||||||
|
|
||||||
from langchain_community.graphs import GremlinGraph
|
|
||||||
from langchain_core.callbacks.manager import CallbackManager, CallbackManagerForChainRun
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
|
||||||
from langchain_core.pydantic_v1 import Field
|
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.chains.graph_qa.prompts import (
|
|
||||||
CYPHER_QA_PROMPT,
|
|
||||||
GRAPHDB_SPARQL_FIX_TEMPLATE,
|
|
||||||
GREMLIN_GENERATION_PROMPT,
|
|
||||||
)
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
|
|
||||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
|
||||||
|
|
||||||
|
|
||||||
def extract_gremlin(text: str) -> str:
|
|
||||||
"""Extract Gremlin code from a text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Text to extract Gremlin code from.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Gremlin code extracted from the text.
|
|
||||||
"""
|
|
||||||
text = text.replace("`", "")
|
|
||||||
if text.startswith("gremlin"):
|
|
||||||
text = text[len("gremlin") :]
|
|
||||||
return text.replace("\n", "")
|
|
||||||
|
|
||||||
|
|
||||||
class GremlinQAChain(Chain):
|
|
||||||
"""Chain for question-answering against a graph by generating gremlin statements.
|
|
||||||
|
|
||||||
*Security note*: Make sure that the database connection uses credentials
|
|
||||||
that are narrowly-scoped to only include necessary permissions.
|
|
||||||
Failure to do so may result in data corruption or loss, since the calling
|
|
||||||
code may attempt commands that would result in deletion, mutation
|
|
||||||
of data if appropriately prompted or reading sensitive data if such
|
|
||||||
data is present in the database.
|
|
||||||
The best way to guard against such negative outcomes is to (as appropriate)
|
|
||||||
limit the permissions granted to the credentials used with this tool.
|
|
||||||
|
|
||||||
See https://python.langchain.com/docs/security for more information.
|
|
||||||
"""
|
|
||||||
|
|
||||||
graph: GremlinGraph = Field(exclude=True)
|
|
||||||
gremlin_generation_chain: LLMChain
|
|
||||||
qa_chain: LLMChain
|
|
||||||
gremlin_fix_chain: LLMChain
|
|
||||||
max_fix_retries: int = 3
|
|
||||||
input_key: str = "query" #: :meta private:
|
|
||||||
output_key: str = "result" #: :meta private:
|
|
||||||
top_k: int = 100
|
|
||||||
return_direct: bool = False
|
|
||||||
return_intermediate_steps: bool = False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
"""Input keys.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.input_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
"""Output keys.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
_output_keys = [self.output_key]
|
|
||||||
return _output_keys
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm(
|
|
||||||
cls,
|
|
||||||
llm: BaseLanguageModel,
|
|
||||||
*,
|
|
||||||
gremlin_fix_prompt: BasePromptTemplate = PromptTemplate(
|
|
||||||
input_variables=["error_message", "generated_sparql", "schema"],
|
|
||||||
template=GRAPHDB_SPARQL_FIX_TEMPLATE.replace("SPARQL", "Gremlin").replace(
|
|
||||||
"in Turtle format", ""
|
|
||||||
),
|
|
||||||
),
|
|
||||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
|
||||||
gremlin_prompt: BasePromptTemplate = GREMLIN_GENERATION_PROMPT,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> GremlinQAChain:
|
|
||||||
"""Initialize from LLM."""
|
|
||||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
|
||||||
gremlin_generation_chain = LLMChain(llm=llm, prompt=gremlin_prompt)
|
|
||||||
gremlinl_fix_chain = LLMChain(llm=llm, prompt=gremlin_fix_prompt)
|
|
||||||
return cls(
|
|
||||||
qa_chain=qa_chain,
|
|
||||||
gremlin_generation_chain=gremlin_generation_chain,
|
|
||||||
gremlin_fix_chain=gremlinl_fix_chain,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
"""Generate gremlin statement, use it to look up in db and answer question."""
|
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
||||||
callbacks = _run_manager.get_child()
|
|
||||||
question = inputs[self.input_key]
|
|
||||||
|
|
||||||
intermediate_steps: List = []
|
|
||||||
|
|
||||||
chain_response = self.gremlin_generation_chain.invoke(
|
|
||||||
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
|
|
||||||
)
|
|
||||||
|
|
||||||
generated_gremlin = extract_gremlin(
|
|
||||||
chain_response[self.gremlin_generation_chain.output_key]
|
|
||||||
)
|
|
||||||
|
|
||||||
_run_manager.on_text("Generated gremlin:", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(
|
|
||||||
generated_gremlin, color="green", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
intermediate_steps.append({"query": generated_gremlin})
|
|
||||||
|
|
||||||
if generated_gremlin:
|
|
||||||
context = self.execute_with_retry(
|
|
||||||
_run_manager, callbacks, generated_gremlin
|
|
||||||
)[: self.top_k]
|
|
||||||
else:
|
|
||||||
context = []
|
|
||||||
|
|
||||||
if self.return_direct:
|
|
||||||
final_result = context
|
|
||||||
else:
|
|
||||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(
|
|
||||||
str(context), color="green", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
intermediate_steps.append({"context": context})
|
|
||||||
|
|
||||||
result = self.qa_chain.invoke(
|
|
||||||
{"question": question, "context": context},
|
|
||||||
callbacks=callbacks,
|
|
||||||
)
|
|
||||||
final_result = result[self.qa_chain.output_key]
|
|
||||||
|
|
||||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
|
||||||
if self.return_intermediate_steps:
|
|
||||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
|
||||||
|
|
||||||
return chain_result
|
|
||||||
|
|
||||||
def execute_query(self, query: str) -> List[Any]:
|
|
||||||
try:
|
|
||||||
return self.graph.query(query)
|
|
||||||
except Exception as e:
|
|
||||||
if hasattr(e, "status_message"):
|
|
||||||
raise ValueError(e.status_message)
|
|
||||||
else:
|
|
||||||
raise ValueError(str(e))
|
|
||||||
|
|
||||||
def execute_with_retry(
|
|
||||||
self,
|
|
||||||
_run_manager: CallbackManagerForChainRun,
|
|
||||||
callbacks: CallbackManager,
|
|
||||||
generated_gremlin: str,
|
|
||||||
) -> List[Any]:
|
|
||||||
try:
|
|
||||||
return self.execute_query(generated_gremlin)
|
|
||||||
except Exception as e:
|
|
||||||
retries = 0
|
|
||||||
error_message = str(e)
|
|
||||||
self.log_invalid_query(_run_manager, generated_gremlin, error_message)
|
|
||||||
|
|
||||||
while retries < self.max_fix_retries:
|
|
||||||
try:
|
|
||||||
fix_chain_result = self.gremlin_fix_chain.invoke(
|
|
||||||
{
|
|
||||||
"error_message": error_message,
|
|
||||||
# we are borrowing template from sparql
|
|
||||||
"generated_sparql": generated_gremlin,
|
|
||||||
"schema": self.schema,
|
|
||||||
},
|
|
||||||
callbacks=callbacks,
|
|
||||||
)
|
|
||||||
fixed_gremlin = fix_chain_result[self.gremlin_fix_chain.output_key]
|
|
||||||
return self.execute_query(fixed_gremlin)
|
|
||||||
except Exception as e:
|
|
||||||
retries += 1
|
|
||||||
parse_exception = str(e)
|
|
||||||
self.log_invalid_query(_run_manager, fixed_gremlin, parse_exception)
|
|
||||||
|
|
||||||
raise ValueError("The generated Gremlin query is invalid.")
|
|
||||||
|
|
||||||
def log_invalid_query(
|
|
||||||
self,
|
|
||||||
_run_manager: CallbackManagerForChainRun,
|
|
||||||
generated_query: str,
|
|
||||||
error_message: str,
|
|
||||||
) -> None:
|
|
||||||
_run_manager.on_text("Invalid Gremlin query: ", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(
|
|
||||||
generated_query, color="red", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
_run_manager.on_text(
|
|
||||||
"Gremlin Query Parse Error: ", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
_run_manager.on_text(
|
|
||||||
error_message, color="red", end="\n\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
@ -1,106 +1,3 @@
|
|||||||
"""Question answering over a graph."""
|
from langchain_community.chains.graph_qa.hugegraph import HugeGraphQAChain
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
__all__ = ["HugeGraphQAChain"]
|
||||||
|
|
||||||
from langchain_community.graphs.hugegraph import HugeGraph
|
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
|
||||||
from langchain_core.pydantic_v1 import Field
|
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.chains.graph_qa.prompts import (
|
|
||||||
CYPHER_QA_PROMPT,
|
|
||||||
GREMLIN_GENERATION_PROMPT,
|
|
||||||
)
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
|
|
||||||
|
|
||||||
class HugeGraphQAChain(Chain):
|
|
||||||
"""Chain for question-answering against a graph by generating gremlin statements.
|
|
||||||
|
|
||||||
*Security note*: Make sure that the database connection uses credentials
|
|
||||||
that are narrowly-scoped to only include necessary permissions.
|
|
||||||
Failure to do so may result in data corruption or loss, since the calling
|
|
||||||
code may attempt commands that would result in deletion, mutation
|
|
||||||
of data if appropriately prompted or reading sensitive data if such
|
|
||||||
data is present in the database.
|
|
||||||
The best way to guard against such negative outcomes is to (as appropriate)
|
|
||||||
limit the permissions granted to the credentials used with this tool.
|
|
||||||
|
|
||||||
See https://python.langchain.com/docs/security for more information.
|
|
||||||
"""
|
|
||||||
|
|
||||||
graph: HugeGraph = Field(exclude=True)
|
|
||||||
gremlin_generation_chain: LLMChain
|
|
||||||
qa_chain: LLMChain
|
|
||||||
input_key: str = "query" #: :meta private:
|
|
||||||
output_key: str = "result" #: :meta private:
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
"""Input keys.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.input_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
"""Output keys.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
_output_keys = [self.output_key]
|
|
||||||
return _output_keys
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm(
|
|
||||||
cls,
|
|
||||||
llm: BaseLanguageModel,
|
|
||||||
*,
|
|
||||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
|
||||||
gremlin_prompt: BasePromptTemplate = GREMLIN_GENERATION_PROMPT,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> HugeGraphQAChain:
|
|
||||||
"""Initialize from LLM."""
|
|
||||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
|
||||||
gremlin_generation_chain = LLMChain(llm=llm, prompt=gremlin_prompt)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
qa_chain=qa_chain,
|
|
||||||
gremlin_generation_chain=gremlin_generation_chain,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
"""Generate gremlin statement, use it to look up in db and answer question."""
|
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
||||||
callbacks = _run_manager.get_child()
|
|
||||||
question = inputs[self.input_key]
|
|
||||||
|
|
||||||
generated_gremlin = self.gremlin_generation_chain.run(
|
|
||||||
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
|
|
||||||
)
|
|
||||||
|
|
||||||
_run_manager.on_text("Generated gremlin:", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(
|
|
||||||
generated_gremlin, color="green", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
context = self.graph.query(generated_gremlin)
|
|
||||||
|
|
||||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(
|
|
||||||
str(context), color="green", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
result = self.qa_chain(
|
|
||||||
{"question": question, "context": context},
|
|
||||||
callbacks=callbacks,
|
|
||||||
)
|
|
||||||
return {self.output_key: result[self.qa_chain.output_key]}
|
|
||||||
|
@ -1,131 +1,7 @@
|
|||||||
"""Question answering over a graph."""
|
from langchain_community.chains.graph_qa.kuzu import (
|
||||||
from __future__ import annotations
|
KuzuQAChain,
|
||||||
|
extract_cypher,
|
||||||
|
remove_prefix,
|
||||||
|
)
|
||||||
|
|
||||||
import re
|
__all__ = ["KuzuQAChain", "extract_cypher", "remove_prefix"]
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from langchain_community.graphs.kuzu_graph import KuzuGraph
|
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
|
||||||
from langchain_core.pydantic_v1 import Field
|
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, KUZU_GENERATION_PROMPT
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
|
|
||||||
|
|
||||||
def remove_prefix(text: str, prefix: str) -> str:
|
|
||||||
if text.startswith(prefix):
|
|
||||||
return text[len(prefix) :]
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def extract_cypher(text: str) -> str:
|
|
||||||
"""Extract Cypher code from a text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Text to extract Cypher code from.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Cypher code extracted from the text.
|
|
||||||
"""
|
|
||||||
# The pattern to find Cypher code enclosed in triple backticks
|
|
||||||
pattern = r"```(.*?)```"
|
|
||||||
|
|
||||||
# Find all matches in the input text
|
|
||||||
matches = re.findall(pattern, text, re.DOTALL)
|
|
||||||
|
|
||||||
return matches[0] if matches else text
|
|
||||||
|
|
||||||
|
|
||||||
class KuzuQAChain(Chain):
|
|
||||||
"""Question-answering against a graph by generating Cypher statements for Kùzu.
|
|
||||||
|
|
||||||
*Security note*: Make sure that the database connection uses credentials
|
|
||||||
that are narrowly-scoped to only include necessary permissions.
|
|
||||||
Failure to do so may result in data corruption or loss, since the calling
|
|
||||||
code may attempt commands that would result in deletion, mutation
|
|
||||||
of data if appropriately prompted or reading sensitive data if such
|
|
||||||
data is present in the database.
|
|
||||||
The best way to guard against such negative outcomes is to (as appropriate)
|
|
||||||
limit the permissions granted to the credentials used with this tool.
|
|
||||||
|
|
||||||
See https://python.langchain.com/docs/security for more information.
|
|
||||||
"""
|
|
||||||
|
|
||||||
graph: KuzuGraph = Field(exclude=True)
|
|
||||||
cypher_generation_chain: LLMChain
|
|
||||||
qa_chain: LLMChain
|
|
||||||
input_key: str = "query" #: :meta private:
|
|
||||||
output_key: str = "result" #: :meta private:
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
"""Return the input keys.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.input_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
"""Return the output keys.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
_output_keys = [self.output_key]
|
|
||||||
return _output_keys
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm(
|
|
||||||
cls,
|
|
||||||
llm: BaseLanguageModel,
|
|
||||||
*,
|
|
||||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
|
||||||
cypher_prompt: BasePromptTemplate = KUZU_GENERATION_PROMPT,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> KuzuQAChain:
|
|
||||||
"""Initialize from LLM."""
|
|
||||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
|
||||||
cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
qa_chain=qa_chain,
|
|
||||||
cypher_generation_chain=cypher_generation_chain,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
"""Generate Cypher statement, use it to look up in db and answer question."""
|
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
||||||
callbacks = _run_manager.get_child()
|
|
||||||
question = inputs[self.input_key]
|
|
||||||
|
|
||||||
generated_cypher = self.cypher_generation_chain.run(
|
|
||||||
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
|
|
||||||
)
|
|
||||||
# Extract Cypher code if it is wrapped in triple backticks
|
|
||||||
# with the language marker "cypher"
|
|
||||||
generated_cypher = remove_prefix(extract_cypher(generated_cypher), "cypher")
|
|
||||||
|
|
||||||
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(
|
|
||||||
generated_cypher, color="green", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
context = self.graph.query(generated_cypher)
|
|
||||||
|
|
||||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(
|
|
||||||
str(context), color="green", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
result = self.qa_chain(
|
|
||||||
{"question": question, "context": context},
|
|
||||||
callbacks=callbacks,
|
|
||||||
)
|
|
||||||
return {self.output_key: result[self.qa_chain.output_key]}
|
|
||||||
|
@ -1,103 +1,3 @@
|
|||||||
"""Question answering over a graph."""
|
from langchain_community.chains.graph_qa.nebulagraph import NebulaGraphQAChain
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
__all__ = ["NebulaGraphQAChain"]
|
||||||
|
|
||||||
from langchain_community.graphs.nebula_graph import NebulaGraph
|
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
|
||||||
from langchain_core.pydantic_v1 import Field
|
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, NGQL_GENERATION_PROMPT
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
|
|
||||||
|
|
||||||
class NebulaGraphQAChain(Chain):
|
|
||||||
"""Chain for question-answering against a graph by generating nGQL statements.
|
|
||||||
|
|
||||||
*Security note*: Make sure that the database connection uses credentials
|
|
||||||
that are narrowly-scoped to only include necessary permissions.
|
|
||||||
Failure to do so may result in data corruption or loss, since the calling
|
|
||||||
code may attempt commands that would result in deletion, mutation
|
|
||||||
of data if appropriately prompted or reading sensitive data if such
|
|
||||||
data is present in the database.
|
|
||||||
The best way to guard against such negative outcomes is to (as appropriate)
|
|
||||||
limit the permissions granted to the credentials used with this tool.
|
|
||||||
|
|
||||||
See https://python.langchain.com/docs/security for more information.
|
|
||||||
"""
|
|
||||||
|
|
||||||
graph: NebulaGraph = Field(exclude=True)
|
|
||||||
ngql_generation_chain: LLMChain
|
|
||||||
qa_chain: LLMChain
|
|
||||||
input_key: str = "query" #: :meta private:
|
|
||||||
output_key: str = "result" #: :meta private:
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
"""Return the input keys.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.input_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
"""Return the output keys.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
_output_keys = [self.output_key]
|
|
||||||
return _output_keys
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm(
|
|
||||||
cls,
|
|
||||||
llm: BaseLanguageModel,
|
|
||||||
*,
|
|
||||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
|
||||||
ngql_prompt: BasePromptTemplate = NGQL_GENERATION_PROMPT,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> NebulaGraphQAChain:
|
|
||||||
"""Initialize from LLM."""
|
|
||||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
|
||||||
ngql_generation_chain = LLMChain(llm=llm, prompt=ngql_prompt)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
qa_chain=qa_chain,
|
|
||||||
ngql_generation_chain=ngql_generation_chain,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
"""Generate nGQL statement, use it to look up in db and answer question."""
|
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
||||||
callbacks = _run_manager.get_child()
|
|
||||||
question = inputs[self.input_key]
|
|
||||||
|
|
||||||
generated_ngql = self.ngql_generation_chain.run(
|
|
||||||
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
|
|
||||||
)
|
|
||||||
|
|
||||||
_run_manager.on_text("Generated nGQL:", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(
|
|
||||||
generated_ngql, color="green", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
context = self.graph.query(generated_ngql)
|
|
||||||
|
|
||||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(
|
|
||||||
str(context), color="green", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
result = self.qa_chain(
|
|
||||||
{"question": question, "context": context},
|
|
||||||
callbacks=callbacks,
|
|
||||||
)
|
|
||||||
return {self.output_key: result[self.qa_chain.output_key]}
|
|
||||||
|
@ -1,217 +1,13 @@
|
|||||||
from __future__ import annotations
|
from langchain_community.chains.graph_qa.neptune_cypher import (
|
||||||
|
NeptuneOpenCypherQAChain,
|
||||||
import re
|
extract_cypher,
|
||||||
from typing import Any, Dict, List, Optional
|
trim_query,
|
||||||
|
use_simple_prompt,
|
||||||
from langchain_community.graphs import BaseNeptuneGraph
|
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
|
||||||
from langchain_core.prompts.base import BasePromptTemplate
|
|
||||||
from langchain_core.pydantic_v1 import Field
|
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.chains.graph_qa.prompts import (
|
|
||||||
CYPHER_QA_PROMPT,
|
|
||||||
NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
|
|
||||||
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT,
|
|
||||||
)
|
)
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
from langchain.chains.prompt_selector import ConditionalPromptSelector
|
|
||||||
|
|
||||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
|
||||||
|
|
||||||
|
|
||||||
def trim_query(query: str) -> str:
|
|
||||||
"""Trim the query to only include Cypher keywords."""
|
|
||||||
keywords = (
|
|
||||||
"CALL",
|
|
||||||
"CREATE",
|
|
||||||
"DELETE",
|
|
||||||
"DETACH",
|
|
||||||
"LIMIT",
|
|
||||||
"MATCH",
|
|
||||||
"MERGE",
|
|
||||||
"OPTIONAL",
|
|
||||||
"ORDER",
|
|
||||||
"REMOVE",
|
|
||||||
"RETURN",
|
|
||||||
"SET",
|
|
||||||
"SKIP",
|
|
||||||
"UNWIND",
|
|
||||||
"WITH",
|
|
||||||
"WHERE",
|
|
||||||
"//",
|
|
||||||
)
|
|
||||||
|
|
||||||
lines = query.split("\n")
|
|
||||||
new_query = ""
|
|
||||||
|
|
||||||
for line in lines:
|
|
||||||
if line.strip().upper().startswith(keywords):
|
|
||||||
new_query += line + "\n"
|
|
||||||
|
|
||||||
return new_query
|
|
||||||
|
|
||||||
|
|
||||||
def extract_cypher(text: str) -> str:
|
|
||||||
"""Extract Cypher code from text using Regex."""
|
|
||||||
# The pattern to find Cypher code enclosed in triple backticks
|
|
||||||
pattern = r"```(.*?)```"
|
|
||||||
|
|
||||||
# Find all matches in the input text
|
|
||||||
matches = re.findall(pattern, text, re.DOTALL)
|
|
||||||
|
|
||||||
return matches[0] if matches else text
|
|
||||||
|
|
||||||
|
|
||||||
def use_simple_prompt(llm: BaseLanguageModel) -> bool:
|
|
||||||
"""Decides whether to use the simple prompt"""
|
|
||||||
if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Bedrock anthropic
|
|
||||||
if hasattr(llm, "model_id") and "anthropic" in llm.model_id: # type: ignore
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
PROMPT_SELECTOR = ConditionalPromptSelector(
|
|
||||||
default_prompt=NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
|
|
||||||
conditionals=[(use_simple_prompt, NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT)],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class NeptuneOpenCypherQAChain(Chain):
|
|
||||||
"""Chain for question-answering against a Neptune graph
|
|
||||||
by generating openCypher statements.
|
|
||||||
|
|
||||||
*Security note*: Make sure that the database connection uses credentials
|
|
||||||
that are narrowly-scoped to only include necessary permissions.
|
|
||||||
Failure to do so may result in data corruption or loss, since the calling
|
|
||||||
code may attempt commands that would result in deletion, mutation
|
|
||||||
of data if appropriately prompted or reading sensitive data if such
|
|
||||||
data is present in the database.
|
|
||||||
The best way to guard against such negative outcomes is to (as appropriate)
|
|
||||||
limit the permissions granted to the credentials used with this tool.
|
|
||||||
|
|
||||||
See https://python.langchain.com/docs/security for more information.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
chain = NeptuneOpenCypherQAChain.from_llm(
|
|
||||||
llm=llm,
|
|
||||||
graph=graph
|
|
||||||
)
|
|
||||||
response = chain.run(query)
|
|
||||||
"""
|
|
||||||
|
|
||||||
graph: BaseNeptuneGraph = Field(exclude=True)
|
|
||||||
cypher_generation_chain: LLMChain
|
|
||||||
qa_chain: LLMChain
|
|
||||||
input_key: str = "query" #: :meta private:
|
|
||||||
output_key: str = "result" #: :meta private:
|
|
||||||
top_k: int = 10
|
|
||||||
return_intermediate_steps: bool = False
|
|
||||||
"""Whether or not to return the intermediate steps along with the final answer."""
|
|
||||||
return_direct: bool = False
|
|
||||||
"""Whether or not to return the result of querying the graph directly."""
|
|
||||||
extra_instructions: Optional[str] = None
|
|
||||||
"""Extra instructions by the appended to the query generation prompt."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
"""Return the input keys.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.input_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
"""Return the output keys.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
_output_keys = [self.output_key]
|
|
||||||
return _output_keys
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm(
|
|
||||||
cls,
|
|
||||||
llm: BaseLanguageModel,
|
|
||||||
*,
|
|
||||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
|
||||||
cypher_prompt: Optional[BasePromptTemplate] = None,
|
|
||||||
extra_instructions: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> NeptuneOpenCypherQAChain:
|
|
||||||
"""Initialize from LLM."""
|
|
||||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
|
||||||
|
|
||||||
_cypher_prompt = cypher_prompt or PROMPT_SELECTOR.get_prompt(llm)
|
|
||||||
cypher_generation_chain = LLMChain(llm=llm, prompt=_cypher_prompt)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
qa_chain=qa_chain,
|
|
||||||
cypher_generation_chain=cypher_generation_chain,
|
|
||||||
extra_instructions=extra_instructions,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Generate Cypher statement, use it to look up in db and answer question."""
|
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
||||||
callbacks = _run_manager.get_child()
|
|
||||||
question = inputs[self.input_key]
|
|
||||||
|
|
||||||
intermediate_steps: List = []
|
|
||||||
|
|
||||||
generated_cypher = self.cypher_generation_chain.run(
|
|
||||||
{
|
|
||||||
"question": question,
|
|
||||||
"schema": self.graph.get_schema,
|
|
||||||
"extra_instructions": self.extra_instructions or "",
|
|
||||||
},
|
|
||||||
callbacks=callbacks,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract Cypher code if it is wrapped in backticks
|
|
||||||
generated_cypher = extract_cypher(generated_cypher)
|
|
||||||
generated_cypher = trim_query(generated_cypher)
|
|
||||||
|
|
||||||
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(
|
|
||||||
generated_cypher, color="green", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
intermediate_steps.append({"query": generated_cypher})
|
|
||||||
|
|
||||||
context = self.graph.query(generated_cypher)
|
|
||||||
|
|
||||||
if self.return_direct:
|
|
||||||
final_result = context
|
|
||||||
else:
|
|
||||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(
|
|
||||||
str(context), color="green", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
intermediate_steps.append({"context": context})
|
|
||||||
|
|
||||||
result = self.qa_chain(
|
|
||||||
{"question": question, "context": context},
|
|
||||||
callbacks=callbacks,
|
|
||||||
)
|
|
||||||
final_result = result[self.qa_chain.output_key]
|
|
||||||
|
|
||||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
|
||||||
if self.return_intermediate_steps:
|
|
||||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
|
||||||
|
|
||||||
return chain_result
|
__all__ = [
|
||||||
|
"NeptuneOpenCypherQAChain",
|
||||||
|
"extract_cypher",
|
||||||
|
"trim_query",
|
||||||
|
"use_simple_prompt",
|
||||||
|
]
|
||||||
|
@ -1,196 +1,6 @@
|
|||||||
"""
|
from langchain_community.chains.graph_qa.neptune_sparql import (
|
||||||
Question answering over an RDF or OWL graph using SPARQL.
|
NeptuneSparqlQAChain,
|
||||||
"""
|
extract_sparql,
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from langchain_community.graphs import NeptuneRdfGraph
|
|
||||||
from langchain_core.callbacks.manager import CallbackManagerForChainRun
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
|
||||||
from langchain_core.prompts.base import BasePromptTemplate
|
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
|
||||||
from langchain_core.pydantic_v1 import Field
|
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.chains.graph_qa.prompts import SPARQL_QA_PROMPT
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
|
|
||||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
|
||||||
|
|
||||||
SPARQL_GENERATION_TEMPLATE = """
|
|
||||||
Task: Generate a SPARQL SELECT statement for querying a graph database.
|
|
||||||
For instance, to find all email addresses of John Doe, the following
|
|
||||||
query in backticks would be suitable:
|
|
||||||
```
|
|
||||||
PREFIX foaf: <http://xmlns.com/foaf/0.1/>
|
|
||||||
SELECT ?email
|
|
||||||
WHERE {{
|
|
||||||
?person foaf:name "John Doe" .
|
|
||||||
?person foaf:mbox ?email .
|
|
||||||
}}
|
|
||||||
```
|
|
||||||
Instructions:
|
|
||||||
Use only the node types and properties provided in the schema.
|
|
||||||
Do not use any node types and properties that are not explicitly provided.
|
|
||||||
Include all necessary prefixes.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
Schema:
|
|
||||||
{schema}
|
|
||||||
Note: Be as concise as possible.
|
|
||||||
Do not include any explanations or apologies in your responses.
|
|
||||||
Do not respond to any questions that ask for anything else than
|
|
||||||
for you to construct a SPARQL query.
|
|
||||||
Do not include any text except the SPARQL query generated.
|
|
||||||
|
|
||||||
The question is:
|
|
||||||
{prompt}"""
|
|
||||||
|
|
||||||
SPARQL_GENERATION_PROMPT = PromptTemplate(
|
|
||||||
input_variables=["schema", "prompt"], template=SPARQL_GENERATION_TEMPLATE
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
__all__ = ["NeptuneSparqlQAChain", "extract_sparql"]
|
||||||
def extract_sparql(query: str) -> str:
|
|
||||||
query = query.strip()
|
|
||||||
querytoks = query.split("```")
|
|
||||||
if len(querytoks) == 3:
|
|
||||||
query = querytoks[1]
|
|
||||||
|
|
||||||
if query.startswith("sparql"):
|
|
||||||
query = query[6:]
|
|
||||||
elif query.startswith("<sparql>") and query.endswith("</sparql>"):
|
|
||||||
query = query[8:-9]
|
|
||||||
return query
|
|
||||||
|
|
||||||
|
|
||||||
class NeptuneSparqlQAChain(Chain):
|
|
||||||
"""Chain for question-answering against a Neptune graph
|
|
||||||
by generating SPARQL statements.
|
|
||||||
|
|
||||||
*Security note*: Make sure that the database connection uses credentials
|
|
||||||
that are narrowly-scoped to only include necessary permissions.
|
|
||||||
Failure to do so may result in data corruption or loss, since the calling
|
|
||||||
code may attempt commands that would result in deletion, mutation
|
|
||||||
of data if appropriately prompted or reading sensitive data if such
|
|
||||||
data is present in the database.
|
|
||||||
The best way to guard against such negative outcomes is to (as appropriate)
|
|
||||||
limit the permissions granted to the credentials used with this tool.
|
|
||||||
|
|
||||||
See https://python.langchain.com/docs/security for more information.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
chain = NeptuneSparqlQAChain.from_llm(
|
|
||||||
llm=llm,
|
|
||||||
graph=graph
|
|
||||||
)
|
|
||||||
response = chain.invoke(query)
|
|
||||||
"""
|
|
||||||
|
|
||||||
graph: NeptuneRdfGraph = Field(exclude=True)
|
|
||||||
sparql_generation_chain: LLMChain
|
|
||||||
qa_chain: LLMChain
|
|
||||||
input_key: str = "query" #: :meta private:
|
|
||||||
output_key: str = "result" #: :meta private:
|
|
||||||
top_k: int = 10
|
|
||||||
return_intermediate_steps: bool = False
|
|
||||||
"""Whether or not to return the intermediate steps along with the final answer."""
|
|
||||||
return_direct: bool = False
|
|
||||||
"""Whether or not to return the result of querying the graph directly."""
|
|
||||||
extra_instructions: Optional[str] = None
|
|
||||||
"""Extra instructions by the appended to the query generation prompt."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
return [self.input_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
_output_keys = [self.output_key]
|
|
||||||
return _output_keys
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm(
|
|
||||||
cls,
|
|
||||||
llm: BaseLanguageModel,
|
|
||||||
*,
|
|
||||||
qa_prompt: BasePromptTemplate = SPARQL_QA_PROMPT,
|
|
||||||
sparql_prompt: BasePromptTemplate = SPARQL_GENERATION_PROMPT,
|
|
||||||
examples: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> NeptuneSparqlQAChain:
|
|
||||||
"""Initialize from LLM."""
|
|
||||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
|
||||||
template_to_use = SPARQL_GENERATION_TEMPLATE
|
|
||||||
if examples:
|
|
||||||
template_to_use = template_to_use.replace(
|
|
||||||
"Examples:", "Examples: " + examples
|
|
||||||
)
|
|
||||||
sparql_prompt = PromptTemplate(
|
|
||||||
input_variables=["schema", "prompt"], template=template_to_use
|
|
||||||
)
|
|
||||||
sparql_generation_chain = LLMChain(llm=llm, prompt=sparql_prompt)
|
|
||||||
|
|
||||||
return cls( # type: ignore[call-arg]
|
|
||||||
qa_chain=qa_chain,
|
|
||||||
sparql_generation_chain=sparql_generation_chain,
|
|
||||||
examples=examples,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
"""
|
|
||||||
Generate SPARQL query, use it to retrieve a response from the gdb and answer
|
|
||||||
the question.
|
|
||||||
"""
|
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
||||||
callbacks = _run_manager.get_child()
|
|
||||||
prompt = inputs[self.input_key]
|
|
||||||
|
|
||||||
intermediate_steps: List = []
|
|
||||||
|
|
||||||
generated_sparql = self.sparql_generation_chain.run(
|
|
||||||
{"prompt": prompt, "schema": self.graph.get_schema}, callbacks=callbacks
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract SPARQL
|
|
||||||
generated_sparql = extract_sparql(generated_sparql)
|
|
||||||
|
|
||||||
_run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(
|
|
||||||
generated_sparql, color="green", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
intermediate_steps.append({"query": generated_sparql})
|
|
||||||
|
|
||||||
context = self.graph.query(generated_sparql)
|
|
||||||
|
|
||||||
if self.return_direct:
|
|
||||||
final_result = context
|
|
||||||
else:
|
|
||||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(
|
|
||||||
str(context), color="green", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
intermediate_steps.append({"context": context})
|
|
||||||
|
|
||||||
result = self.qa_chain(
|
|
||||||
{"prompt": prompt, "context": context},
|
|
||||||
callbacks=callbacks,
|
|
||||||
)
|
|
||||||
final_result = result[self.qa_chain.output_key]
|
|
||||||
|
|
||||||
chain_result: Dict[str, Any] = {self.output_key: final_result}
|
|
||||||
if self.return_intermediate_steps:
|
|
||||||
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
|
|
||||||
|
|
||||||
return chain_result
|
|
||||||
|
@ -1,190 +1,3 @@
|
|||||||
"""Question answering over a graph."""
|
from langchain_community.chains.graph_qa.ontotext_graphdb import OntotextGraphDBQAChain
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
__all__ = ["OntotextGraphDBQAChain"]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
import rdflib
|
|
||||||
|
|
||||||
from langchain_community.graphs import OntotextGraphDBGraph
|
|
||||||
from langchain_core.callbacks.manager import CallbackManager, CallbackManagerForChainRun
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
|
||||||
from langchain_core.prompts.base import BasePromptTemplate
|
|
||||||
from langchain_core.pydantic_v1 import Field
|
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.chains.graph_qa.prompts import (
|
|
||||||
GRAPHDB_QA_PROMPT,
|
|
||||||
GRAPHDB_SPARQL_FIX_PROMPT,
|
|
||||||
GRAPHDB_SPARQL_GENERATION_PROMPT,
|
|
||||||
)
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
|
|
||||||
|
|
||||||
class OntotextGraphDBQAChain(Chain):
|
|
||||||
"""Question-answering against Ontotext GraphDB
|
|
||||||
https://graphdb.ontotext.com/ by generating SPARQL queries.
|
|
||||||
|
|
||||||
*Security note*: Make sure that the database connection uses credentials
|
|
||||||
that are narrowly-scoped to only include necessary permissions.
|
|
||||||
Failure to do so may result in data corruption or loss, since the calling
|
|
||||||
code may attempt commands that would result in deletion, mutation
|
|
||||||
of data if appropriately prompted or reading sensitive data if such
|
|
||||||
data is present in the database.
|
|
||||||
The best way to guard against such negative outcomes is to (as appropriate)
|
|
||||||
limit the permissions granted to the credentials used with this tool.
|
|
||||||
|
|
||||||
See https://python.langchain.com/docs/security for more information.
|
|
||||||
"""
|
|
||||||
|
|
||||||
graph: OntotextGraphDBGraph = Field(exclude=True)
|
|
||||||
sparql_generation_chain: LLMChain
|
|
||||||
sparql_fix_chain: LLMChain
|
|
||||||
max_fix_retries: int
|
|
||||||
qa_chain: LLMChain
|
|
||||||
input_key: str = "query" #: :meta private:
|
|
||||||
output_key: str = "result" #: :meta private:
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
return [self.input_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
_output_keys = [self.output_key]
|
|
||||||
return _output_keys
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm(
|
|
||||||
cls,
|
|
||||||
llm: BaseLanguageModel,
|
|
||||||
*,
|
|
||||||
sparql_generation_prompt: BasePromptTemplate = GRAPHDB_SPARQL_GENERATION_PROMPT,
|
|
||||||
sparql_fix_prompt: BasePromptTemplate = GRAPHDB_SPARQL_FIX_PROMPT,
|
|
||||||
max_fix_retries: int = 5,
|
|
||||||
qa_prompt: BasePromptTemplate = GRAPHDB_QA_PROMPT,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> OntotextGraphDBQAChain:
|
|
||||||
"""Initialize from LLM."""
|
|
||||||
sparql_generation_chain = LLMChain(llm=llm, prompt=sparql_generation_prompt)
|
|
||||||
sparql_fix_chain = LLMChain(llm=llm, prompt=sparql_fix_prompt)
|
|
||||||
max_fix_retries = max_fix_retries
|
|
||||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
|
||||||
return cls(
|
|
||||||
qa_chain=qa_chain,
|
|
||||||
sparql_generation_chain=sparql_generation_chain,
|
|
||||||
sparql_fix_chain=sparql_fix_chain,
|
|
||||||
max_fix_retries=max_fix_retries,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
"""
|
|
||||||
Generate a SPARQL query, use it to retrieve a response from GraphDB and answer
|
|
||||||
the question.
|
|
||||||
"""
|
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
||||||
callbacks = _run_manager.get_child()
|
|
||||||
prompt = inputs[self.input_key]
|
|
||||||
ontology_schema = self.graph.get_schema
|
|
||||||
|
|
||||||
sparql_generation_chain_result = self.sparql_generation_chain.invoke(
|
|
||||||
{"prompt": prompt, "schema": ontology_schema}, callbacks=callbacks
|
|
||||||
)
|
|
||||||
generated_sparql = sparql_generation_chain_result[
|
|
||||||
self.sparql_generation_chain.output_key
|
|
||||||
]
|
|
||||||
|
|
||||||
generated_sparql = self._get_prepared_sparql_query(
|
|
||||||
_run_manager, callbacks, generated_sparql, ontology_schema
|
|
||||||
)
|
|
||||||
query_results = self._execute_query(generated_sparql)
|
|
||||||
|
|
||||||
qa_chain_result = self.qa_chain.invoke(
|
|
||||||
{"prompt": prompt, "context": query_results}, callbacks=callbacks
|
|
||||||
)
|
|
||||||
result = qa_chain_result[self.qa_chain.output_key]
|
|
||||||
return {self.output_key: result}
|
|
||||||
|
|
||||||
def _get_prepared_sparql_query(
|
|
||||||
self,
|
|
||||||
_run_manager: CallbackManagerForChainRun,
|
|
||||||
callbacks: CallbackManager,
|
|
||||||
generated_sparql: str,
|
|
||||||
ontology_schema: str,
|
|
||||||
) -> str:
|
|
||||||
try:
|
|
||||||
return self._prepare_sparql_query(_run_manager, generated_sparql)
|
|
||||||
except Exception as e:
|
|
||||||
retries = 0
|
|
||||||
error_message = str(e)
|
|
||||||
self._log_invalid_sparql_query(
|
|
||||||
_run_manager, generated_sparql, error_message
|
|
||||||
)
|
|
||||||
|
|
||||||
while retries < self.max_fix_retries:
|
|
||||||
try:
|
|
||||||
sparql_fix_chain_result = self.sparql_fix_chain.invoke(
|
|
||||||
{
|
|
||||||
"error_message": error_message,
|
|
||||||
"generated_sparql": generated_sparql,
|
|
||||||
"schema": ontology_schema,
|
|
||||||
},
|
|
||||||
callbacks=callbacks,
|
|
||||||
)
|
|
||||||
generated_sparql = sparql_fix_chain_result[
|
|
||||||
self.sparql_fix_chain.output_key
|
|
||||||
]
|
|
||||||
return self._prepare_sparql_query(_run_manager, generated_sparql)
|
|
||||||
except Exception as e:
|
|
||||||
retries += 1
|
|
||||||
parse_exception = str(e)
|
|
||||||
self._log_invalid_sparql_query(
|
|
||||||
_run_manager, generated_sparql, parse_exception
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError("The generated SPARQL query is invalid.")
|
|
||||||
|
|
||||||
def _prepare_sparql_query(
|
|
||||||
self, _run_manager: CallbackManagerForChainRun, generated_sparql: str
|
|
||||||
) -> str:
|
|
||||||
from rdflib.plugins.sparql import prepareQuery
|
|
||||||
|
|
||||||
prepareQuery(generated_sparql)
|
|
||||||
self._log_prepared_sparql_query(_run_manager, generated_sparql)
|
|
||||||
return generated_sparql
|
|
||||||
|
|
||||||
def _log_prepared_sparql_query(
|
|
||||||
self, _run_manager: CallbackManagerForChainRun, generated_query: str
|
|
||||||
) -> None:
|
|
||||||
_run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(
|
|
||||||
generated_query, color="green", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
def _log_invalid_sparql_query(
|
|
||||||
self,
|
|
||||||
_run_manager: CallbackManagerForChainRun,
|
|
||||||
generated_query: str,
|
|
||||||
error_message: str,
|
|
||||||
) -> None:
|
|
||||||
_run_manager.on_text("Invalid SPARQL query: ", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(
|
|
||||||
generated_query, color="red", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
_run_manager.on_text(
|
|
||||||
"SPARQL Query Parse Error: ", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
_run_manager.on_text(
|
|
||||||
error_message, color="red", end="\n\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
def _execute_query(self, query: str) -> List[rdflib.query.ResultRow]:
|
|
||||||
try:
|
|
||||||
return self.graph.query(query)
|
|
||||||
except Exception:
|
|
||||||
raise ValueError("Failed to execute the generated SPARQL query.")
|
|
||||||
|
@ -1,152 +1,3 @@
|
|||||||
"""
|
from langchain_community.chains.graph_qa.sparql import GraphSparqlQAChain
|
||||||
Question answering over an RDF or OWL graph using SPARQL.
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
__all__ = ["GraphSparqlQAChain"]
|
||||||
|
|
||||||
from langchain_community.graphs.rdf_graph import RdfGraph
|
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
|
||||||
from langchain_core.prompts.base import BasePromptTemplate
|
|
||||||
from langchain_core.pydantic_v1 import Field
|
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.chains.graph_qa.prompts import (
|
|
||||||
SPARQL_GENERATION_SELECT_PROMPT,
|
|
||||||
SPARQL_GENERATION_UPDATE_PROMPT,
|
|
||||||
SPARQL_INTENT_PROMPT,
|
|
||||||
SPARQL_QA_PROMPT,
|
|
||||||
)
|
|
||||||
from langchain.chains.llm import LLMChain
|
|
||||||
|
|
||||||
|
|
||||||
class GraphSparqlQAChain(Chain):
|
|
||||||
"""Question-answering against an RDF or OWL graph by generating SPARQL statements.
|
|
||||||
|
|
||||||
*Security note*: Make sure that the database connection uses credentials
|
|
||||||
that are narrowly-scoped to only include necessary permissions.
|
|
||||||
Failure to do so may result in data corruption or loss, since the calling
|
|
||||||
code may attempt commands that would result in deletion, mutation
|
|
||||||
of data if appropriately prompted or reading sensitive data if such
|
|
||||||
data is present in the database.
|
|
||||||
The best way to guard against such negative outcomes is to (as appropriate)
|
|
||||||
limit the permissions granted to the credentials used with this tool.
|
|
||||||
|
|
||||||
See https://python.langchain.com/docs/security for more information.
|
|
||||||
"""
|
|
||||||
|
|
||||||
graph: RdfGraph = Field(exclude=True)
|
|
||||||
sparql_generation_select_chain: LLMChain
|
|
||||||
sparql_generation_update_chain: LLMChain
|
|
||||||
sparql_intent_chain: LLMChain
|
|
||||||
qa_chain: LLMChain
|
|
||||||
return_sparql_query: bool = False
|
|
||||||
input_key: str = "query" #: :meta private:
|
|
||||||
output_key: str = "result" #: :meta private:
|
|
||||||
sparql_query_key: str = "sparql_query" #: :meta private:
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
"""Return the input keys.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.input_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
"""Return the output keys.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
_output_keys = [self.output_key]
|
|
||||||
return _output_keys
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_llm(
|
|
||||||
cls,
|
|
||||||
llm: BaseLanguageModel,
|
|
||||||
*,
|
|
||||||
qa_prompt: BasePromptTemplate = SPARQL_QA_PROMPT,
|
|
||||||
sparql_select_prompt: BasePromptTemplate = SPARQL_GENERATION_SELECT_PROMPT,
|
|
||||||
sparql_update_prompt: BasePromptTemplate = SPARQL_GENERATION_UPDATE_PROMPT,
|
|
||||||
sparql_intent_prompt: BasePromptTemplate = SPARQL_INTENT_PROMPT,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> GraphSparqlQAChain:
|
|
||||||
"""Initialize from LLM."""
|
|
||||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
|
||||||
sparql_generation_select_chain = LLMChain(llm=llm, prompt=sparql_select_prompt)
|
|
||||||
sparql_generation_update_chain = LLMChain(llm=llm, prompt=sparql_update_prompt)
|
|
||||||
sparql_intent_chain = LLMChain(llm=llm, prompt=sparql_intent_prompt)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
qa_chain=qa_chain,
|
|
||||||
sparql_generation_select_chain=sparql_generation_select_chain,
|
|
||||||
sparql_generation_update_chain=sparql_generation_update_chain,
|
|
||||||
sparql_intent_chain=sparql_intent_chain,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
"""
|
|
||||||
Generate SPARQL query, use it to retrieve a response from the gdb and answer
|
|
||||||
the question.
|
|
||||||
"""
|
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
||||||
callbacks = _run_manager.get_child()
|
|
||||||
prompt = inputs[self.input_key]
|
|
||||||
|
|
||||||
_intent = self.sparql_intent_chain.run({"prompt": prompt}, callbacks=callbacks)
|
|
||||||
intent = _intent.strip()
|
|
||||||
|
|
||||||
if "SELECT" in intent and "UPDATE" not in intent:
|
|
||||||
sparql_generation_chain = self.sparql_generation_select_chain
|
|
||||||
intent = "SELECT"
|
|
||||||
elif "UPDATE" in intent and "SELECT" not in intent:
|
|
||||||
sparql_generation_chain = self.sparql_generation_update_chain
|
|
||||||
intent = "UPDATE"
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"I am sorry, but this prompt seems to fit none of the currently "
|
|
||||||
"supported SPARQL query types, i.e., SELECT and UPDATE."
|
|
||||||
)
|
|
||||||
|
|
||||||
_run_manager.on_text("Identified intent:", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(intent, color="green", end="\n", verbose=self.verbose)
|
|
||||||
|
|
||||||
generated_sparql = sparql_generation_chain.run(
|
|
||||||
{"prompt": prompt, "schema": self.graph.get_schema}, callbacks=callbacks
|
|
||||||
)
|
|
||||||
|
|
||||||
_run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(
|
|
||||||
generated_sparql, color="green", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
|
|
||||||
if intent == "SELECT":
|
|
||||||
context = self.graph.query(generated_sparql)
|
|
||||||
|
|
||||||
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
|
|
||||||
_run_manager.on_text(
|
|
||||||
str(context), color="green", end="\n", verbose=self.verbose
|
|
||||||
)
|
|
||||||
result = self.qa_chain(
|
|
||||||
{"prompt": prompt, "context": context},
|
|
||||||
callbacks=callbacks,
|
|
||||||
)
|
|
||||||
res = result[self.qa_chain.output_key]
|
|
||||||
elif intent == "UPDATE":
|
|
||||||
self.graph.update(generated_sparql)
|
|
||||||
res = "Successfully inserted triples into the graph."
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported SPARQL query type.")
|
|
||||||
|
|
||||||
chain_result: Dict[str, Any] = {self.output_key: res}
|
|
||||||
if self.return_sparql_query:
|
|
||||||
chain_result[self.sparql_query_key] = generated_sparql
|
|
||||||
return chain_result
|
|
||||||
|
Loading…
Reference in New Issue