mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Updated Neptune graph to use boto (#10121)
## Description This PR updates the `NeptuneGraph` class to start using the boto API for connecting to the Neptune service. With boto integration, the graph class now supports authenticating requests using Sigv4; this is encapsulated with the boto API, and users only have to ensure they have the correct AWS credentials setup in their workspace to work with the graph class. This PR also introduces a conditional prompt that uses a simpler prompt when using the `Anthropic` model provider. A simpler prompt have seemed to work better for generating cypher queries in our testing. **Note**: This version will require boto3 version 1.28.38 or greater to work.
This commit is contained in:
parent
33781ac4a2
commit
94cf71ecfa
@ -9,8 +9,10 @@ from langchain.chains.base import Chain
|
|||||||
from langchain.chains.graph_qa.prompts import (
|
from langchain.chains.graph_qa.prompts import (
|
||||||
CYPHER_QA_PROMPT,
|
CYPHER_QA_PROMPT,
|
||||||
NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
|
NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
|
||||||
|
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT,
|
||||||
)
|
)
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.chains.prompt_selector import ConditionalPromptSelector
|
||||||
from langchain.graphs import NeptuneGraph
|
from langchain.graphs import NeptuneGraph
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
from langchain.pydantic_v1 import Field
|
from langchain.pydantic_v1 import Field
|
||||||
@ -18,6 +20,37 @@ from langchain.pydantic_v1 import Field
|
|||||||
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
INTERMEDIATE_STEPS_KEY = "intermediate_steps"
|
||||||
|
|
||||||
|
|
||||||
|
def trim_query(query: str) -> str:
|
||||||
|
keywords = (
|
||||||
|
"CALL",
|
||||||
|
"CREATE",
|
||||||
|
"DELETE",
|
||||||
|
"DETACH",
|
||||||
|
"LIMIT",
|
||||||
|
"MATCH",
|
||||||
|
"MERGE",
|
||||||
|
"OPTIONAL",
|
||||||
|
"ORDER",
|
||||||
|
"REMOVE",
|
||||||
|
"RETURN",
|
||||||
|
"SET",
|
||||||
|
"SKIP",
|
||||||
|
"UNWIND",
|
||||||
|
"WITH",
|
||||||
|
"WHERE",
|
||||||
|
"//",
|
||||||
|
)
|
||||||
|
|
||||||
|
lines = query.split("\n")
|
||||||
|
new_query = ""
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if line.strip().upper().startswith(keywords):
|
||||||
|
new_query += line + "\n"
|
||||||
|
|
||||||
|
return new_query
|
||||||
|
|
||||||
|
|
||||||
def extract_cypher(text: str) -> str:
|
def extract_cypher(text: str) -> str:
|
||||||
"""Extract Cypher code from text using Regex."""
|
"""Extract Cypher code from text using Regex."""
|
||||||
# The pattern to find Cypher code enclosed in triple backticks
|
# The pattern to find Cypher code enclosed in triple backticks
|
||||||
@ -29,6 +62,24 @@ def extract_cypher(text: str) -> str:
|
|||||||
return matches[0] if matches else text
|
return matches[0] if matches else text
|
||||||
|
|
||||||
|
|
||||||
|
def use_simple_prompt(llm: BaseLanguageModel) -> bool:
|
||||||
|
"""Decides whether to use the simple prompt"""
|
||||||
|
if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Bedrock anthropic
|
||||||
|
if llm.model_id and "anthropic" in llm.model_id: # type: ignore
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
PROMPT_SELECTOR = ConditionalPromptSelector(
|
||||||
|
default_prompt=NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
|
||||||
|
conditionals=[(use_simple_prompt, NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NeptuneOpenCypherQAChain(Chain):
|
class NeptuneOpenCypherQAChain(Chain):
|
||||||
"""Chain for question-answering against a Neptune graph
|
"""Chain for question-answering against a Neptune graph
|
||||||
by generating openCypher statements.
|
by generating openCypher statements.
|
||||||
@ -77,12 +128,14 @@ class NeptuneOpenCypherQAChain(Chain):
|
|||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
*,
|
*,
|
||||||
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
|
||||||
cypher_prompt: BasePromptTemplate = NEPTUNE_OPENCYPHER_GENERATION_PROMPT,
|
cypher_prompt: Optional[BasePromptTemplate] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> NeptuneOpenCypherQAChain:
|
) -> NeptuneOpenCypherQAChain:
|
||||||
"""Initialize from LLM."""
|
"""Initialize from LLM."""
|
||||||
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
qa_chain = LLMChain(llm=llm, prompt=qa_prompt)
|
||||||
cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt)
|
|
||||||
|
_cypher_prompt = cypher_prompt or PROMPT_SELECTOR.get_prompt(llm)
|
||||||
|
cypher_generation_chain = LLMChain(llm=llm, prompt=_cypher_prompt)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
qa_chain=qa_chain,
|
qa_chain=qa_chain,
|
||||||
@ -108,6 +161,7 @@ class NeptuneOpenCypherQAChain(Chain):
|
|||||||
|
|
||||||
# Extract Cypher code if it is wrapped in backticks
|
# Extract Cypher code if it is wrapped in backticks
|
||||||
generated_cypher = extract_cypher(generated_cypher)
|
generated_cypher = extract_cypher(generated_cypher)
|
||||||
|
generated_cypher = trim_query(generated_cypher)
|
||||||
|
|
||||||
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
|
_run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose)
|
||||||
_run_manager.on_text(
|
_run_manager.on_text(
|
||||||
|
@ -331,3 +331,15 @@ NEPTUNE_OPENCYPHER_GENERATION_PROMPT = PromptTemplate(
|
|||||||
input_variables=["schema", "question"],
|
input_variables=["schema", "question"],
|
||||||
template=NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE,
|
template=NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE = """
|
||||||
|
Write an openCypher query to answer the following question. Do not explain the answer. Only return the query.
|
||||||
|
Question: "{question}".
|
||||||
|
Here is the property graph schema:
|
||||||
|
{schema}
|
||||||
|
\n"""
|
||||||
|
|
||||||
|
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT = PromptTemplate(
|
||||||
|
input_variables=["schema", "question"],
|
||||||
|
template=NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE,
|
||||||
|
)
|
||||||
|
@ -1,7 +1,4 @@
|
|||||||
import json
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
|
|
||||||
class NeptuneQueryException(Exception):
|
class NeptuneQueryException(Exception):
|
||||||
@ -23,8 +20,16 @@ class NeptuneQueryException(Exception):
|
|||||||
|
|
||||||
|
|
||||||
class NeptuneGraph:
|
class NeptuneGraph:
|
||||||
"""Neptune wrapper for graph operations. This version
|
"""Neptune wrapper for graph operations.
|
||||||
does not support Sigv4 signing of requests.
|
|
||||||
|
Args:
|
||||||
|
host: endpoint for the database instance
|
||||||
|
port: port number for the database instance, default is 8182
|
||||||
|
use_https: whether to use secure connection, default is True
|
||||||
|
client: optional boto3 Neptune client
|
||||||
|
credentials_profile_name: optional AWS profile name
|
||||||
|
region_name: optional AWS region, e.g., us-west-2
|
||||||
|
service: optional service name, default is neptunedata
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -35,25 +40,67 @@ class NeptuneGraph:
|
|||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, host: str, port: int = 8182, use_https: bool = True) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
port: int = 8182,
|
||||||
|
use_https: bool = True,
|
||||||
|
client: Any = None,
|
||||||
|
credentials_profile_name: Optional[str] = None,
|
||||||
|
region_name: Optional[str] = None,
|
||||||
|
service: str = "neptunedata",
|
||||||
|
) -> None:
|
||||||
"""Create a new Neptune graph wrapper instance."""
|
"""Create a new Neptune graph wrapper instance."""
|
||||||
|
|
||||||
if use_https:
|
try:
|
||||||
self.summary_url = (
|
if client is not None:
|
||||||
f"https://{host}:{port}/pg/statistics/summary?mode=detailed"
|
self.client = client
|
||||||
)
|
|
||||||
self.query_url = f"https://{host}:{port}/openCypher"
|
|
||||||
else:
|
else:
|
||||||
self.summary_url = (
|
import boto3
|
||||||
f"http://{host}:{port}/pg/statistics/summary?mode=detailed"
|
|
||||||
)
|
if credentials_profile_name is not None:
|
||||||
self.query_url = f"http://{host}:{port}/openCypher"
|
session = boto3.Session(profile_name=credentials_profile_name)
|
||||||
|
else:
|
||||||
|
# use default credentials
|
||||||
|
session = boto3.Session()
|
||||||
|
|
||||||
|
client_params = {}
|
||||||
|
if region_name:
|
||||||
|
client_params["region_name"] = region_name
|
||||||
|
|
||||||
|
protocol = "https" if use_https else "http"
|
||||||
|
|
||||||
|
client_params["endpoint_url"] = f"{protocol}://{host}:{port}"
|
||||||
|
|
||||||
|
self.client = session.client(service, **client_params)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"Could not import boto3 python package. "
|
||||||
|
"Please install it with `pip install boto3`."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
if type(e).__name__ == "UnknownServiceError":
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"NeptuneGraph requires a boto3 version 1.28.38 or greater."
|
||||||
|
"Please install it with `pip install -U boto3`."
|
||||||
|
) from e
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not load credentials to authenticate with AWS client. "
|
||||||
|
"Please check that credentials in the specified "
|
||||||
|
"profile name are valid."
|
||||||
|
) from e
|
||||||
|
|
||||||
# Set schema
|
|
||||||
try:
|
try:
|
||||||
self._refresh_schema()
|
self._refresh_schema()
|
||||||
except NeptuneQueryException:
|
except Exception as e:
|
||||||
raise ValueError("Could not get schema for Neptune database")
|
raise NeptuneQueryException(
|
||||||
|
{
|
||||||
|
"message": "Could not get schema for Neptune database",
|
||||||
|
"detail": str(e),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def get_schema(self) -> str:
|
def get_schema(self) -> str:
|
||||||
@ -62,32 +109,24 @@ class NeptuneGraph:
|
|||||||
|
|
||||||
def query(self, query: str, params: dict = {}) -> Dict[str, Any]:
|
def query(self, query: str, params: dict = {}) -> Dict[str, Any]:
|
||||||
"""Query Neptune database."""
|
"""Query Neptune database."""
|
||||||
response = requests.post(url=self.query_url, data={"query": query})
|
return self.client.execute_open_cypher_query(openCypherQuery=query)
|
||||||
if response.ok:
|
|
||||||
results = json.loads(response.content.decode())
|
|
||||||
return results
|
|
||||||
else:
|
|
||||||
raise NeptuneQueryException(
|
|
||||||
{
|
|
||||||
"message": "The generated query failed to execute",
|
|
||||||
"details": response.content.decode(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_summary(self) -> Dict:
|
def _get_summary(self) -> Dict:
|
||||||
response = requests.get(url=self.summary_url)
|
try:
|
||||||
if not response.ok:
|
response = self.client.get_propertygraph_summary()
|
||||||
|
except Exception as e:
|
||||||
raise NeptuneQueryException(
|
raise NeptuneQueryException(
|
||||||
{
|
{
|
||||||
"message": (
|
"message": (
|
||||||
"Summary API is not available for this instance of Neptune,"
|
"Summary API is not available for this instance of Neptune,"
|
||||||
"ensure the engine version is >=1.2.1.0"
|
"ensure the engine version is >=1.2.1.0"
|
||||||
),
|
),
|
||||||
"details": response.content.decode(),
|
"details": str(e),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
summary = response.json()["payload"]["graphSummary"]
|
summary = response["payload"]["graphSummary"]
|
||||||
except Exception:
|
except Exception:
|
||||||
raise NeptuneQueryException(
|
raise NeptuneQueryException(
|
||||||
{
|
{
|
||||||
@ -107,13 +146,13 @@ class NeptuneGraph:
|
|||||||
|
|
||||||
def _get_triples(self, e_labels: List[str]) -> List[str]:
|
def _get_triples(self, e_labels: List[str]) -> List[str]:
|
||||||
triple_query = """
|
triple_query = """
|
||||||
MATCH (a)-[e:{e_label}]->(b)
|
MATCH (a)-[e:`{e_label}`]->(b)
|
||||||
WITH a,e,b LIMIT 3000
|
WITH a,e,b LIMIT 3000
|
||||||
RETURN DISTINCT labels(a) AS from, type(e) AS edge, labels(b) AS to
|
RETURN DISTINCT labels(a) AS from, type(e) AS edge, labels(b) AS to
|
||||||
LIMIT 10
|
LIMIT 10
|
||||||
"""
|
"""
|
||||||
|
|
||||||
triple_template = "(:{a})-[:{e}]->(:{b})"
|
triple_template = "(:`{a}`)-[:`{e}`]->(:`{b}`)"
|
||||||
triple_schema = []
|
triple_schema = []
|
||||||
for label in e_labels:
|
for label in e_labels:
|
||||||
q = triple_query.format(e_label=label)
|
q = triple_query.format(e_label=label)
|
||||||
@ -128,7 +167,7 @@ class NeptuneGraph:
|
|||||||
|
|
||||||
def _get_node_properties(self, n_labels: List[str], types: Dict) -> List:
|
def _get_node_properties(self, n_labels: List[str], types: Dict) -> List:
|
||||||
node_properties_query = """
|
node_properties_query = """
|
||||||
MATCH (a:{n_label})
|
MATCH (a:`{n_label}`)
|
||||||
RETURN properties(a) AS props
|
RETURN properties(a) AS props
|
||||||
LIMIT 100
|
LIMIT 100
|
||||||
"""
|
"""
|
||||||
@ -151,7 +190,7 @@ class NeptuneGraph:
|
|||||||
|
|
||||||
def _get_edge_properties(self, e_labels: List[str], types: Dict[str, Any]) -> List:
|
def _get_edge_properties(self, e_labels: List[str], types: Dict[str, Any]) -> List:
|
||||||
edge_properties_query = """
|
edge_properties_query = """
|
||||||
MATCH ()-[e:{e_label}]->()
|
MATCH ()-[e:`{e_label}`]->()
|
||||||
RETURN properties(e) AS props
|
RETURN properties(e) AS props
|
||||||
LIMIT 100
|
LIMIT 100
|
||||||
"""
|
"""
|
||||||
@ -183,6 +222,7 @@ class NeptuneGraph:
|
|||||||
"int": "INTEGER",
|
"int": "INTEGER",
|
||||||
"list": "LIST",
|
"list": "LIST",
|
||||||
"dict": "MAP",
|
"dict": "MAP",
|
||||||
|
"bool": "BOOLEAN",
|
||||||
}
|
}
|
||||||
n_labels, e_labels = self._get_labels()
|
n_labels, e_labels = self._get_labels()
|
||||||
triple_schema = self._get_triples(e_labels)
|
triple_schema = self._get_triples(e_labels)
|
||||||
|
Loading…
Reference in New Issue
Block a user