IMPROVEMENT Neptune graph updates (#13491)

## Description
This PR adds an option to allow unsigned requests to the Neptune
database when using the `NeptuneGraph` class.

```python
graph = NeptuneGraph(
    host='<my-cluster>',
    port=8182,
    sign=False
)
```

Also, added is an option in the `NeptuneOpenCypherQAChain` to provide
additional domain instructions to the graph query generation prompt.
This will be injected in the prompt as-is, so you should include any
provider specific tags, for example `<instructions>` or `<INSTR>`.

```python
chain = NeptuneOpenCypherQAChain.from_llm(
    llm=llm,
    graph=graph,
    extra_instructions="""
    Follow these instructions to build the query:
    1. Countries contain airports, not the other way around
    2. Use the airport code for identifying airports
    """
)
```
This commit is contained in:
Piyush Jain 2023-11-17 13:49:31 -08:00 committed by GitHub
parent 5a28dc3210
commit d2335d0114
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 7 deletions

View File

@ -116,6 +116,8 @@ class NeptuneOpenCypherQAChain(Chain):
"""Whether or not to return the intermediate steps along with the final answer.""" """Whether or not to return the intermediate steps along with the final answer."""
return_direct: bool = False return_direct: bool = False
"""Whether or not to return the result of querying the graph directly.""" """Whether or not to return the result of querying the graph directly."""
extra_instructions: Optional[str] = None
"""Extra instructions by the appended to the query generation prompt."""
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:
@ -141,6 +143,7 @@ class NeptuneOpenCypherQAChain(Chain):
*, *,
qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT,
cypher_prompt: Optional[BasePromptTemplate] = None, cypher_prompt: Optional[BasePromptTemplate] = None,
extra_instructions: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> NeptuneOpenCypherQAChain: ) -> NeptuneOpenCypherQAChain:
"""Initialize from LLM.""" """Initialize from LLM."""
@ -152,6 +155,7 @@ class NeptuneOpenCypherQAChain(Chain):
return cls( return cls(
qa_chain=qa_chain, qa_chain=qa_chain,
cypher_generation_chain=cypher_generation_chain, cypher_generation_chain=cypher_generation_chain,
extra_instructions=extra_instructions,
**kwargs, **kwargs,
) )
@ -168,7 +172,12 @@ class NeptuneOpenCypherQAChain(Chain):
intermediate_steps: List = [] intermediate_steps: List = []
generated_cypher = self.cypher_generation_chain.run( generated_cypher = self.cypher_generation_chain.run(
{"question": question, "schema": self.graph.get_schema}, callbacks=callbacks {
"question": question,
"schema": self.graph.get_schema,
"extra_instructions": self.extra_instructions or "",
},
callbacks=callbacks,
) )
# Extract Cypher code if it is wrapped in backticks # Extract Cypher code if it is wrapped in backticks

View File

@ -320,7 +320,7 @@ Instructions:
Generate the query in openCypher format and follow these rules: Generate the query in openCypher format and follow these rules:
Do not use `NONE`, `ALL` or `ANY` predicate functions, rather use list comprehensions. Do not use `NONE`, `ALL` or `ANY` predicate functions, rather use list comprehensions.
Do not use `REDUCE` function. Rather use a combination of list comprehension and the `UNWIND` clause to achieve similar results. Do not use `REDUCE` function. Rather use a combination of list comprehension and the `UNWIND` clause to achieve similar results.
Do not use `FOREACH` clause. Rather use a combination of `WITH` and `UNWIND` clauses to achieve similar results. Do not use `FOREACH` clause. Rather use a combination of `WITH` and `UNWIND` clauses to achieve similar results.{extra_instructions}
\n""" \n"""
NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace( NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(
@ -328,18 +328,18 @@ NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE = CYPHER_GENERATION_TEMPLATE.replace(
) )
NEPTUNE_OPENCYPHER_GENERATION_PROMPT = PromptTemplate( NEPTUNE_OPENCYPHER_GENERATION_PROMPT = PromptTemplate(
input_variables=["schema", "question"], input_variables=["schema", "question", "extra_instructions"],
template=NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE, template=NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE,
) )
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE = """ NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE = """
Write an openCypher query to answer the following question. Do not explain the answer. Only return the query. Write an openCypher query to answer the following question. Do not explain the answer. Only return the query.{extra_instructions}
Question: "{question}". Question: "{question}".
Here is the property graph schema: Here is the property graph schema:
{schema} {schema}
\n""" \n"""
NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT = PromptTemplate( NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT = PromptTemplate(
input_variables=["schema", "question"], input_variables=["schema", "question", "extra_instructions"],
template=NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE, template=NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE,
) )

View File

@ -30,6 +30,7 @@ class NeptuneGraph:
credentials_profile_name: optional AWS profile name credentials_profile_name: optional AWS profile name
region_name: optional AWS region, e.g., us-west-2 region_name: optional AWS region, e.g., us-west-2
service: optional service name, default is neptunedata service: optional service name, default is neptunedata
sign: optional, whether to sign the request payload, default is True
Example: Example:
.. code-block:: python .. code-block:: python
@ -60,6 +61,7 @@ class NeptuneGraph:
credentials_profile_name: Optional[str] = None, credentials_profile_name: Optional[str] = None,
region_name: Optional[str] = None, region_name: Optional[str] = None,
service: str = "neptunedata", service: str = "neptunedata",
sign: bool = True,
) -> None: ) -> None:
"""Create a new Neptune graph wrapper instance.""" """Create a new Neptune graph wrapper instance."""
@ -83,7 +85,17 @@ class NeptuneGraph:
client_params["endpoint_url"] = f"{protocol}://{host}:{port}" client_params["endpoint_url"] = f"{protocol}://{host}:{port}"
self.client = session.client(service, **client_params) if sign:
self.client = session.client(service, **client_params)
else:
from botocore import UNSIGNED
from botocore.config import Config
self.client = session.client(
service,
**client_params,
config=Config(signature_version=UNSIGNED),
)
except ImportError: except ImportError:
raise ModuleNotFoundError( raise ModuleNotFoundError(
@ -120,7 +132,15 @@ class NeptuneGraph:
def query(self, query: str, params: dict = {}) -> Dict[str, Any]: def query(self, query: str, params: dict = {}) -> Dict[str, Any]:
"""Query Neptune database.""" """Query Neptune database."""
return self.client.execute_open_cypher_query(openCypherQuery=query) try:
return self.client.execute_open_cypher_query(openCypherQuery=query)
except Exception as e:
raise NeptuneQueryException(
{
"message": "An error occurred while executing the query.",
"details": str(e),
}
)
def _get_summary(self) -> Dict: def _get_summary(self) -> Dict:
try: try: