From 1bbb64d956210fe57d234be98af30436d518d4d9 Mon Sep 17 00:00:00 2001 From: mhavey Date: Tue, 13 Feb 2024 00:30:20 -0500 Subject: [PATCH] community[minor], langchian[minor]: Add Neptune Rdf graph and chain (#16650) **Description**: This PR adds a chain for Amazon Neptune graph database RDF format. It complements the existing Neptune Cypher chain. The PR also includes a Neptune RDF graph class to connect to, introspect, and query a Neptune RDF graph database from the chain. A sample notebook is provided under docs that demonstrates the overall effect: invoking the chain to make natural language queries against Neptune using an LLM. **Issue**: This is a new feature **Dependencies**: The RDF graph class depends on the AWS boto3 library if using IAM authentication to connect to the Neptune database. --------- Co-authored-by: Piyush Jain Co-authored-by: Bagatur --- .../use_cases/graph/neptune_sparql_qa.ipynb | 337 ++++++++++++++++++ .../langchain_community/graphs/__init__.py | 2 + .../graphs/neptune_rdf_graph.py | 256 +++++++++++++ .../tests/unit_tests/graphs/test_imports.py | 1 + .../unit_tests/graphs/test_neptune_graph.py | 5 +- libs/langchain/langchain/chains/__init__.py | 2 + .../chains/graph_qa/neptune_sparql.py | 196 ++++++++++ .../tests/unit_tests/chains/test_imports.py | 1 + 8 files changed, 799 insertions(+), 1 deletion(-) create mode 100644 docs/docs/use_cases/graph/neptune_sparql_qa.ipynb create mode 100644 libs/community/langchain_community/graphs/neptune_rdf_graph.py create mode 100644 libs/langchain/langchain/chains/graph_qa/neptune_sparql.py diff --git a/docs/docs/use_cases/graph/neptune_sparql_qa.ipynb b/docs/docs/use_cases/graph/neptune_sparql_qa.ipynb new file mode 100644 index 0000000000..4c464e840d --- /dev/null +++ b/docs/docs/use_cases/graph/neptune_sparql_qa.ipynb @@ -0,0 +1,337 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Neptune SPARQL QA Chain\n", + "\n", + "This notebook shows use of LLM to query RDF graph in Amazon Neptune. This code uses a `NeptuneRdfGraph` class that connects with the Neptune database and loads it's schema. The `NeptuneSparqlQAChain` is used to connect the graph and LLM to ask natural language questions.\n", + "\n", + "Requirements for running this notebook:\n", + "- Neptune 1.2.x cluster accessible from this notebook\n", + "- Kernel with Python 3.9 or higher\n", + "- For Bedrock access, ensure IAM role has this policy\n", + "\n", + "```json\n", + "{\n", + " \"Action\": [\n", + " \"bedrock:ListFoundationModels\",\n", + " \"bedrock:InvokeModel\"\n", + " ],\n", + " \"Resource\": \"*\",\n", + " \"Effect\": \"Allow\"\n", + "}\n", + "```\n", + "\n", + "- S3 bucket for staging sample data, bucket should be in same account/region as Neptune." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Seed W3C organizational data\n", + "W3C org ontology plus some instances. \n", + "\n", + "You will need an S3 bucket in the same region and account. Set STAGE_BUCKET to name of that bucket." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "STAGE_BUCKET = \"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%bash -s \"$STAGE_BUCKET\"\n", + "\n", + "rm -rf data\n", + "mkdir -p data\n", + "cd data\n", + "echo getting org ontology and sample org instances\n", + "wget http://www.w3.org/ns/org.ttl \n", + "wget https://raw.githubusercontent.com/aws-samples/amazon-neptune-ontology-example-blog/main/data/example_org.ttl \n", + "\n", + "echo Copying org ttl to S3\n", + "aws s3 cp org.ttl s3://$1/org.ttl\n", + "aws s3 cp example_org.ttl s3://$1/example_org.ttl\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Bulk-load the org ttl - both ontology and instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load -s s3://{STAGE_BUCKET} -f turtle --store-to loadres --run" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_status {loadres['payload']['loadId']} --errors --details" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup Chain" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "EXAMPLES = \"\"\"\n", + "\n", + "\n", + "Find organizations.\n", + "\n", + "\n", + "\n", + "PREFIX rdf: \n", + "PREFIX rdfs: \n", + "PREFIX org: \n", + "\n", + "select ?org ?orgName where {{\n", + " ?org rdfs:label ?orgName .\n", + "}} \n", + "\n", + "\n", + "\n", + "Find sites of an organization\n", + "\n", + "\n", + "\n", + "PREFIX rdf: \n", + "PREFIX rdfs: \n", + "PREFIX org: \n", + "\n", + "select ?org ?orgName ?siteName where {{\n", + " ?org rdfs:label ?orgName .\n", + " ?org org:hasSite/rdfs:label ?siteName . \n", + "}} \n", + "\n", + "\n", + "\n", + "Find suborganizations of an organization\n", + "\n", + "\n", + "\n", + "PREFIX rdf: \n", + "PREFIX rdfs: \n", + "PREFIX org: \n", + "\n", + "select ?org ?orgName ?subName where {{\n", + " ?org rdfs:label ?orgName .\n", + " ?org org:hasSubOrganization/rdfs:label ?subName .\n", + "}} \n", + "\n", + "\n", + "\n", + "Find organizational units of an organization\n", + "\n", + "\n", + "\n", + "PREFIX rdf: \n", + "PREFIX rdfs: \n", + "PREFIX org: \n", + "\n", + "select ?org ?orgName ?unitName where {{\n", + " ?org rdfs:label ?orgName .\n", + " ?org org:hasUnit/rdfs:label ?unitName . \n", + "}} \n", + "\n", + "\n", + "\n", + "Find members of an organization. Also find their manager, or the member they report to.\n", + "\n", + "\n", + "\n", + "PREFIX org: \n", + "PREFIX foaf: \n", + "\n", + "select * where {{\n", + " ?person rdf:type foaf:Person .\n", + " ?person org:memberOf ?org .\n", + " OPTIONAL {{ ?person foaf:firstName ?firstName . }}\n", + " OPTIONAL {{ ?person foaf:family_name ?lastName . }}\n", + " OPTIONAL {{ ?person org:reportsTo ??manager }} .\n", + "}}\n", + "\n", + "\n", + "\n", + "\n", + "Find change events, such as mergers and acquisitions, of an organization\n", + "\n", + "\n", + "\n", + "PREFIX org: \n", + "\n", + "select ?event ?prop ?obj where {{\n", + " ?org rdfs:label ?orgName .\n", + " ?event rdf:type org:ChangeEvent .\n", + " ?event org:originalOrganization ?origOrg .\n", + " ?event org:resultingOrganization ?resultingOrg .\n", + "}}\n", + "\n", + "\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import boto3\n", + "from langchain.chains.graph_qa.neptune_sparql import NeptuneSparqlQAChain\n", + "from langchain.chat_models import BedrockChat\n", + "from langchain_community.graphs import NeptuneRdfGraph\n", + "\n", + "host = \"\"\n", + "port = \"\"\n", + "region = \"us-east-1\" # specify region\n", + "\n", + "graph = NeptuneRdfGraph(\n", + " host=host, port=port, use_iam_auth=True, region_name=region, hide_comments=True\n", + ")\n", + "\n", + "schema_elements = graph.get_schema_elements\n", + "# Optionally, you can update the schema_elements, and\n", + "# load the schema from the pruned elements.\n", + "graph.load_from_schema_elements(schema_elements)\n", + "\n", + "bedrock_client = boto3.client(\"bedrock-runtime\")\n", + "llm = BedrockChat(model_id=\"anthropic.claude-v2\", client=bedrock_client)\n", + "\n", + "chain = NeptuneSparqlQAChain.from_llm(\n", + " llm=llm,\n", + " graph=graph,\n", + " examples=EXAMPLES,\n", + " verbose=True,\n", + " top_K=10,\n", + " return_intermediate_steps=True,\n", + " return_direct=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Ask questions\n", + "Depends on the data we ingested above" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "chain.invoke(\"\"\"How many organizations are in the graph\"\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "chain.invoke(\"\"\"Are there any mergers or acquisitions\"\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "chain.invoke(\"\"\"Find organizations\"\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "chain.invoke(\"\"\"Find sites of MegaSystems or MegaFinancial\"\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "chain.invoke(\"\"\"Find a member who is manager of one or more members.\"\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "chain.invoke(\"\"\"Find five members and who their manager is.\"\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "chain.invoke(\n", + " \"\"\"Find org units or suborganizations of The Mega Group. What are the sites of those units?\"\"\"\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/community/langchain_community/graphs/__init__.py b/libs/community/langchain_community/graphs/__init__.py index bd15f6465d..c1fc640c4b 100644 --- a/libs/community/langchain_community/graphs/__init__.py +++ b/libs/community/langchain_community/graphs/__init__.py @@ -8,6 +8,7 @@ from langchain_community.graphs.memgraph_graph import MemgraphGraph from langchain_community.graphs.nebula_graph import NebulaGraph from langchain_community.graphs.neo4j_graph import Neo4jGraph from langchain_community.graphs.neptune_graph import NeptuneGraph +from langchain_community.graphs.neptune_rdf_graph import NeptuneRdfGraph from langchain_community.graphs.networkx_graph import NetworkxEntityGraph from langchain_community.graphs.ontotext_graphdb_graph import OntotextGraphDBGraph from langchain_community.graphs.rdf_graph import RdfGraph @@ -19,6 +20,7 @@ __all__ = [ "Neo4jGraph", "NebulaGraph", "NeptuneGraph", + "NeptuneRdfGraph", "KuzuGraph", "HugeGraph", "RdfGraph", diff --git a/libs/community/langchain_community/graphs/neptune_rdf_graph.py b/libs/community/langchain_community/graphs/neptune_rdf_graph.py new file mode 100644 index 0000000000..b9e0074a36 --- /dev/null +++ b/libs/community/langchain_community/graphs/neptune_rdf_graph.py @@ -0,0 +1,256 @@ +import json +from types import SimpleNamespace +from typing import Any, Dict, Optional, Sequence + +import requests + +CLASS_QUERY = """ +SELECT DISTINCT ?elem ?com +WHERE { + ?instance a ?elem . + OPTIONAL { ?instance rdf:type/rdfs:subClassOf* ?elem } . + #FILTER (isIRI(?elem)) . + OPTIONAL { ?elem rdfs:comment ?com filter (lang(?com) = "en")} +} +""" + +REL_QUERY = """ +SELECT DISTINCT ?elem ?com +WHERE { + ?subj ?elem ?obj . + OPTIONAL { + ?elem rdf:type/rdfs:subPropertyOf* ?proptype . + VALUES ?proptype { rdf:Property owl:DatatypeProperty owl:ObjectProperty } . + } . + OPTIONAL { ?elem rdfs:comment ?com filter (lang(?com) = "en")} +} +""" + +DTPROP_QUERY = """ +SELECT DISTINCT ?elem ?com +WHERE { + ?subj ?elem ?obj . + OPTIONAL { + ?elem rdf:type/rdfs:subPropertyOf* ?proptype . + ?proptype a owl:DatatypeProperty . + } . + OPTIONAL { ?elem rdfs:comment ?com filter (lang(?com) = "en")} +} +""" + +OPROP_QUERY = """ +SELECT DISTINCT ?elem ?com +WHERE { + ?subj ?elem ?obj . + OPTIONAL { + ?elem rdf:type/rdfs:subPropertyOf* ?proptype . + ?proptype a owl:ObjectProperty . + } . + OPTIONAL { ?elem rdfs:comment ?com filter (lang(?com) = "en")} +} +""" + +ELEM_TYPES = { + "classes": CLASS_QUERY, + "rels": REL_QUERY, + "dtprops": DTPROP_QUERY, + "oprops": OPROP_QUERY, +} + + +class NeptuneRdfGraph: + """Neptune wrapper for RDF graph operations. + + Args: + host: SPARQL endpoint host for Neptune + port: SPARQL endpoint port for Neptune. Defaults 8182. + use_iam_auth: boolean indicating IAM auth is enabled in Neptune cluster + region_name: AWS region required if use_iam_auth is True, e.g., us-west-2 + hide_comments: whether to include ontology comments in schema for prompt + + Example: + .. code-block:: python + + graph = NeptuneRdfGraph( + host=', + port=, + use_iam_auth=False + ) + schema = graph.get_schema() + + OR + graph = NeptuneRdfGraph( + host=', + port=, + use_iam_auth=False + ) + schema_elem = graph.get_schema_elements() + ... change schema_elements ... + graph.load_schema(schema_elem) + schema = graph.get_schema() + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + """ + + def __init__( + self, + host: str, + port: int = 8182, + use_iam_auth: bool = False, + region_name: Optional[str] = None, + hide_comments: bool = False, + ) -> None: + self.use_iam_auth = use_iam_auth + self.region_name = region_name + self.hide_comments = hide_comments + self.query_endpoint = f"https://{host}:{port}/sparql" + + if self.use_iam_auth: + try: + import boto3 + + self.session = boto3.Session() + except ImportError: + raise ImportError( + "Could not import boto3 python package. " + "Please install it with `pip install boto3`." + ) + else: + self.session = None + + # Set schema + self.schema = "" + self.schema_elements: Dict[str, Any] = {} + self._refresh_schema() + + @property + def get_schema(self) -> str: + """ + Returns the schema of the graph database. + """ + return self.schema + + @property + def get_schema_elements(self) -> Dict[str, Any]: + return self.schema_elements + + def query( + self, + query: str, + ) -> Dict[str, Any]: + """ + Run Neptune query. + """ + request_data = {"query": query} + data = request_data + request_hdr = None + + if self.use_iam_auth: + credentials = self.session.get_credentials() + credentials = credentials.get_frozen_credentials() + access_key = credentials.access_key + secret_key = credentials.secret_key + service = "neptune-db" + session_token = credentials.token + params = None + creds = SimpleNamespace( + access_key=access_key, + secret_key=secret_key, + token=session_token, + region=self.region_name, + ) + from botocore.awsrequest import AWSRequest + + request = AWSRequest( + method="POST", url=self.query_endpoint, data=data, params=params + ) + from botocore.auth import SigV4Auth + + SigV4Auth(creds, service, self.region_name).add_auth(request) + request.headers["Content-Type"] = "application/x-www-form-urlencoded" + request_hdr = request.headers + else: + request_hdr = {} + request_hdr["Content-Type"] = "application/x-www-form-urlencoded" + + queryres = requests.request( + method="POST", url=self.query_endpoint, headers=request_hdr, data=data + ) + json_resp = json.loads(queryres.text) + return json_resp + + def load_schema(self, schema_elements: Dict[str, Any]) -> None: + """ + Generates and sets schema from schema_elements. Helpful in + cases where introspected schema needs pruning. + """ + + elem_str = {} + for elem in ELEM_TYPES: + res_list = [] + for elem_rec in self.schema_elements[elem]: + uri = elem_rec["uri"] + local = elem_rec["local"] + res_str = f"<{uri}> ({local})" + if self.hide_comments is False: + res_str = res_str + f", {elem_rec['comment']}" + res_list.append(res_str) + elem_str[elem] = ", ".join(res_list) + + self.schema = ( + "In the following, each IRI is followed by the local name and " + "optionally its description in parentheses. \n" + "The graph supports the following node types:\n" + f"{elem_str['classes']}" + "The graph supports the following relationships:\n" + f"{elem_str['rels']}" + "The graph supports the following OWL object properties, " + f"{elem_str['dtprops']}" + "The graph supports the following OWL data properties, " + f"{elem_str['oprops']}" + ) + + def _get_local_name(self, iri: str) -> Sequence[str]: + """ + Split IRI into prefix and local + """ + if "#" in iri: + tokens = iri.split("#") + return [f"{tokens[0]}#", tokens[-1]] + elif "/" in iri: + tokens = iri.split("/") + return [f"{'/'.join(tokens[0:len(tokens)-1])}/", tokens[-1]] + else: + raise ValueError(f"Unexpected IRI '{iri}', contains neither '#' nor '/'.") + + def _refresh_schema(self) -> None: + """ + Query Neptune to introspect schema. + """ + self.schema_elements["distinct_prefixes"] = {} + + for elem in ELEM_TYPES: + items = self.query(ELEM_TYPES[elem]) + reslist = [] + for r in items["results"]["bindings"]: + uri = r["elem"]["value"] + tokens = self._get_local_name(uri) + elem_record = {"uri": uri, "local": tokens[1]} + if not self.hide_comments: + elem_record["comment"] = r["com"]["value"] if "com" in r else "" + reslist.append(elem_record) + if tokens[0] not in self.schema_elements["distinct_prefixes"]: + self.schema_elements["distinct_prefixes"][tokens[0]] = "y" + + self.schema_elements[elem] = reslist + + self.load_schema(self.schema_elements) diff --git a/libs/community/tests/unit_tests/graphs/test_imports.py b/libs/community/tests/unit_tests/graphs/test_imports.py index 653d7d540b..202ecefa24 100644 --- a/libs/community/tests/unit_tests/graphs/test_imports.py +++ b/libs/community/tests/unit_tests/graphs/test_imports.py @@ -6,6 +6,7 @@ EXPECTED_ALL = [ "Neo4jGraph", "NebulaGraph", "NeptuneGraph", + "NeptuneRdfGraph", "KuzuGraph", "HugeGraph", "RdfGraph", diff --git a/libs/community/tests/unit_tests/graphs/test_neptune_graph.py b/libs/community/tests/unit_tests/graphs/test_neptune_graph.py index e3d986f2eb..6e714a4166 100644 --- a/libs/community/tests/unit_tests/graphs/test_neptune_graph.py +++ b/libs/community/tests/unit_tests/graphs/test_neptune_graph.py @@ -1,2 +1,5 @@ def test_import() -> None: - from langchain_community.graphs import NeptuneGraph # noqa: F401 + from langchain_community.graphs import ( + NeptuneGraph, # noqa: F401 + NeptuneRdfGraph, # noqa: F401 + ) diff --git a/libs/langchain/langchain/chains/__init__.py b/libs/langchain/langchain/chains/__init__.py index 2b7ba6ac25..b20d3fecb4 100644 --- a/libs/langchain/langchain/chains/__init__.py +++ b/libs/langchain/langchain/chains/__init__.py @@ -41,6 +41,7 @@ from langchain.chains.graph_qa.hugegraph import HugeGraphQAChain from langchain.chains.graph_qa.kuzu import KuzuQAChain from langchain.chains.graph_qa.nebulagraph import NebulaGraphQAChain from langchain.chains.graph_qa.neptune_cypher import NeptuneOpenCypherQAChain +from langchain.chains.graph_qa.neptune_sparql import NeptuneSparqlQAChain from langchain.chains.graph_qa.ontotext_graphdb import OntotextGraphDBQAChain from langchain.chains.graph_qa.sparql import GraphSparqlQAChain from langchain.chains.history_aware_retriever import create_history_aware_retriever @@ -116,6 +117,7 @@ __all__ = [ "NatBotChain", "NebulaGraphQAChain", "NeptuneOpenCypherQAChain", + "NeptuneSparqlQAChain", "OpenAIModerationChain", "OpenAPIEndpointChain", "QAGenerationChain", diff --git a/libs/langchain/langchain/chains/graph_qa/neptune_sparql.py b/libs/langchain/langchain/chains/graph_qa/neptune_sparql.py new file mode 100644 index 0000000000..08a1cc249b --- /dev/null +++ b/libs/langchain/langchain/chains/graph_qa/neptune_sparql.py @@ -0,0 +1,196 @@ +""" +Question answering over an RDF or OWL graph using SPARQL. +""" +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from langchain_community.graphs import NeptuneRdfGraph +from langchain_core.language_models import BaseLanguageModel +from langchain_core.prompts.base import BasePromptTemplate +from langchain_core.prompts.prompt import PromptTemplate +from langchain_core.pydantic_v1 import Field + +from langchain.callbacks.manager import CallbackManagerForChainRun +from langchain.chains.base import Chain +from langchain.chains.graph_qa.prompts import SPARQL_QA_PROMPT +from langchain.chains.llm import LLMChain + +INTERMEDIATE_STEPS_KEY = "intermediate_steps" + +SPARQL_GENERATION_TEMPLATE = """ +Task: Generate a SPARQL SELECT statement for querying a graph database. +For instance, to find all email addresses of John Doe, the following +query in backticks would be suitable: +``` +PREFIX foaf: +SELECT ?email +WHERE {{ + ?person foaf:name "John Doe" . + ?person foaf:mbox ?email . +}} +``` +Instructions: +Use only the node types and properties provided in the schema. +Do not use any node types and properties that are not explicitly provided. +Include all necessary prefixes. + +Examples: + +Schema: +{schema} +Note: Be as concise as possible. +Do not include any explanations or apologies in your responses. +Do not respond to any questions that ask for anything else than +for you to construct a SPARQL query. +Do not include any text except the SPARQL query generated. + +The question is: +{prompt}""" + +SPARQL_GENERATION_PROMPT = PromptTemplate( + input_variables=["schema", "prompt"], template=SPARQL_GENERATION_TEMPLATE +) + + +def extract_sparql(query: str) -> str: + query = query.strip() + querytoks = query.split("```") + if len(querytoks) == 3: + query = querytoks[1] + + if query.startswith("sparql"): + query = query[6:] + elif query.startswith("") and query.endswith(""): + query = query[8:-9] + return query + + +class NeptuneSparqlQAChain(Chain): + """Chain for question-answering against a Neptune graph + by generating SPARQL statements. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + + Example: + .. code-block:: python + + chain = NeptuneSparqlQAChain.from_llm( + llm=llm, + graph=graph + ) + response = chain.invoke(query) + """ + + graph: NeptuneRdfGraph = Field(exclude=True) + sparql_generation_chain: LLMChain + qa_chain: LLMChain + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + top_k: int = 10 + return_intermediate_steps: bool = False + """Whether or not to return the intermediate steps along with the final answer.""" + return_direct: bool = False + """Whether or not to return the result of querying the graph directly.""" + extra_instructions: Optional[str] = None + """Extra instructions by the appended to the query generation prompt.""" + + @property + def input_keys(self) -> List[str]: + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + _output_keys = [self.output_key] + return _output_keys + + @classmethod + def from_llm( + cls, + llm: BaseLanguageModel, + *, + qa_prompt: BasePromptTemplate = SPARQL_QA_PROMPT, + sparql_prompt: BasePromptTemplate = SPARQL_GENERATION_PROMPT, + examples: Optional[str] = None, + **kwargs: Any, + ) -> NeptuneSparqlQAChain: + """Initialize from LLM.""" + qa_chain = LLMChain(llm=llm, prompt=qa_prompt) + template_to_use = SPARQL_GENERATION_TEMPLATE + if examples: + template_to_use = template_to_use.replace( + "Examples:", "Examples: " + examples + ) + sparql_prompt = PromptTemplate( + input_variables=["schema", "prompt"], template=template_to_use + ) + sparql_generation_chain = LLMChain(llm=llm, prompt=sparql_prompt) + + return cls( + qa_chain=qa_chain, + sparql_generation_chain=sparql_generation_chain, + examples=examples, + **kwargs, + ) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, str]: + """ + Generate SPARQL query, use it to retrieve a response from the gdb and answer + the question. + """ + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + prompt = inputs[self.input_key] + + intermediate_steps: List = [] + + generated_sparql = self.sparql_generation_chain.run( + {"prompt": prompt, "schema": self.graph.get_schema}, callbacks=callbacks + ) + + # Extract SPARQL + generated_sparql = extract_sparql(generated_sparql) + + _run_manager.on_text("Generated SPARQL:", end="\n", verbose=self.verbose) + _run_manager.on_text( + generated_sparql, color="green", end="\n", verbose=self.verbose + ) + + intermediate_steps.append({"query": generated_sparql}) + + context = self.graph.query(generated_sparql) + + if self.return_direct: + final_result = context + else: + _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) + _run_manager.on_text( + str(context), color="green", end="\n", verbose=self.verbose + ) + + intermediate_steps.append({"context": context}) + + result = self.qa_chain( + {"prompt": prompt, "context": context}, + callbacks=callbacks, + ) + final_result = result[self.qa_chain.output_key] + + chain_result: Dict[str, Any] = {self.output_key: final_result} + if self.return_intermediate_steps: + chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps + + return chain_result diff --git a/libs/langchain/tests/unit_tests/chains/test_imports.py b/libs/langchain/tests/unit_tests/chains/test_imports.py index 8317dd62ea..cf76a851b8 100644 --- a/libs/langchain/tests/unit_tests/chains/test_imports.py +++ b/libs/langchain/tests/unit_tests/chains/test_imports.py @@ -33,6 +33,7 @@ EXPECTED_ALL = [ "NatBotChain", "NebulaGraphQAChain", "NeptuneOpenCypherQAChain", + "NeptuneSparqlQAChain", "OpenAIModerationChain", "OpenAPIEndpointChain", "QAGenerationChain",