Updated Neptune graph to use boto (#10121)

## Description
This PR updates the `NeptuneGraph` class to start using the boto API for
connecting to the Neptune service. With boto integration, the graph
class now supports authenticating requests using Sigv4; this is
encapsulated with the boto API, and users only have to ensure they have
the correct AWS credentials setup in their workspace to work with the
graph class.

This PR also introduces a conditional prompt that uses a simpler prompt
when using the `Anthropic` model provider. A simpler prompt have seemed
to work better for generating cypher queries in our testing.

**Note**: This version will require boto3 version 1.28.38 or greater to
work.
This commit is contained in:
Piyush Jain 2023-09-19 16:03:08 -07:00 committed by GitHub
parent 33781ac4a2
commit 94cf71ecfa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 147 additions and 41 deletions

View File

@ -9,8 +9,10 @@ from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import ( from langchain.chains.graph_qa.prompts import (
CYPHER_QA_PROMPT, CYPHER_QA_PROMPT,
NEPTUNE_OPENCYPHER_GENERATION_PROMPT, NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT,
) )
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.prompt_selector import ConditionalPromptSelector
from langchain.graphs import NeptuneGraph from langchain.graphs import NeptuneGraph
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.pydantic_v1 import Field from langchain.pydantic_v1 import Field
@ -18,6 +20,37 @@ from langchain.pydantic_v1 import Field
INTERMEDIATE_STEPS_KEY = "intermediate_steps" INTERMEDIATE_STEPS_KEY = "intermediate_steps"
def trim_query(query: str) -> str:
keywords = (
"CALL",
"CREATE",
"DELETE",
"DETACH",
"LIMIT",
"MATCH",
"MERGE",
"OPTIONAL",
"ORDER",
"REMOVE",
"RETURN",
"SET",
"SKIP",
"UNWIND",
"WITH",
"WHERE",
"//",
)
lines = query.split("\n")
new_query = ""
for line in lines:
if line.strip().upper().startswith(keywords):
new_query += line + "\n"
return new_query
def extract_cypher(text: str) -> str: def extract_cypher(text: str) -> str:
"""Extract Cypher code from text using Regex.""" """Extract Cypher code from text using Regex."""
# The pattern to find Cypher code enclosed in triple backticks # The pattern to find Cypher code enclosed in triple backticks
@ -29,6 +62,24 @@ def extract_cypher(text: str) -> str:
return matches[0] if matches else text return matches[0] if matches else text
def use_simple_prompt(llm: BaseLanguageModel) -> bool:
"""Decides whether to use the simple prompt"""
if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore
return True
# Bedrock anthropic
if llm.model_id and "anthropic" in llm.model_id: # type: ignore
return True
return False
PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
conditionals=[(use_simple_prompt, NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT)],
)
class NeptuneOpenCypherQAChain(Chain): class NeptuneOpenCypherQAChain(Chain):
"""Chain for question-answering against a Neptune graph """Chain for question-answering against a Neptune graph
by generating openCypher statements. by generating openCypher statements.
@ -77,12 +128,14 @@ class NeptuneOpenCypherQAChain(Chain):
llm: BaseLanguageModel, llm: BaseLanguageModel,
*, *,
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
cypher_prompt: BasePromptTemplate = NEPTUNE_OPENCYPHER_GENERATION_PROMPT, cypher_prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any, **kwargs: Any,
) -> NeptuneOpenCypherQAChain: ) -> NeptuneOpenCypherQAChain:
"""Initialize from LLM.""" """Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt) qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt)
_cypher_prompt = cypher_prompt or PROMPT_SELECTOR.get_prompt(llm)
cypher_generation_chain = LLMChain(llm=llm, prompt=_cypher_prompt)
return cls( return cls(
qa_chain=qa_chain, qa_chain=qa_chain,
@ -108,6 +161,7 @@ class NeptuneOpenCypherQAChain(Chain):
# Extract Cypher code if it is wrapped in backticks # Extract Cypher code if it is wrapped in backticks
generated_cypher = extract_cypher(generated_cypher) generated_cypher = extract_cypher(generated_cypher)
generated_cypher = trim_query(generated_cypher)
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose) _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
_run_manager.on_text( _run_manager.on_text(

View File

@ -331,3 +331,15 @@ NEPTUNE_OPENCYPHER_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "question"], input_variables=["schema", "question"],
template=NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE, template=NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE,
) )
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE = """
Write an openCypher query to answer the following question. Do not explain the answer. Only return the query.
Question: "{question}".
Here is the property graph schema:
{schema}
\n"""
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT = PromptTemplate(
input_variables=["schema", "question"],
template=NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE,
)

View File

@ -1,7 +1,4 @@
import json from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Tuple, Union
import requests
class NeptuneQueryException(Exception): class NeptuneQueryException(Exception):
@ -23,8 +20,16 @@ class NeptuneQueryException(Exception):
class NeptuneGraph: class NeptuneGraph:
"""Neptune wrapper for graph operations. This version """Neptune wrapper for graph operations.
does not support Sigv4 signing of requests.
Args:
host: endpoint for the database instance
port: port number for the database instance, default is 8182
use_https: whether to use secure connection, default is True
client: optional boto3 Neptune client
credentials_profile_name: optional AWS profile name
region_name: optional AWS region, e.g., us-west-2
service: optional service name, default is neptunedata
Example: Example:
.. code-block:: python .. code-block:: python
@ -35,25 +40,67 @@ class NeptuneGraph:
) )
""" """
def __init__(self, host: str, port: int = 8182, use_https: bool = True) -> None: def __init__(
self,
host: str,
port: int = 8182,
use_https: bool = True,
client: Any = None,
credentials_profile_name: Optional[str] = None,
region_name: Optional[str] = None,
service: str = "neptunedata",
) -> None:
"""Create a new Neptune graph wrapper instance.""" """Create a new Neptune graph wrapper instance."""
if use_https: try:
self.summary_url = ( if client is not None:
f"https://{host}:{port}/pg/statistics/summary?mode=detailed" self.client = client
)
self.query_url = f"https://{host}:{port}/openCypher"
else: else:
self.summary_url = ( import boto3
f"http://{host}:{port}/pg/statistics/summary?mode=detailed"
) if credentials_profile_name is not None:
self.query_url = f"http://{host}:{port}/openCypher" session = boto3.Session(profile_name=credentials_profile_name)
else:
# use default credentials
session = boto3.Session()
client_params = {}
if region_name:
client_params["region_name"] = region_name
protocol = "https" if use_https else "http"
client_params["endpoint_url"] = f"{protocol}://{host}:{port}"
self.client = session.client(service, **client_params)
except ImportError:
raise ModuleNotFoundError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
except Exception as e:
if type(e).__name__ == "UnknownServiceError":
raise ModuleNotFoundError(
"NeptuneGraph requires a boto3 version 1.28.38 or greater."
"Please install it with `pip install -U boto3`."
) from e
else:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e
# Set schema
try: try:
self._refresh_schema() self._refresh_schema()
except NeptuneQueryException: except Exception as e:
raise ValueError("Could not get schema for Neptune database") raise NeptuneQueryException(
{
"message": "Could not get schema for Neptune database",
"detail": str(e),
}
)
@property @property
def get_schema(self) -> str: def get_schema(self) -> str:
@ -62,32 +109,24 @@ class NeptuneGraph:
def query(self, query: str, params: dict = {}) -> Dict[str, Any]: def query(self, query: str, params: dict = {}) -> Dict[str, Any]:
"""Query Neptune database.""" """Query Neptune database."""
response = requests.post(url=self.query_url, data={"query": query}) return self.client.execute_open_cypher_query(openCypherQuery=query)
if response.ok:
results = json.loads(response.content.decode())
return results
else:
raise NeptuneQueryException(
{
"message": "The generated query failed to execute",
"details": response.content.decode(),
}
)
def _get_summary(self) -> Dict: def _get_summary(self) -> Dict:
response = requests.get(url=self.summary_url) try:
if not response.ok: response = self.client.get_propertygraph_summary()
except Exception as e:
raise NeptuneQueryException( raise NeptuneQueryException(
{ {
"message": ( "message": (
"Summary API is not available for this instance of Neptune," "Summary API is not available for this instance of Neptune,"
"ensure the engine version is >=1.2.1.0" "ensure the engine version is >=1.2.1.0"
), ),
"details": response.content.decode(), "details": str(e),
} }
) )
try: try:
summary = response.json()["payload"]["graphSummary"] summary = response["payload"]["graphSummary"]
except Exception: except Exception:
raise NeptuneQueryException( raise NeptuneQueryException(
{ {
@ -107,13 +146,13 @@ class NeptuneGraph:
def _get_triples(self, e_labels: List[str]) -> List[str]: def _get_triples(self, e_labels: List[str]) -> List[str]:
triple_query = """ triple_query = """
MATCH (a)-[e:{e_label}]->(b) MATCH (a)-[e:`{e_label}`]->(b)
WITH a,e,b LIMIT 3000 WITH a,e,b LIMIT 3000
RETURN DISTINCT labels(a) AS from, type(e) AS edge, labels(b) AS to RETURN DISTINCT labels(a) AS from, type(e) AS edge, labels(b) AS to
LIMIT 10 LIMIT 10
""" """
triple_template = "(:{a})-[:{e}]->(:{b})" triple_template = "(:`{a}`)-[:`{e}`]->(:`{b}`)"
triple_schema = [] triple_schema = []
for label in e_labels: for label in e_labels:
q = triple_query.format(e_label=label) q = triple_query.format(e_label=label)
@ -128,7 +167,7 @@ class NeptuneGraph:
def _get_node_properties(self, n_labels: List[str], types: Dict) -> List: def _get_node_properties(self, n_labels: List[str], types: Dict) -> List:
node_properties_query = """ node_properties_query = """
MATCH (a:{n_label}) MATCH (a:`{n_label}`)
RETURN properties(a) AS props RETURN properties(a) AS props
LIMIT 100 LIMIT 100
""" """
@ -151,7 +190,7 @@ class NeptuneGraph:
def _get_edge_properties(self, e_labels: List[str], types: Dict[str, Any]) -> List: def _get_edge_properties(self, e_labels: List[str], types: Dict[str, Any]) -> List:
edge_properties_query = """ edge_properties_query = """
MATCH ()-[e:{e_label}]->() MATCH ()-[e:`{e_label}`]->()
RETURN properties(e) AS props RETURN properties(e) AS props
LIMIT 100 LIMIT 100
""" """
@ -183,6 +222,7 @@ class NeptuneGraph:
"int": "INTEGER", "int": "INTEGER",
"list": "LIST", "list": "LIST",
"dict": "MAP", "dict": "MAP",
"bool": "BOOLEAN",
} }
n_labels, e_labels = self._get_labels() n_labels, e_labels = self._get_labels()
triple_schema = self._get_triples(e_labels) triple_schema = self._get_triples(e_labels)