Updated Neptune graph to use boto (#10121)

## Description
This PR updates the `NeptuneGraph` class to start using the boto API for
connecting to the Neptune service. With boto integration, the graph
class now supports authenticating requests using Sigv4; this is
encapsulated with the boto API, and users only have to ensure they have
the correct AWS credentials setup in their workspace to work with the
graph class.

This PR also introduces a conditional prompt that uses a simpler prompt
when using the `Anthropic` model provider. A simpler prompt have seemed
to work better for generating cypher queries in our testing.

**Note**: This version will require boto3 version 1.28.38 or greater to
work.
pull/10824/head
Piyush Jain 11 months ago committed by GitHub
parent 33781ac4a2
commit 94cf71ecfa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,8 +9,10 @@ from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import (
CYPHER_QA_PROMPT,
NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT,
)
from langchain.chains.llm import LLMChain
from langchain.chains.prompt_selector import ConditionalPromptSelector
from langchain.graphs import NeptuneGraph
from langchain.prompts.base import BasePromptTemplate
from langchain.pydantic_v1 import Field
@ -18,6 +20,37 @@ from langchain.pydantic_v1 import Field
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
def trim_query(query: str) -> str:
keywords = (
"CALL",
"CREATE",
"DELETE",
"DETACH",
"LIMIT",
"MATCH",
"MERGE",
"OPTIONAL",
"ORDER",
"REMOVE",
"RETURN",
"SET",
"SKIP",
"UNWIND",
"WITH",
"WHERE",
"//",
)
lines = query.split("\n")
new_query = ""
for line in lines:
if line.strip().upper().startswith(keywords):
new_query += line + "\n"
return new_query
def extract_cypher(text: str) -> str:
"""Extract Cypher code from text using Regex."""
# The pattern to find Cypher code enclosed in triple backticks
@ -29,6 +62,24 @@ def extract_cypher(text: str) -> str:
return matches[0] if matches else text
def use_simple_prompt(llm: BaseLanguageModel) -> bool:
"""Decides whether to use the simple prompt"""
if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore
return True
# Bedrock anthropic
if llm.model_id and "anthropic" in llm.model_id: # type: ignore
return True
return False
PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
conditionals=[(use_simple_prompt, NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT)],
)
class NeptuneOpenCypherQAChain(Chain):
"""Chain for question-answering against a Neptune graph
by generating openCypher statements.
@ -77,12 +128,14 @@ class NeptuneOpenCypherQAChain(Chain):
llm: BaseLanguageModel,
*,
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
cypher_prompt: BasePromptTemplate = NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
cypher_prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any,
) -> NeptuneOpenCypherQAChain:
"""Initialize from LLM."""
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt)
_cypher_prompt = cypher_prompt or PROMPT_SELECTOR.get_prompt(llm)
cypher_generation_chain = LLMChain(llm=llm, prompt=_cypher_prompt)
return cls(
qa_chain=qa_chain,
@ -108,6 +161,7 @@ class NeptuneOpenCypherQAChain(Chain):
# Extract Cypher code if it is wrapped in backticks
generated_cypher = extract_cypher(generated_cypher)
generated_cypher = trim_query(generated_cypher)
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
_run_manager.on_text(

@ -331,3 +331,15 @@ NEPTUNE_OPENCYPHER_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "question"],
template=NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE,
)
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE = """
Write an openCypher query to answer the following question. Do not explain the answer. Only return the query.
Question: "{question}".
Here is the property graph schema:
{schema}
\n"""
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT = PromptTemplate(
input_variables=["schema", "question"],
template=NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE,
)

@ -1,7 +1,4 @@
import json
from typing import Any, Dict, List, Tuple, Union
import requests
from typing import Any, Dict, List, Optional, Tuple, Union
class NeptuneQueryException(Exception):
@ -23,8 +20,16 @@ class NeptuneQueryException(Exception):
class NeptuneGraph:
"""Neptune wrapper for graph operations. This version
does not support Sigv4 signing of requests.
"""Neptune wrapper for graph operations.
Args:
host: endpoint for the database instance
port: port number for the database instance, default is 8182
use_https: whether to use secure connection, default is True
client: optional boto3 Neptune client
credentials_profile_name: optional AWS profile name
region_name: optional AWS region, e.g., us-west-2
service: optional service name, default is neptunedata
Example:
.. code-block:: python
@ -35,25 +40,67 @@ class NeptuneGraph:
)
"""
def __init__(self, host: str, port: int = 8182, use_https: bool = True) -> None:
def __init__(
self,
host: str,
port: int = 8182,
use_https: bool = True,
client: Any = None,
credentials_profile_name: Optional[str] = None,
region_name: Optional[str] = None,
service: str = "neptunedata",
) -> None:
"""Create a new Neptune graph wrapper instance."""
if use_https:
self.summary_url = (
f"https://{host}:{port}/pg/statistics/summary?mode=detailed"
)
self.query_url = f"https://{host}:{port}/openCypher"
else:
self.summary_url = (
f"http://{host}:{port}/pg/statistics/summary?mode=detailed"
try:
if client is not None:
self.client = client
else:
import boto3
if credentials_profile_name is not None:
session = boto3.Session(profile_name=credentials_profile_name)
else:
# use default credentials
session = boto3.Session()
client_params = {}
if region_name:
client_params["region_name"] = region_name
protocol = "https" if use_https else "http"
client_params["endpoint_url"] = f"{protocol}://{host}:{port}"
self.client = session.client(service, **client_params)
except ImportError:
raise ModuleNotFoundError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
self.query_url = f"http://{host}:{port}/openCypher"
except Exception as e:
if type(e).__name__ == "UnknownServiceError":
raise ModuleNotFoundError(
"NeptuneGraph requires a boto3 version 1.28.38 or greater."
"Please install it with `pip install -U boto3`."
) from e
else:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e
# Set schema
try:
self._refresh_schema()
except NeptuneQueryException:
raise ValueError("Could not get schema for Neptune database")
except Exception as e:
raise NeptuneQueryException(
{
"message": "Could not get schema for Neptune database",
"detail": str(e),
}
)
@property
def get_schema(self) -> str:
@ -62,32 +109,24 @@ class NeptuneGraph:
def query(self, query: str, params: dict = {}) -> Dict[str, Any]:
"""Query Neptune database."""
response = requests.post(url=self.query_url, data={"query": query})
if response.ok:
results = json.loads(response.content.decode())
return results
else:
raise NeptuneQueryException(
{
"message": "The generated query failed to execute",
"details": response.content.decode(),
}
)
return self.client.execute_open_cypher_query(openCypherQuery=query)
def _get_summary(self) -> Dict:
response = requests.get(url=self.summary_url)
if not response.ok:
try:
response = self.client.get_propertygraph_summary()
except Exception as e:
raise NeptuneQueryException(
{
"message": (
"Summary API is not available for this instance of Neptune,"
"ensure the engine version is >=1.2.1.0"
),
"details": response.content.decode(),
"details": str(e),
}
)
try:
summary = response.json()["payload"]["graphSummary"]
summary = response["payload"]["graphSummary"]
except Exception:
raise NeptuneQueryException(
{
@ -107,13 +146,13 @@ class NeptuneGraph:
def _get_triples(self, e_labels: List[str]) -> List[str]:
triple_query = """
MATCH (a)-[e:{e_label}]->(b)
MATCH (a)-[e:`{e_label}`]->(b)
WITH a,e,b LIMIT 3000
RETURN DISTINCT labels(a) AS from, type(e) AS edge, labels(b) AS to
LIMIT 10
"""
triple_template = "(:{a})-[:{e}]->(:{b})"
triple_template = "(:`{a}`)-[:`{e}`]->(:`{b}`)"
triple_schema = []
for label in e_labels:
q = triple_query.format(e_label=label)
@ -128,7 +167,7 @@ class NeptuneGraph:
def _get_node_properties(self, n_labels: List[str], types: Dict) -> List:
node_properties_query = """
MATCH (a:{n_label})
MATCH (a:`{n_label}`)
RETURN properties(a) AS props
LIMIT 100
"""
@ -151,7 +190,7 @@ class NeptuneGraph:
def _get_edge_properties(self, e_labels: List[str], types: Dict[str, Any]) -> List:
edge_properties_query = """
MATCH ()-[e:{e_label}]->()
MATCH ()-[e:`{e_label}`]->()
RETURN properties(e) AS props
LIMIT 100
"""
@ -183,6 +222,7 @@ class NeptuneGraph:
"int": "INTEGER",
"list": "LIST",
"dict": "MAP",
"bool": "BOOLEAN",
}
n_labels, e_labels = self._get_labels()
triple_schema = self._get_triples(e_labels)

Loading…
Cancel
Save