From a0ea6f6b6bf0fbe8d06c855815ea1e3a6775a84d Mon Sep 17 00:00:00 2001 From: Tomaz Bratanic Date: Mon, 5 Jun 2023 21:48:13 +0200 Subject: [PATCH] Cypher search: Check if generated Cypher is provided in backticks (#5541) # Check if generated Cypher code is wrapped in backticks Some LLMs like the VertexAI like to explain how they generated the Cypher statement and wrap the actual code in three backticks: ![Screenshot from 2023-06-01 08-08-23](https://github.com/hwchase17/langchain/assets/19948365/1d8eecb3-d26c-4882-8f5b-6a9bc7e93690) I have observed a similar pattern with OpenAI chat models in a conversational settings, where multiple user and assistant message are provided to the LLM to generate Cypher statements, where then the LLM wants to maybe apologize for previous steps or explain its thoughts. Interestingly, both OpenAI and VertexAI wrap the code in three backticks if they are doing any explaining or apologizing. Checking if the generated cypher is wrapped in backticks seems like a low-hanging fruit to expand the cypher search to other LLMs and conversational settings. --- langchain/chains/graph_qa/cypher.py | 14 ++++++++++++++ langchain/chains/graph_qa/prompts.py | 5 +++-- tests/unit_tests/chains/test_graph_qa.py | 15 +++++++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) create mode 100644 tests/unit_tests/chains/test_graph_qa.py diff --git a/langchain/chains/graph_qa/cypher.py b/langchain/chains/graph_qa/cypher.py index 68966cd2..b18e5d36 100644 --- a/langchain/chains/graph_qa/cypher.py +++ b/langchain/chains/graph_qa/cypher.py @@ -1,6 +1,7 @@ """Question answering over a graph.""" from __future__ import annotations +import re from typing import Any, Dict, List, Optional from pydantic import Field @@ -14,6 +15,16 @@ from langchain.graphs.neo4j_graph import Neo4jGraph from langchain.prompts.base import BasePromptTemplate +def extract_cypher(text: str) -> str: + # The pattern to find Cypher code enclosed in triple backticks + pattern = r"```(.*?)```" + + # Find all matches in the input text + matches = re.findall(pattern, text, re.DOTALL) + + return matches[0] if matches else text + + class GraphCypherQAChain(Chain): """Chain for question-answering against a graph by generating Cypher statements.""" @@ -73,6 +84,9 @@ class GraphCypherQAChain(Chain): {"question": question, "schema": self.graph.get_schema}, callbacks=callbacks ) + # Extract Cypher code if it is wrapped in backticks + generated_cypher = extract_cypher(generated_cypher) + _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose) _run_manager.on_text( generated_cypher, color="green", end="\n", verbose=self.verbose diff --git a/langchain/chains/graph_qa/prompts.py b/langchain/chains/graph_qa/prompts.py index b7008ab3..aefb2489 100644 --- a/langchain/chains/graph_qa/prompts.py +++ b/langchain/chains/graph_qa/prompts.py @@ -50,9 +50,10 @@ CYPHER_GENERATION_PROMPT = PromptTemplate( ) CYPHER_QA_TEMPLATE = """You are an assistant that helps to form nice and human understandable answers. -The information part contains the provided information that you can use to construct an answer. +The information part contains the provided information that you must use to construct an answer. The provided information is authorative, you must never doubt it or try to use your internal knowledge to correct it. -Make it sound like the information are coming from an AI assistant, but don't add any information. +Make the answer sound as a response to the question. Do not mention that you based the result on the given information. +If the provided information is empty, say that you don't know the answer. Information: {context} diff --git a/tests/unit_tests/chains/test_graph_qa.py b/tests/unit_tests/chains/test_graph_qa.py new file mode 100644 index 00000000..4577994b --- /dev/null +++ b/tests/unit_tests/chains/test_graph_qa.py @@ -0,0 +1,15 @@ +from langchain.chains.graph_qa.cypher import extract_cypher + + +def test_no_backticks() -> None: + """Test if there are no backticks, so the original text should be returned.""" + query = "MATCH (n) RETURN n" + output = extract_cypher(query) + assert output == query + + +def test_backticks() -> None: + """Test if there are backticks. Query from within backticks should be returned.""" + query = "You can use the following query: ```MATCH (n) RETURN n```" + output = extract_cypher(query) + assert output == "MATCH (n) RETURN n"