community[minor], langchian[minor]: Add Neptune Rdf graph and chain (#16650)

**Description**: This PR adds a chain for Amazon Neptune graph database
RDF format. It complements the existing Neptune Cypher chain. The PR
also includes a Neptune RDF graph class to connect to, introspect, and
query a Neptune RDF graph database from the chain. A sample notebook is
provided under docs that demonstrates the overall effect: invoking the
chain to make natural language queries against Neptune using an LLM.

**Issue**: This is a new feature
 
**Dependencies**: The RDF graph class depends on the AWS boto3 library
if using IAM authentication to connect to the Neptune database.

---------

Co-authored-by: Piyush Jain <piyushjain@duck.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/16497/head
mhavey 4 months ago committed by GitHub
parent e1cfd0f3e7
commit 1bbb64d956
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -0,0 +1,337 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Neptune SPARQL QA Chain\n",
"\n",
"This notebook shows use of LLM to query RDF graph in Amazon Neptune. This code uses a `NeptuneRdfGraph` class that connects with the Neptune database and loads it's schema. The `NeptuneSparqlQAChain` is used to connect the graph and LLM to ask natural language questions.\n",
"\n",
"Requirements for running this notebook:\n",
"- Neptune 1.2.x cluster accessible from this notebook\n",
"- Kernel with Python 3.9 or higher\n",
"- For Bedrock access, ensure IAM role has this policy\n",
"\n",
"```json\n",
"{\n",
" \"Action\": [\n",
" \"bedrock:ListFoundationModels\",\n",
" \"bedrock:InvokeModel\"\n",
" ],\n",
" \"Resource\": \"*\",\n",
" \"Effect\": \"Allow\"\n",
"}\n",
"```\n",
"\n",
"- S3 bucket for staging sample data, bucket should be in same account/region as Neptune."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Seed W3C organizational data\n",
"W3C org ontology plus some instances. \n",
"\n",
"You will need an S3 bucket in the same region and account. Set STAGE_BUCKET to name of that bucket."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"STAGE_BUCKET = \"<bucket-name>\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%bash -s \"$STAGE_BUCKET\"\n",
"\n",
"rm -rf data\n",
"mkdir -p data\n",
"cd data\n",
"echo getting org ontology and sample org instances\n",
"wget http://www.w3.org/ns/org.ttl \n",
"wget https://raw.githubusercontent.com/aws-samples/amazon-neptune-ontology-example-blog/main/data/example_org.ttl \n",
"\n",
"echo Copying org ttl to S3\n",
"aws s3 cp org.ttl s3://$1/org.ttl\n",
"aws s3 cp example_org.ttl s3://$1/example_org.ttl\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Bulk-load the org ttl - both ontology and instances"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load -s s3://{STAGE_BUCKET} -f turtle --store-to loadres --run"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_status {loadres['payload']['loadId']} --errors --details"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup Chain"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"EXAMPLES = \"\"\"\n",
"\n",
"<question>\n",
"Find organizations.\n",
"</question>\n",
"\n",
"<sparql>\n",
"PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#> \n",
"PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> \n",
"PREFIX org: <http://www.w3.org/ns/org#> \n",
"\n",
"select ?org ?orgName where {{\n",
" ?org rdfs:label ?orgName .\n",
"}} \n",
"</sparql>\n",
"\n",
"<question>\n",
"Find sites of an organization\n",
"</question>\n",
"\n",
"<sparql>\n",
"PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#> \n",
"PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> \n",
"PREFIX org: <http://www.w3.org/ns/org#> \n",
"\n",
"select ?org ?orgName ?siteName where {{\n",
" ?org rdfs:label ?orgName .\n",
" ?org org:hasSite/rdfs:label ?siteName . \n",
"}} \n",
"</sparql>\n",
"\n",
"<question>\n",
"Find suborganizations of an organization\n",
"</question>\n",
"\n",
"<sparql>\n",
"PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#> \n",
"PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> \n",
"PREFIX org: <http://www.w3.org/ns/org#> \n",
"\n",
"select ?org ?orgName ?subName where {{\n",
" ?org rdfs:label ?orgName .\n",
" ?org org:hasSubOrganization/rdfs:label ?subName .\n",
"}} \n",
"</sparql>\n",
"\n",
"<question>\n",
"Find organizational units of an organization\n",
"</question>\n",
"\n",
"<sparql>\n",
"PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#> \n",
"PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> \n",
"PREFIX org: <http://www.w3.org/ns/org#> \n",
"\n",
"select ?org ?orgName ?unitName where {{\n",
" ?org rdfs:label ?orgName .\n",
" ?org org:hasUnit/rdfs:label ?unitName . \n",
"}} \n",
"</sparql>\n",
"\n",
"<question>\n",
"Find members of an organization. Also find their manager, or the member they report to.\n",
"</question>\n",
"\n",
"<sparql>\n",
"PREFIX org: <http://www.w3.org/ns/org#> \n",
"PREFIX foaf: <http://xmlns.com/foaf/0.1/> \n",
"\n",
"select * where {{\n",
" ?person rdf:type foaf:Person .\n",
" ?person org:memberOf ?org .\n",
" OPTIONAL {{ ?person foaf:firstName ?firstName . }}\n",
" OPTIONAL {{ ?person foaf:family_name ?lastName . }}\n",
" OPTIONAL {{ ?person org:reportsTo ??manager }} .\n",
"}}\n",
"</sparql>\n",
"\n",
"\n",
"<question>\n",
"Find change events, such as mergers and acquisitions, of an organization\n",
"</question>\n",
"\n",
"<sparql>\n",
"PREFIX org: <http://www.w3.org/ns/org#> \n",
"\n",
"select ?event ?prop ?obj where {{\n",
" ?org rdfs:label ?orgName .\n",
" ?event rdf:type org:ChangeEvent .\n",
" ?event org:originalOrganization ?origOrg .\n",
" ?event org:resultingOrganization ?resultingOrg .\n",
"}}\n",
"</sparql>\n",
"\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import boto3\n",
"from langchain.chains.graph_qa.neptune_sparql import NeptuneSparqlQAChain\n",
"from langchain.chat_models import BedrockChat\n",
"from langchain_community.graphs import NeptuneRdfGraph\n",
"\n",
"host = \"<neptune-host>\"\n",
"port = \"<neptune-port>\"\n",
"region = \"us-east-1\" # specify region\n",
"\n",
"graph = NeptuneRdfGraph(\n",
" host=host, port=port, use_iam_auth=True, region_name=region, hide_comments=True\n",
")\n",
"\n",
"schema_elements = graph.get_schema_elements\n",
"# Optionally, you can update the schema_elements, and\n",
"# load the schema from the pruned elements.\n",
"graph.load_from_schema_elements(schema_elements)\n",
"\n",
"bedrock_client = boto3.client(\"bedrock-runtime\")\n",
"llm = BedrockChat(model_id=\"anthropic.claude-v2\", client=bedrock_client)\n",
"\n",
"chain = NeptuneSparqlQAChain.from_llm(\n",
" llm=llm,\n",
" graph=graph,\n",
" examples=EXAMPLES,\n",
" verbose=True,\n",
" top_K=10,\n",
" return_intermediate_steps=True,\n",
" return_direct=False,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Ask questions\n",
"Depends on the data we ingested above"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chain.invoke(\"\"\"How many organizations are in the graph\"\"\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chain.invoke(\"\"\"Are there any mergers or acquisitions\"\"\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chain.invoke(\"\"\"Find organizations\"\"\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chain.invoke(\"\"\"Find sites of MegaSystems or MegaFinancial\"\"\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chain.invoke(\"\"\"Find a member who is manager of one or more members.\"\"\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chain.invoke(\"\"\"Find five members and who their manager is.\"\"\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chain.invoke(\n",
" \"\"\"Find org units or suborganizations of The Mega Group. What are the sites of those units?\"\"\"\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -8,6 +8,7 @@ from langchain_community.graphs.memgraph_graph import MemgraphGraph
from langchain_community.graphs.nebula_graph import NebulaGraph
from langchain_community.graphs.neo4j_graph import Neo4jGraph
from langchain_community.graphs.neptune_graph import NeptuneGraph
from langchain_community.graphs.neptune_rdf_graph import NeptuneRdfGraph
from langchain_community.graphs.networkx_graph import NetworkxEntityGraph
from langchain_community.graphs.ontotext_graphdb_graph import OntotextGraphDBGraph
from langchain_community.graphs.rdf_graph import RdfGraph
@ -19,6 +20,7 @@ __all__ = [
"Neo4jGraph",
"NebulaGraph",
"NeptuneGraph",
"NeptuneRdfGraph",
"KuzuGraph",
"HugeGraph",
"RdfGraph",

@ -0,0 +1,256 @@
import json
from types import SimpleNamespace
from typing import Any, Dict, Optional, Sequence
import requests
CLASS_QUERY = """
SELECT DISTINCT ?elem ?com
WHERE {
?instance a ?elem .
OPTIONAL { ?instance rdf:type/rdfs:subClassOf* ?elem } .
#FILTER (isIRI(?elem)) .
OPTIONAL { ?elem rdfs:comment ?com filter (lang(?com) = "en")}
}
"""
REL_QUERY = """
SELECT DISTINCT ?elem ?com
WHERE {
?subj ?elem ?obj .
OPTIONAL {
?elem rdf:type/rdfs:subPropertyOf* ?proptype .
VALUES ?proptype { rdf:Property owl:DatatypeProperty owl:ObjectProperty } .
} .
OPTIONAL { ?elem rdfs:comment ?com filter (lang(?com) = "en")}
}
"""
DTPROP_QUERY = """
SELECT DISTINCT ?elem ?com
WHERE {
?subj ?elem ?obj .
OPTIONAL {
?elem rdf:type/rdfs:subPropertyOf* ?proptype .
?proptype a owl:DatatypeProperty .
} .
OPTIONAL { ?elem rdfs:comment ?com filter (lang(?com) = "en")}
}
"""
OPROP_QUERY = """
SELECT DISTINCT ?elem ?com
WHERE {
?subj ?elem ?obj .
OPTIONAL {
?elem rdf:type/rdfs:subPropertyOf* ?proptype .
?proptype a owl:ObjectProperty .
} .
OPTIONAL { ?elem rdfs:comment ?com filter (lang(?com) = "en")}
}
"""
ELEM_TYPES = {
"classes": CLASS_QUERY,
"rels": REL_QUERY,
"dtprops": DTPROP_QUERY,
"oprops": OPROP_QUERY,
}
class NeptuneRdfGraph:
"""Neptune wrapper for RDF graph operations.
Args:
host: SPARQL endpoint host for Neptune
port: SPARQL endpoint port for Neptune. Defaults 8182.
use_iam_auth: boolean indicating IAM auth is enabled in Neptune cluster
region_name: AWS region required if use_iam_auth is True, e.g., us-west-2
hide_comments: whether to include ontology comments in schema for prompt
Example:
.. code-block:: python
graph = NeptuneRdfGraph(
host='<SPARQL host'>,
port=<SPARQL port>,
use_iam_auth=False
)
schema = graph.get_schema()
OR
graph = NeptuneRdfGraph(
host='<SPARQL host'>,
port=<SPARQL port>,
use_iam_auth=False
)
schema_elem = graph.get_schema_elements()
... change schema_elements ...
graph.load_schema(schema_elem)
schema = graph.get_schema()
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
"""
def __init__(
self,
host: str,
port: int = 8182,
use_iam_auth: bool = False,
region_name: Optional[str] = None,
hide_comments: bool = False,
) -> None:
self.use_iam_auth = use_iam_auth
self.region_name = region_name
self.hide_comments = hide_comments
self.query_endpoint = f"https://{host}:{port}/sparql"
if self.use_iam_auth:
try:
import boto3
self.session = boto3.Session()
except ImportError:
raise ImportError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
else:
self.session = None
# Set schema
self.schema = ""
self.schema_elements: Dict[str, Any] = {}
self._refresh_schema()
@property
def get_schema(self) -> str:
"""
Returns the schema of the graph database.
"""
return self.schema
@property
def get_schema_elements(self) -> Dict[str, Any]:
return self.schema_elements
def query(
self,
query: str,
) -> Dict[str, Any]:
"""
Run Neptune query.
"""
request_data = {"query": query}
data = request_data
request_hdr = None
if self.use_iam_auth:
credentials = self.session.get_credentials()
credentials = credentials.get_frozen_credentials()
access_key = credentials.access_key
secret_key = credentials.secret_key
service = "neptune-db"
session_token = credentials.token
params = None
creds = SimpleNamespace(
access_key=access_key,
secret_key=secret_key,
token=session_token,
region=self.region_name,
)
from botocore.awsrequest import AWSRequest
request = AWSRequest(
method="POST", url=self.query_endpoint, data=data, params=params
)
from botocore.auth import SigV4Auth
SigV4Auth(creds, service, self.region_name).add_auth(request)
request.headers["Content-Type"] = "application/x-www-form-urlencoded"
request_hdr = request.headers
else:
request_hdr = {}
request_hdr["Content-Type"] = "application/x-www-form-urlencoded"
queryres = requests.request(
method="POST", url=self.query_endpoint, headers=request_hdr, data=data
)
json_resp = json.loads(queryres.text)
return json_resp
def load_schema(self, schema_elements: Dict[str, Any]) -> None:
"""
Generates and sets schema from schema_elements. Helpful in
cases where introspected schema needs pruning.
"""
elem_str = {}
for elem in ELEM_TYPES:
res_list = []
for elem_rec in self.schema_elements[elem]:
uri = elem_rec["uri"]
local = elem_rec["local"]
res_str = f"<{uri}> ({local})"
if self.hide_comments is False:
res_str = res_str + f", {elem_rec['comment']}"
res_list.append(res_str)
elem_str[elem] = ", ".join(res_list)
self.schema = (
"In the following, each IRI is followed by the local name and "
"optionally its description in parentheses. \n"
"The graph supports the following node types:\n"
f"{elem_str['classes']}"
"The graph supports the following relationships:\n"
f"{elem_str['rels']}"
"The graph supports the following OWL object properties, "
f"{elem_str['dtprops']}"
"The graph supports the following OWL data properties, "
f"{elem_str['oprops']}"
)
def _get_local_name(self, iri: str) -> Sequence[str]:
"""
Split IRI into prefix and local
"""
if "#" in iri:
tokens = iri.split("#")
return [f"{tokens[0]}#", tokens[-1]]
elif "/" in iri:
tokens = iri.split("/")
return [f"{'/'.join(tokens[0:len(tokens)-1])}/", tokens[-1]]
else:
raise ValueError(f"Unexpected IRI '{iri}', contains neither '#' nor '/'.")
def _refresh_schema(self) -> None:
"""
Query Neptune to introspect schema.
"""
self.schema_elements["distinct_prefixes"] = {}
for elem in ELEM_TYPES:
items = self.query(ELEM_TYPES[elem])
reslist = []
for r in items["results"]["bindings"]:
uri = r["elem"]["value"]
tokens = self._get_local_name(uri)
elem_record = {"uri": uri, "local": tokens[1]}
if not self.hide_comments:
elem_record["comment"] = r["com"]["value"] if "com" in r else ""
reslist.append(elem_record)
if tokens[0] not in self.schema_elements["distinct_prefixes"]:
self.schema_elements["distinct_prefixes"][tokens[0]] = "y"
self.schema_elements[elem] = reslist
self.load_schema(self.schema_elements)

@ -6,6 +6,7 @@ EXPECTED_ALL = [
"Neo4jGraph",
"NebulaGraph",
"NeptuneGraph",
"NeptuneRdfGraph",
"KuzuGraph",
"HugeGraph",
"RdfGraph",

@ -1,2 +1,5 @@
def test_import() -> None:
from langchain_community.graphs import NeptuneGraph # noqa: F401
from langchain_community.graphs import (
NeptuneGraph, # noqa: F401
NeptuneRdfGraph, # noqa: F401
)

@ -41,6 +41,7 @@ from langchain.chains.graph_qa.hugegraph import HugeGraphQAChain
from langchain.chains.graph_qa.kuzu import KuzuQAChain
from langchain.chains.graph_qa.nebulagraph import NebulaGraphQAChain
from langchain.chains.graph_qa.neptune_cypher import NeptuneOpenCypherQAChain
from langchain.chains.graph_qa.neptune_sparql import NeptuneSparqlQAChain
from langchain.chains.graph_qa.ontotext_graphdb import OntotextGraphDBQAChain
from langchain.chains.graph_qa.sparql import GraphSparqlQAChain
from langchain.chains.history_aware_retriever import create_history_aware_retriever
@ -116,6 +117,7 @@ __all__ = [
"NatBotChain",
"NebulaGraphQAChain",
"NeptuneOpenCypherQAChain",
"NeptuneSparqlQAChain",
"OpenAIModerationChain",
"OpenAPIEndpointChain",
"QAGenerationChain",

@ -0,0 +1,196 @@
"""
Question answering over an RDF or OWL graph using SPARQL.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from langchain_community.graphs import NeptuneRdfGraph
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import SPARQL_QA_PROMPT
from langchain.chains.llm import LLMChain
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
SPARQL_GENERATION_TEMPLATE = """
Task: Generate a SPARQL SELECT statement for querying a graph database.
For instance, to find all email addresses of John Doe, the following
query in backticks would be suitable:
```
PREFIX foaf: <http://xmlns.com/foaf/0.1/>
SELECT ?email
WHERE {{
?person foaf:name "John Doe" .
?person foaf:mbox ?email .
}}
```
Instructions:
Use only the node types and properties provided in the schema.
Do not use any node types and properties that are not explicitly provided.
Include all necessary prefixes.
Examples:
Schema:
{schema}
Note: Be as concise as possible.
Do not include any explanations or apologies in your responses.
Do not respond to any questions that ask for anything else than
for you to construct a SPARQL query.
Do not include any text except the SPARQL query generated.
The question is:
{prompt}"""
SPARQL_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "prompt"], template=SPARQL_GENERATION_TEMPLATE
)
def extract_sparql(query: str) -> str:
query = query.strip()
querytoks = query.split("```")
if len(querytoks) == 3:
query = querytoks[1]
if query.startswith("sparql"):
query = query[6:]
elif query.startswith("<sparql>") and query.endswith("</sparql>"):
query = query[8:-9]
return query
class NeptuneSparqlQAChain(Chain):
"""Chain for question-answering against a Neptune graph
by generating SPARQL statements.
*Security note*: Make sure that the database connection uses credentials
that are narrowly-scoped to only include necessary permissions.
Failure to do so may result in data corruption or loss, since the calling
code may attempt commands that would result in deletion, mutation
of data if appropriately prompted or reading sensitive data if such
data is present in the database.
The best way to guard against such negative outcomes is to (as appropriate)
limit the permissions granted to the credentials used with this tool.
See https://python.langchain.com/docs/security for more information.
Example:
.. code-block:: python
chain = NeptuneSparqlQAChain.from_llm(
llm=llm,
graph=graph
)
response = chain.invoke(query)
"""
graph: NeptuneRdfGraph = Field(exclude=True)
sparql_generation_chain: LLMChain
qa_chain: LLMChain
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
top_k: int = 10
return_intermediate_steps: bool = False
"""Whether or not to return the intermediate steps along with the final answer."""
return_direct: bool = False
"""Whether or not to return the result of querying the graph directly."""
extra_instructions: Optional[str] = None
"""Extra instructions by the appended to the query generation prompt."""
@property
def input_keys(self) -> List[str]:
return [self.input_key]
@property
def output_keys(self) -> List[str]:
_output_keys = [self.output_key]
return _output_keys
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
*,
qa_prompt: BasePromptTemplate = SPARQL_QA_PROMPT,
sparql_prompt: BasePromptTemplate = SPARQL_GENERATION_PROMPT,
examples: Optional[str] = None,
**kwargs: Any,
) -> NeptuneSparqlQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
template_to_use = SPARQL_GENERATION_TEMPLATE
if examples:
template_to_use = template_to_use.replace(
"Examples:", "Examples: " + examples
)
sparql_prompt = PromptTemplate(
input_variables=["schema", "prompt"], template=template_to_use
)
sparql_generation_chain = LLMChain(llm=llm, prompt=sparql_prompt)
return cls(
qa_chain=qa_chain,
sparql_generation_chain=sparql_generation_chain,
examples=examples,
**kwargs,
)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
"""
Generate SPARQL query, use it to retrieve a response from the gdb and answer
the question.
"""
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
prompt = inputs[self.input_key]
intermediate_steps: List = []
generated_sparql = self.sparql_generation_chain.run(
{"prompt": prompt, "schema": self.graph.get_schema}, callbacks=callbacks
)
# Extract SPARQL
generated_sparql = extract_sparql(generated_sparql)
_run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose)
_run_manager.on_text(
generated_sparql, color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"query": generated_sparql})
context = self.graph.query(generated_sparql)
if self.return_direct:
final_result = context
else:
_run_manager.on_text("Full Context:", end="\n", verbose=self.verbose)
_run_manager.on_text(
str(context), color="green", end="\n", verbose=self.verbose
)
intermediate_steps.append({"context": context})
result = self.qa_chain(
{"prompt": prompt, "context": context},
callbacks=callbacks,
)
final_result = result[self.qa_chain.output_key]
chain_result: Dict[str, Any] = {self.output_key: final_result}
if self.return_intermediate_steps:
chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps
return chain_result

@ -33,6 +33,7 @@ EXPECTED_ALL = [
"NatBotChain",
"NebulaGraphQAChain",
"NeptuneOpenCypherQAChain",
"NeptuneSparqlQAChain",
"OpenAIModerationChain",
"OpenAPIEndpointChain",
"QAGenerationChain",

Loading…
Cancel
Save