diff --git a/langchain/chains/graph_qa/base.py b/langchain/chains/graph_qa/base.py index 44dceca8ed..b194f2e146 100644 --- a/langchain/chains/graph_qa/base.py +++ b/langchain/chains/graph_qa/base.py @@ -8,7 +8,7 @@ from pydantic import Field from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain -from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, PROMPT +from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, GRAPH_QA_PROMPT from langchain.chains.llm import LLMChain from langchain.graphs.networkx_graph import NetworkxEntityGraph, get_entities from langchain.schema import BasePromptTemplate @@ -44,7 +44,7 @@ class GraphQAChain(Chain): def from_llm( cls, llm: BaseLanguageModel, - qa_prompt: BasePromptTemplate = PROMPT, + qa_prompt: BasePromptTemplate = GRAPH_QA_PROMPT, entity_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT, **kwargs: Any, ) -> GraphQAChain: diff --git a/langchain/chains/graph_qa/prompts.py b/langchain/chains/graph_qa/prompts.py index ca68983b08..a0898e7aa6 100644 --- a/langchain/chains/graph_qa/prompts.py +++ b/langchain/chains/graph_qa/prompts.py @@ -23,14 +23,14 @@ ENTITY_EXTRACTION_PROMPT = PromptTemplate( input_variables=["input"], template=_DEFAULT_ENTITY_EXTRACTION_TEMPLATE ) -prompt_template = """Use the following knowledge triplets to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. +_DEFAULT_GRAPH_QA_TEMPLATE = """Use the following knowledge triplets to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. {context} Question: {question} Helpful Answer:""" -PROMPT = PromptTemplate( - template=prompt_template, input_variables=["context", "question"] +GRAPH_QA_PROMPT = PromptTemplate( + template=_DEFAULT_GRAPH_QA_TEMPLATE, input_variables=["context", "question"] ) CYPHER_GENERATION_TEMPLATE = """Task:Generate Cypher statement to query a graph database.