forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
81 lines
2.7 KiB
Python
81 lines
2.7 KiB
Python
# flake8: noqa
|
|
from langchain.prompts.prompt import PromptTemplate
|
|
from langchain.prompts.chat import (
|
|
SystemMessagePromptTemplate,
|
|
HumanMessagePromptTemplate,
|
|
ChatPromptTemplate,
|
|
AIMessagePromptTemplate,
|
|
)
|
|
from langchain.chains.prompt_selector import (
|
|
ConditionalPromptSelector,
|
|
is_chat_model,
|
|
)
|
|
|
|
|
|
DEFAULT_REFINE_PROMPT_TMPL = (
|
|
"The original question is as follows: {question}\n"
|
|
"We have provided an existing answer: {existing_answer}\n"
|
|
"We have the opportunity to refine the existing answer"
|
|
"(only if needed) with some more context below.\n"
|
|
"------------\n"
|
|
"{context_str}\n"
|
|
"------------\n"
|
|
"Given the new context, refine the original answer to better "
|
|
"answer the question. "
|
|
"If the context isn't useful, return the original answer."
|
|
)
|
|
DEFAULT_REFINE_PROMPT = PromptTemplate(
|
|
input_variables=["question", "existing_answer", "context_str"],
|
|
template=DEFAULT_REFINE_PROMPT_TMPL,
|
|
)
|
|
refine_template = (
|
|
"We have the opportunity to refine the existing answer"
|
|
"(only if needed) with some more context below.\n"
|
|
"------------\n"
|
|
"{context_str}\n"
|
|
"------------\n"
|
|
"Given the new context, refine the original answer to better "
|
|
"answer the question. "
|
|
"If the context isn't useful, return the original answer."
|
|
)
|
|
messages = [
|
|
HumanMessagePromptTemplate.from_template("{question}"),
|
|
AIMessagePromptTemplate.from_template("{existing_answer}"),
|
|
HumanMessagePromptTemplate.from_template(refine_template),
|
|
]
|
|
CHAT_REFINE_PROMPT = ChatPromptTemplate.from_messages(messages)
|
|
REFINE_PROMPT_SELECTOR = ConditionalPromptSelector(
|
|
default_prompt=DEFAULT_REFINE_PROMPT,
|
|
conditionals=[(is_chat_model, CHAT_REFINE_PROMPT)],
|
|
)
|
|
|
|
|
|
DEFAULT_TEXT_QA_PROMPT_TMPL = (
|
|
"Context information is below. \n"
|
|
"---------------------\n"
|
|
"{context_str}"
|
|
"\n---------------------\n"
|
|
"Given the context information and not prior knowledge, "
|
|
"answer the question: {question}\n"
|
|
)
|
|
DEFAULT_TEXT_QA_PROMPT = PromptTemplate(
|
|
input_variables=["context_str", "question"], template=DEFAULT_TEXT_QA_PROMPT_TMPL
|
|
)
|
|
chat_qa_prompt_template = (
|
|
"Context information is below. \n"
|
|
"---------------------\n"
|
|
"{context_str}"
|
|
"\n---------------------\n"
|
|
"Given the context information and not prior knowledge, "
|
|
"answer any questions"
|
|
)
|
|
messages = [
|
|
SystemMessagePromptTemplate.from_template(chat_qa_prompt_template),
|
|
HumanMessagePromptTemplate.from_template("{question}"),
|
|
]
|
|
CHAT_QUESTION_PROMPT = ChatPromptTemplate.from_messages(messages)
|
|
QUESTION_PROMPT_SELECTOR = ConditionalPromptSelector(
|
|
default_prompt=DEFAULT_TEXT_QA_PROMPT,
|
|
conditionals=[(is_chat_model, CHAT_QUESTION_PROMPT)],
|
|
)
|