From fd866d1801793d22dca5cabe200df4f2b80fa7a4 Mon Sep 17 00:00:00 2001 From: Tomaz Bratanic Date: Wed, 24 May 2023 17:31:30 +0200 Subject: [PATCH] Update Cypher QA prompt (#5173) # Improve Cypher QA prompt The current QA prompt is optimized for networkX answer generation, which returns all the possible triples. However, Cypher search is a bit more focused and doesn't necessary return all the context information. Due to that reason, the model sometimes refuses to generate an answer even though the information is provided: ![Screenshot from 2023-05-24 08-36-23](https://github.com/hwchase17/langchain/assets/19948365/351cf9c1-2567-447c-91fd-284ae3fa1ccf) To fix this issue, I have updated the prompt. Interestingly, I tried many variations with less instructions and they didn't work properly. However, the current fix works nicely. ![Screenshot from 2023-05-24 08-37-25](https://github.com/hwchase17/langchain/assets/19948365/fc830603-e6ec-4a23-8a86-eaf572996014) --- langchain/chains/graph_qa/cypher.py | 4 ++-- langchain/chains/graph_qa/prompts.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/langchain/chains/graph_qa/cypher.py b/langchain/chains/graph_qa/cypher.py index b06fb9ce..68966cd2 100644 --- a/langchain/chains/graph_qa/cypher.py +++ b/langchain/chains/graph_qa/cypher.py @@ -8,7 +8,7 @@ from pydantic import Field from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain -from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, PROMPT +from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT from langchain.chains.llm import LLMChain from langchain.graphs.neo4j_graph import Neo4jGraph from langchain.prompts.base import BasePromptTemplate @@ -45,7 +45,7 @@ class GraphCypherQAChain(Chain): cls, llm: BaseLanguageModel, *, - qa_prompt: BasePromptTemplate = PROMPT, + qa_prompt: BasePromptTemplate = CYPHER_QA_PROMPT, cypher_prompt: BasePromptTemplate = CYPHER_GENERATION_PROMPT, **kwargs: Any, ) -> GraphCypherQAChain: diff --git a/langchain/chains/graph_qa/prompts.py b/langchain/chains/graph_qa/prompts.py index 5526c67c..b7008ab3 100644 --- a/langchain/chains/graph_qa/prompts.py +++ b/langchain/chains/graph_qa/prompts.py @@ -48,3 +48,16 @@ The question is: CYPHER_GENERATION_PROMPT = PromptTemplate( input_variables=["schema", "question"], template=CYPHER_GENERATION_TEMPLATE ) + +CYPHER_QA_TEMPLATE = """You are an assistant that helps to form nice and human understandable answers. +The information part contains the provided information that you can use to construct an answer. +The provided information is authorative, you must never doubt it or try to use your internal knowledge to correct it. +Make it sound like the information are coming from an AI assistant, but don't add any information. +Information: +{context} + +Question: {question} +Helpful Answer:""" +CYPHER_QA_PROMPT = PromptTemplate( + input_variables=["context", "question"], template=CYPHER_QA_TEMPLATE +)