From 94cf71ecfa544f650e6197b061b7b5a3d49def87 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Tue, 19 Sep 2023 16:03:08 -0700 Subject: [PATCH] Updated Neptune graph to use boto (#10121) ## Description This PR updates the `NeptuneGraph` class to start using the boto API for connecting to the Neptune service. With boto integration, the graph class now supports authenticating requests using Sigv4; this is encapsulated with the boto API, and users only have to ensure they have the correct AWS credentials setup in their workspace to work with the graph class. This PR also introduces a conditional prompt that uses a simpler prompt when using the `Anthropic` model provider. A simpler prompt have seemed to work better for generating cypher queries in our testing. **Note**: This version will require boto3 version 1.28.38 or greater to work. --- .../chains/graph_qa/neptune_cypher.py | 58 ++++++++- .../langchain/chains/graph_qa/prompts.py | 12 ++ .../langchain/graphs/neptune_graph.py | 116 ++++++++++++------ 3 files changed, 146 insertions(+), 40 deletions(-) diff --git a/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py b/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py index 8809a665fd..013c9622de 100644 --- a/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py +++ b/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py @@ -9,8 +9,10 @@ from langchain.chains.base import Chain from langchain.chains.graph_qa.prompts import ( CYPHER_QA_PROMPT, NEPTUNE_OPENCYPHER_GENERATION_PROMPT, + NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT, ) from langchain.chains.llm import LLMChain +from langchain.chains.prompt_selector import ConditionalPromptSelector from langchain.graphs import NeptuneGraph from langchain.prompts.base import BasePromptTemplate from langchain.pydantic_v1 import Field @@ -18,6 +20,37 @@ from langchain.pydantic_v1 import Field INTERMEDIATE_STEPS_KEY = "intermediate_steps" +def trim_query(query: str) -> str: + keywords = ( + "CALL", + "CREATE", + "DELETE", + "DETACH", + "LIMIT", + "MATCH", + "MERGE", + "OPTIONAL", + "ORDER", + "REMOVE", + "RETURN", + "SET", + "SKIP", + "UNWIND", + "WITH", + "WHERE", + "//", + ) + + lines = query.split("\n") + new_query = "" + + for line in lines: + if line.strip().upper().startswith(keywords): + new_query += line + "\n" + + return new_query + + def extract_cypher(text: str) -> str: """Extract Cypher code from text using Regex.""" # The pattern to find Cypher code enclosed in triple backticks @@ -29,6 +62,24 @@ def extract_cypher(text: str) -> str: return matches[0] if matches else text +def use_simple_prompt(llm: BaseLanguageModel) -> bool: + """Decides whether to use the simple prompt""" + if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore + return True + + # Bedrock anthropic + if llm.model_id and "anthropic" in llm.model_id: # type: ignore + return True + + return False + + +PROMPT_SELECTOR = ConditionalPromptSelector( + default_prompt=NEPTUNE_OPENCYPHER_GENERATION_PROMPT, + conditionals=[(use_simple_prompt, NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT)], +) + + class NeptuneOpenCypherQAChain(Chain): """Chain for question-answering against a Neptune graph by generating openCypher statements. @@ -77,12 +128,14 @@ class NeptuneOpenCypherQAChain(Chain): llm: BaseLanguageModel, *, qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, - cypher_prompt: BasePromptTemplate = NEPTUNE_OPENCYPHER_GENERATION_PROMPT, + cypher_prompt: Optional[BasePromptTemplate] = None, **kwargs: Any, ) -> NeptuneOpenCypherQAChain: """Initialize from LLM.""" qa_chain = LLMChain(llm=llm, prompt=qa_prompt) - cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt) + + _cypher_prompt = cypher_prompt or PROMPT_SELECTOR.get_prompt(llm) + cypher_generation_chain = LLMChain(llm=llm, prompt=_cypher_prompt) return cls( qa_chain=qa_chain, @@ -108,6 +161,7 @@ class NeptuneOpenCypherQAChain(Chain): # Extract Cypher code if it is wrapped in backticks generated_cypher = extract_cypher(generated_cypher) + generated_cypher = trim_query(generated_cypher) _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose) _run_manager.on_text( diff --git a/libs/langchain/langchain/chains/graph_qa/prompts.py b/libs/langchain/langchain/chains/graph_qa/prompts.py index 0392a99955..d6cc4da11c 100644 --- a/libs/langchain/langchain/chains/graph_qa/prompts.py +++ b/libs/langchain/langchain/chains/graph_qa/prompts.py @@ -331,3 +331,15 @@ NEPTUNE_OPENCYPHER_GENERATION_PROMPT = PromptTemplate( input_variables=["schema", "question"], template=NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE, ) + +NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE = """ +Write an openCypher query to answer the following question. Do not explain the answer. Only return the query. +Question: "{question}". +Here is the property graph schema: +{schema} +\n""" + +NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT = PromptTemplate( + input_variables=["schema", "question"], + template=NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE, +) diff --git a/libs/langchain/langchain/graphs/neptune_graph.py b/libs/langchain/langchain/graphs/neptune_graph.py index 299c1ebbcc..ac6e98eb25 100644 --- a/libs/langchain/langchain/graphs/neptune_graph.py +++ b/libs/langchain/langchain/graphs/neptune_graph.py @@ -1,7 +1,4 @@ -import json -from typing import Any, Dict, List, Tuple, Union - -import requests +from typing import Any, Dict, List, Optional, Tuple, Union class NeptuneQueryException(Exception): @@ -23,8 +20,16 @@ class NeptuneQueryException(Exception): class NeptuneGraph: - """Neptune wrapper for graph operations. This version - does not support Sigv4 signing of requests. + """Neptune wrapper for graph operations. + + Args: + host: endpoint for the database instance + port: port number for the database instance, default is 8182 + use_https: whether to use secure connection, default is True + client: optional boto3 Neptune client + credentials_profile_name: optional AWS profile name + region_name: optional AWS region, e.g., us-west-2 + service: optional service name, default is neptunedata Example: .. code-block:: python @@ -35,25 +40,67 @@ class NeptuneGraph: ) """ - def __init__(self, host: str, port: int = 8182, use_https: bool = True) -> None: + def __init__( + self, + host: str, + port: int = 8182, + use_https: bool = True, + client: Any = None, + credentials_profile_name: Optional[str] = None, + region_name: Optional[str] = None, + service: str = "neptunedata", + ) -> None: """Create a new Neptune graph wrapper instance.""" - if use_https: - self.summary_url = ( - f"https://{host}:{port}/pg/statistics/summary?mode=detailed" - ) - self.query_url = f"https://{host}:{port}/openCypher" - else: - self.summary_url = ( - f"http://{host}:{port}/pg/statistics/summary?mode=detailed" + try: + if client is not None: + self.client = client + else: + import boto3 + + if credentials_profile_name is not None: + session = boto3.Session(profile_name=credentials_profile_name) + else: + # use default credentials + session = boto3.Session() + + client_params = {} + if region_name: + client_params["region_name"] = region_name + + protocol = "https" if use_https else "http" + + client_params["endpoint_url"] = f"{protocol}://{host}:{port}" + + self.client = session.client(service, **client_params) + + except ImportError: + raise ModuleNotFoundError( + "Could not import boto3 python package. " + "Please install it with `pip install boto3`." ) - self.query_url = f"http://{host}:{port}/openCypher" + except Exception as e: + if type(e).__name__ == "UnknownServiceError": + raise ModuleNotFoundError( + "NeptuneGraph requires a boto3 version 1.28.38 or greater." + "Please install it with `pip install -U boto3`." + ) from e + else: + raise ValueError( + "Could not load credentials to authenticate with AWS client. " + "Please check that credentials in the specified " + "profile name are valid." + ) from e - # Set schema try: self._refresh_schema() - except NeptuneQueryException: - raise ValueError("Could not get schema for Neptune database") + except Exception as e: + raise NeptuneQueryException( + { + "message": "Could not get schema for Neptune database", + "detail": str(e), + } + ) @property def get_schema(self) -> str: @@ -62,32 +109,24 @@ class NeptuneGraph: def query(self, query: str, params: dict = {}) -> Dict[str, Any]: """Query Neptune database.""" - response = requests.post(url=self.query_url, data={"query": query}) - if response.ok: - results = json.loads(response.content.decode()) - return results - else: - raise NeptuneQueryException( - { - "message": "The generated query failed to execute", - "details": response.content.decode(), - } - ) + return self.client.execute_open_cypher_query(openCypherQuery=query) def _get_summary(self) -> Dict: - response = requests.get(url=self.summary_url) - if not response.ok: + try: + response = self.client.get_propertygraph_summary() + except Exception as e: raise NeptuneQueryException( { "message": ( "Summary API is not available for this instance of Neptune," "ensure the engine version is >=1.2.1.0" ), - "details": response.content.decode(), + "details": str(e), } ) + try: - summary = response.json()["payload"]["graphSummary"] + summary = response["payload"]["graphSummary"] except Exception: raise NeptuneQueryException( { @@ -107,13 +146,13 @@ class NeptuneGraph: def _get_triples(self, e_labels: List[str]) -> List[str]: triple_query = """ - MATCH (a)-[e:{e_label}]->(b) + MATCH (a)-[e:`{e_label}`]->(b) WITH a,e,b LIMIT 3000 RETURN DISTINCT labels(a) AS from, type(e) AS edge, labels(b) AS to LIMIT 10 """ - triple_template = "(:{a})-[:{e}]->(:{b})" + triple_template = "(:`{a}`)-[:`{e}`]->(:`{b}`)" triple_schema = [] for label in e_labels: q = triple_query.format(e_label=label) @@ -128,7 +167,7 @@ class NeptuneGraph: def _get_node_properties(self, n_labels: List[str], types: Dict) -> List: node_properties_query = """ - MATCH (a:{n_label}) + MATCH (a:`{n_label}`) RETURN properties(a) AS props LIMIT 100 """ @@ -151,7 +190,7 @@ class NeptuneGraph: def _get_edge_properties(self, e_labels: List[str], types: Dict[str, Any]) -> List: edge_properties_query = """ - MATCH ()-[e:{e_label}]->() + MATCH ()-[e:`{e_label}`]->() RETURN properties(e) AS props LIMIT 100 """ @@ -183,6 +222,7 @@ class NeptuneGraph: "int": "INTEGER", "list": "LIST", "dict": "MAP", + "bool": "BOOLEAN", } n_labels, e_labels = self._get_labels() triple_schema = self._get_triples(e_labels)