diff --git a/langchain/chains/graph_qa/cypher.py b/langchain/chains/graph_qa/cypher.py index 68966cd2..b18e5d36 100644 --- a/langchain/chains/graph_qa/cypher.py +++ b/langchain/chains/graph_qa/cypher.py @@ -1,6 +1,7 @@ """Question answering over a graph.""" from __future__ import annotations +import re from typing import Any, Dict, List, Optional from pydantic import Field @@ -14,6 +15,16 @@ from langchain.graphs.neo4j_graph import Neo4jGraph from langchain.prompts.base import BasePromptTemplate +def extract_cypher(text: str) -> str: + # The pattern to find Cypher code enclosed in triple backticks + pattern = r"```(.*?)```" + + # Find all matches in the input text + matches = re.findall(pattern, text, re.DOTALL) + + return matches[0] if matches else text + + class GraphCypherQAChain(Chain): """Chain for question-answering against a graph by generating Cypher statements.""" @@ -73,6 +84,9 @@ class GraphCypherQAChain(Chain): {"question": question, "schema": self.graph.get_schema}, callbacks=callbacks ) + # Extract Cypher code if it is wrapped in backticks + generated_cypher = extract_cypher(generated_cypher) + _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose) _run_manager.on_text( generated_cypher, color="green", end="\n", verbose=self.verbose diff --git a/langchain/chains/graph_qa/prompts.py b/langchain/chains/graph_qa/prompts.py index b7008ab3..aefb2489 100644 --- a/langchain/chains/graph_qa/prompts.py +++ b/langchain/chains/graph_qa/prompts.py @@ -50,9 +50,10 @@ CYPHER_GENERATION_PROMPT = PromptTemplate( ) CYPHER_QA_TEMPLATE = """You are an assistant that helps to form nice and human understandable answers. -The information part contains the provided information that you can use to construct an answer. +The information part contains the provided information that you must use to construct an answer. The provided information is authorative, you must never doubt it or try to use your internal knowledge to correct it. -Make it sound like the information are coming from an AI assistant, but don't add any information. +Make the answer sound as a response to the question. Do not mention that you based the result on the given information. +If the provided information is empty, say that you don't know the answer. Information: {context} diff --git a/tests/unit_tests/chains/test_graph_qa.py b/tests/unit_tests/chains/test_graph_qa.py new file mode 100644 index 00000000..4577994b --- /dev/null +++ b/tests/unit_tests/chains/test_graph_qa.py @@ -0,0 +1,15 @@ +from langchain.chains.graph_qa.cypher import extract_cypher + + +def test_no_backticks() -> None: + """Test if there are no backticks, so the original text should be returned.""" + query = "MATCH (n) RETURN n" + output = extract_cypher(query) + assert output == query + + +def test_backticks() -> None: + """Test if there are backticks. Query from within backticks should be returned.""" + query = "You can use the following query: ```MATCH (n) RETURN n```" + output = extract_cypher(query) + assert output == "MATCH (n) RETURN n"