diff --git a/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py b/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py index 8809a665fd..013c9622de 100644 --- a/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py +++ b/libs/langchain/langchain/chains/graph_qa/neptune_cypher.py @@ -9,8 +9,10 @@ from langchain.chains.base import Chain from langchain.chains.graph_qa.prompts import ( CYPHER_QA_PROMPT, NEPTUNE_OPENCYPHER_GENERATION_PROMPT, + NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT, ) from langchain.chains.llm import LLMChain +from langchain.chains.prompt_selector import ConditionalPromptSelector from langchain.graphs import NeptuneGraph from langchain.prompts.base import BasePromptTemplate from langchain.pydantic_v1 import Field @@ -18,6 +20,37 @@ from langchain.pydantic_v1 import Field INTERMEDIATE_STEPS_KEY = "intermediate_steps" +def trim_query(query: str) -> str: + keywords = ( + "CALL", + "CREATE", + "DELETE", + "DETACH", + "LIMIT", + "MATCH", + "MERGE", + "OPTIONAL", + "ORDER", + "REMOVE", + "RETURN", + "SET", + "SKIP", + "UNWIND", + "WITH", + "WHERE", + "//", + ) + + lines = query.split("\n") + new_query = "" + + for line in lines: + if line.strip().upper().startswith(keywords): + new_query += line + "\n" + + return new_query + + def extract_cypher(text: str) -> str: """Extract Cypher code from text using Regex.""" # The pattern to find Cypher code enclosed in triple backticks @@ -29,6 +62,24 @@ def extract_cypher(text: str) -> str: return matches[0] if matches else text +def use_simple_prompt(llm: BaseLanguageModel) -> bool: + """Decides whether to use the simple prompt""" + if llm._llm_type and "anthropic" in llm._llm_type: # type: ignore + return True + + # Bedrock anthropic + if llm.model_id and "anthropic" in llm.model_id: # type: ignore + return True + + return False + + +PROMPT_SELECTOR = ConditionalPromptSelector( + default_prompt=NEPTUNE_OPENCYPHER_GENERATION_PROMPT, + conditionals=[(use_simple_prompt, NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT)], +) + + class NeptuneOpenCypherQAChain(Chain): """Chain for question-answering against a Neptune graph by generating openCypher statements. @@ -77,12 +128,14 @@ class NeptuneOpenCypherQAChain(Chain): llm: BaseLanguageModel, *, qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, - cypher_prompt: BasePromptTemplate = NEPTUNE_OPENCYPHER_GENERATION_PROMPT, + cypher_prompt: Optional[BasePromptTemplate] = None, **kwargs: Any, ) -> NeptuneOpenCypherQAChain: """Initialize from LLM.""" qa_chain = LLMChain(llm=llm, prompt=qa_prompt) - cypher_generation_chain = LLMChain(llm=llm, prompt=cypher_prompt) + + _cypher_prompt = cypher_prompt or PROMPT_SELECTOR.get_prompt(llm) + cypher_generation_chain = LLMChain(llm=llm, prompt=_cypher_prompt) return cls( qa_chain=qa_chain, @@ -108,6 +161,7 @@ class NeptuneOpenCypherQAChain(Chain): # Extract Cypher code if it is wrapped in backticks generated_cypher = extract_cypher(generated_cypher) + generated_cypher = trim_query(generated_cypher) _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose) _run_manager.on_text( diff --git a/libs/langchain/langchain/chains/graph_qa/prompts.py b/libs/langchain/langchain/chains/graph_qa/prompts.py index 0392a99955..d6cc4da11c 100644 --- a/libs/langchain/langchain/chains/graph_qa/prompts.py +++ b/libs/langchain/langchain/chains/graph_qa/prompts.py @@ -331,3 +331,15 @@ NEPTUNE_OPENCYPHER_GENERATION_PROMPT = PromptTemplate( input_variables=["schema", "question"], template=NEPTUNE_OPENCYPHER_GENERATION_TEMPLATE, ) + +NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE = """ +Write an openCypher query to answer the following question. Do not explain the answer. Only return the query. +Question: "{question}". +Here is the property graph schema: +{schema} +\n""" + +NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_PROMPT = PromptTemplate( + input_variables=["schema", "question"], + template=NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE, +) diff --git a/libs/langchain/langchain/graphs/neptune_graph.py b/libs/langchain/langchain/graphs/neptune_graph.py index 299c1ebbcc..ac6e98eb25 100644 --- a/libs/langchain/langchain/graphs/neptune_graph.py +++ b/libs/langchain/langchain/graphs/neptune_graph.py @@ -1,7 +1,4 @@ -import json -from typing import Any, Dict, List, Tuple, Union - -import requests +from typing import Any, Dict, List, Optional, Tuple, Union class NeptuneQueryException(Exception): @@ -23,8 +20,16 @@ class NeptuneQueryException(Exception): class NeptuneGraph: - """Neptune wrapper for graph operations. This version - does not support Sigv4 signing of requests. + """Neptune wrapper for graph operations. + + Args: + host: endpoint for the database instance + port: port number for the database instance, default is 8182 + use_https: whether to use secure connection, default is True + client: optional boto3 Neptune client + credentials_profile_name: optional AWS profile name + region_name: optional AWS region, e.g., us-west-2 + service: optional service name, default is neptunedata Example: .. code-block:: python @@ -35,25 +40,67 @@ class NeptuneGraph: ) """ - def __init__(self, host: str, port: int = 8182, use_https: bool = True) -> None: + def __init__( + self, + host: str, + port: int = 8182, + use_https: bool = True, + client: Any = None, + credentials_profile_name: Optional[str] = None, + region_name: Optional[str] = None, + service: str = "neptunedata", + ) -> None: """Create a new Neptune graph wrapper instance.""" - if use_https: - self.summary_url = ( - f"https://{host}:{port}/pg/statistics/summary?mode=detailed" - ) - self.query_url = f"https://{host}:{port}/openCypher" - else: - self.summary_url = ( - f"http://{host}:{port}/pg/statistics/summary?mode=detailed" + try: + if client is not None: + self.client = client + else: + import boto3 + + if credentials_profile_name is not None: + session = boto3.Session(profile_name=credentials_profile_name) + else: + # use default credentials + session = boto3.Session() + + client_params = {} + if region_name: + client_params["region_name"] = region_name + + protocol = "https" if use_https else "http" + + client_params["endpoint_url"] = f"{protocol}://{host}:{port}" + + self.client = session.client(service, **client_params) + + except ImportError: + raise ModuleNotFoundError( + "Could not import boto3 python package. " + "Please install it with `pip install boto3`." ) - self.query_url = f"http://{host}:{port}/openCypher" + except Exception as e: + if type(e).__name__ == "UnknownServiceError": + raise ModuleNotFoundError( + "NeptuneGraph requires a boto3 version 1.28.38 or greater." + "Please install it with `pip install -U boto3`." + ) from e + else: + raise ValueError( + "Could not load credentials to authenticate with AWS client. " + "Please check that credentials in the specified " + "profile name are valid." + ) from e - # Set schema try: self._refresh_schema() - except NeptuneQueryException: - raise ValueError("Could not get schema for Neptune database") + except Exception as e: + raise NeptuneQueryException( + { + "message": "Could not get schema for Neptune database", + "detail": str(e), + } + ) @property def get_schema(self) -> str: @@ -62,32 +109,24 @@ class NeptuneGraph: def query(self, query: str, params: dict = {}) -> Dict[str, Any]: """Query Neptune database.""" - response = requests.post(url=self.query_url, data={"query": query}) - if response.ok: - results = json.loads(response.content.decode()) - return results - else: - raise NeptuneQueryException( - { - "message": "The generated query failed to execute", - "details": response.content.decode(), - } - ) + return self.client.execute_open_cypher_query(openCypherQuery=query) def _get_summary(self) -> Dict: - response = requests.get(url=self.summary_url) - if not response.ok: + try: + response = self.client.get_propertygraph_summary() + except Exception as e: raise NeptuneQueryException( { "message": ( "Summary API is not available for this instance of Neptune," "ensure the engine version is >=1.2.1.0" ), - "details": response.content.decode(), + "details": str(e), } ) + try: - summary = response.json()["payload"]["graphSummary"] + summary = response["payload"]["graphSummary"] except Exception: raise NeptuneQueryException( { @@ -107,13 +146,13 @@ class NeptuneGraph: def _get_triples(self, e_labels: List[str]) -> List[str]: triple_query = """ - MATCH (a)-[e:{e_label}]->(b) + MATCH (a)-[e:`{e_label}`]->(b) WITH a,e,b LIMIT 3000 RETURN DISTINCT labels(a) AS from, type(e) AS edge, labels(b) AS to LIMIT 10 """ - triple_template = "(:{a})-[:{e}]->(:{b})" + triple_template = "(:`{a}`)-[:`{e}`]->(:`{b}`)" triple_schema = [] for label in e_labels: q = triple_query.format(e_label=label) @@ -128,7 +167,7 @@ class NeptuneGraph: def _get_node_properties(self, n_labels: List[str], types: Dict) -> List: node_properties_query = """ - MATCH (a:{n_label}) + MATCH (a:`{n_label}`) RETURN properties(a) AS props LIMIT 100 """ @@ -151,7 +190,7 @@ class NeptuneGraph: def _get_edge_properties(self, e_labels: List[str], types: Dict[str, Any]) -> List: edge_properties_query = """ - MATCH ()-[e:{e_label}]->() + MATCH ()-[e:`{e_label}`]->() RETURN properties(e) AS props LIMIT 100 """ @@ -183,6 +222,7 @@ class NeptuneGraph: "int": "INTEGER", "list": "LIST", "dict": "MAP", + "bool": "BOOLEAN", } n_labels, e_labels = self._get_labels() triple_schema = self._get_triples(e_labels)