IMPROVEMENT Neptune graph updates (#13491)

## Description
This PR adds an option to allow unsigned requests to the Neptune
database when using the `NeptuneGraph` class.

```python
graph = NeptuneGraph(
    host='<my-cluster>',
    port=8182,
    sign=False
)
```

Also, added is an option in the `NeptuneOpenCypherQAChain` to provide
additional domain instructions to the graph query generation prompt.
This will be injected in the prompt as-is, so you should include any
provider specific tags, for example `<instructions>` or `<INSTR>`.

```python
chain = NeptuneOpenCypherQAChain.from_llm(
    llm=llm,
    graph=graph,
    extra_instructions="""
    Follow these instructions to build the query:
    1. Countries contain airports, not the other way around
    2. Use the airport code for identifying airports
    """
)
```
pull/13543/head
Piyush Jain 10 months ago committed by GitHub
parent 5a28dc3210
commit d2335d0114
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -116,6 +116,8 @@ class NeptuneOpenCypherQAChain(Chain):
"""Whether or not to return the intermediate steps along with the final answer."""
return_direct: bool = False
"""Whether or not to return the result of querying the graph directly."""
extra_instructions: Optional[str] = None
"""Extra instructions by the appended to the query generation prompt."""
@property
def input_keys(self) -> List[str]:
@ -141,6 +143,7 @@ class NeptuneOpenCypherQAChain(Chain):
*,
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
cypher_prompt: Optional[BasePromptTemplate] = None,
extra_instructions: Optional[str] = None,
**kwargs: Any,
) -> NeptuneOpenCypherQAChain:
"""Initialize from LLM."""
@ -152,6 +155,7 @@ class NeptuneOpenCypherQAChain(Chain):
return cls(
qa_chain=qa_chain,
cypher_generation_chain=cypher_generation_chain,
extra_instructions=extra_instructions,
**kwargs,
)
@ -168,7 +172,12 @@ class NeptuneOpenCypherQAChain(Chain):
intermediate_steps: List = []
generated_cypher = self.cypher_generation_chain.run(
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks
{
"question": question,
"schema": self.graph.get_schema,
"extra_instructions": self.extra_instructions or "",
},
callbacks=callbacks,
)
# Extract Cypher code if it is wrapped in backticks

@ -320,7 +320,7 @@ Instructions:
Generate the query in openCypher format and follow these rules:
Do not use `NONE`, `ALL` or `ANY` predicate functions, rather use list comprehensions.
Do not use `REDUCE` function. Rather use a combination of list comprehension and the `UNWIND` clause to achieve similar results.
Do not use `FOREACH` clause. Rather use a combination of `WITH` and `UNWIND` clauses to achieve similar results.
Do not use `FOREACH` clause. Rather use a combination of `WITH` and `UNWIND` clauses to achieve similar results.{extra_instructions}
\n"""
NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(
@ -328,18 +328,18 @@ NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(
)
NEPTUNE_OPENCYPHER_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "question"],
input_variables=["schema", "question", "extra_instructions"],
template=NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE,
)
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE = """
Write an openCypher query to answer the following question. Do not explain the answer. Only return the query.
Write an openCypher query to answer the following question. Do not explain the answer. Only return the query.{extra_instructions}
Question: "{question}".
Here is the property graph schema:
{schema}
\n"""
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT = PromptTemplate(
input_variables=["schema", "question"],
input_variables=["schema", "question", "extra_instructions"],
template=NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE,
)

@ -30,6 +30,7 @@ class NeptuneGraph:
credentials_profile_name: optional AWS profile name
region_name: optional AWS region, e.g., us-west-2
service: optional service name, default is neptunedata
sign: optional, whether to sign the request payload, default is True
Example:
.. code-block:: python
@ -60,6 +61,7 @@ class NeptuneGraph:
credentials_profile_name: Optional[str] = None,
region_name: Optional[str] = None,
service: str = "neptunedata",
sign: bool = True,
) -> None:
"""Create a new Neptune graph wrapper instance."""
@ -83,7 +85,17 @@ class NeptuneGraph:
client_params["endpoint_url"] = f"{protocol}://{host}:{port}"
self.client = session.client(service, **client_params)
if sign:
self.client = session.client(service, **client_params)
else:
from botocore import UNSIGNED
from botocore.config import Config
self.client = session.client(
service,
**client_params,
config=Config(signature_version=UNSIGNED),
)
except ImportError:
raise ModuleNotFoundError(
@ -120,7 +132,15 @@ class NeptuneGraph:
def query(self, query: str, params: dict = {}) -> Dict[str, Any]:
"""Query Neptune database."""
return self.client.execute_open_cypher_query(openCypherQuery=query)
try:
return self.client.execute_open_cypher_query(openCypherQuery=query)
except Exception as e:
raise NeptuneQueryException(
{
"message": "An error occurred while executing the query.",
"details": str(e),
}
)
def _get_summary(self) -> Dict:
try:

Loading…
Cancel
Save