add concept of prompt collection (#1507)

This commit is contained in:
Harrison Chase 2023-03-08 08:31:29 -08:00 committed by GitHub
parent 97e3666e0d
commit c4a557bdd4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 204 additions and 30 deletions

View File

@ -9,3 +9,12 @@ This is a specific type of chain where multiple other chains are run in sequence
to the next. A subtype of this type of chain is the [`SimpleSequentialChain`](./generic/sequential_chains.html#simplesequentialchain), where all subchains have only one input and one output,
and the output of one is therefore used as sole input to the next chain.
## Prompt Selectors
One thing that we've noticed is that the best prompt to use is really dependent on the model you use.
Some prompts work really good with some models, but not great with others.
One of our goals is provide good chains that "just work" out of the box.
A big part of chains like that is having prompts that "just work".
So rather than having a default prompt for chains, we are moving towards a paradigm where if a prompt is not explicitly
provided we select one with a PromptSelector. This class takes in the model passed in, and returns a default prompt.
The inner workings of the PromptSelector can look at any aspect of the model - LLM vs ChatModel, OpenAI vs Cohere, GPT3 vs GPT4, etc.
Due to this being a newer feature, this may not be implemented for all chains, but this is the direction we are moving.

View File

@ -1,17 +1,17 @@
"""Chain for chatting with a vector database."""
from __future__ import annotations
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel
from langchain.chains.base import Chain
from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT
from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
from langchain.vectorstores.base import VectorStore
@ -58,10 +58,10 @@ class ChatVectorDBChain(Chain, BaseModel):
@classmethod
def from_llm(
cls,
llm: BaseLLM,
llm: BaseLanguageModel,
vectorstore: VectorStore,
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
qa_prompt: BasePromptTemplate = QA_PROMPT,
qa_prompt: Optional[BasePromptTemplate] = None,
chain_type: str = "stuff",
**kwargs: Any,
) -> ChatVectorDBChain:

View File

@ -0,0 +1,38 @@
from abc import ABC, abstractmethod
from typing import Callable, List, Tuple
from pydantic import BaseModel, Field
from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class BasePromptSelector(BaseModel, ABC):
@abstractmethod
def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate:
"""Get default prompt for a language model."""
class ConditionalPromptSelector(BasePromptSelector):
"""Prompt collection that goes through conditionals."""
default_prompt: BasePromptTemplate
conditionals: List[
Tuple[Callable[[BaseLanguageModel], bool], BasePromptTemplate]
] = Field(default_factory=list)
def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate:
for condition, prompt in self.conditionals:
if condition(llm):
return prompt
return self.default_prompt
def is_llm(llm: BaseLanguageModel) -> bool:
return isinstance(llm, BaseLLM)
def is_chat_model(llm: BaseLanguageModel) -> bool:
return isinstance(llm, BaseChatModel)

View File

@ -14,19 +14,21 @@ from langchain.chains.question_answering import (
refine_prompts,
stuff_prompt,
)
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseLanguageModel
class LoadingCallable(Protocol):
"""Interface for loading the combine documents chain."""
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
def __call__(
self, llm: BaseLanguageModel, **kwargs: Any
) -> BaseCombineDocumentsChain:
"""Callable to load the combine documents chain."""
def _load_map_rerank_chain(
llm: BaseLLM,
llm: BaseLanguageModel,
prompt: BasePromptTemplate = map_rerank_prompt.PROMPT,
verbose: bool = False,
document_variable_name: str = "context",
@ -50,13 +52,14 @@ def _load_map_rerank_chain(
def _load_stuff_chain(
llm: BaseLLM,
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
llm: BaseLanguageModel,
prompt: Optional[BasePromptTemplate] = None,
document_variable_name: str = "context",
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> StuffDocumentsChain:
_prompt = prompt or stuff_prompt.PROMPT_SELECTOR.get_prompt(llm)
llm_chain = LLMChain(
llm=llm, prompt=prompt, verbose=verbose, callback_manager=callback_manager
)
@ -71,28 +74,34 @@ def _load_stuff_chain(
def _load_map_reduce_chain(
llm: BaseLLM,
question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT,
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
llm: BaseLanguageModel,
question_prompt: Optional[BasePromptTemplate] = None,
combine_prompt: Optional[BasePromptTemplate] = None,
combine_document_variable_name: str = "summaries",
map_reduce_document_variable_name: str = "context",
collapse_prompt: Optional[BasePromptTemplate] = None,
reduce_llm: Optional[BaseLLM] = None,
collapse_llm: Optional[BaseLLM] = None,
reduce_llm: Optional[BaseLanguageModel] = None,
collapse_llm: Optional[BaseLanguageModel] = None,
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> MapReduceDocumentsChain:
_question_prompt = (
question_prompt or map_reduce_prompt.QUESTION_PROMPT_SELECTOR.get_prompt(llm)
)
_combine_prompt = (
combine_prompt or map_reduce_prompt.COMBINE_PROMPT_SELECTOR.get_prompt(llm)
)
map_chain = LLMChain(
llm=llm,
prompt=question_prompt,
prompt=_question_prompt,
verbose=verbose,
callback_manager=callback_manager,
)
_reduce_llm = reduce_llm or llm
reduce_chain = LLMChain(
llm=_reduce_llm,
prompt=combine_prompt,
prompt=_combine_prompt,
verbose=verbose,
callback_manager=callback_manager,
)
@ -135,26 +144,32 @@ def _load_map_reduce_chain(
def _load_refine_chain(
llm: BaseLLM,
question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT,
refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT,
llm: BaseLanguageModel,
question_prompt: Optional[BasePromptTemplate] = None,
refine_prompt: Optional[BasePromptTemplate] = None,
document_variable_name: str = "context_str",
initial_response_name: str = "existing_answer",
refine_llm: Optional[BaseLLM] = None,
refine_llm: Optional[BaseLanguageModel] = None,
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
**kwargs: Any,
) -> RefineDocumentsChain:
_question_prompt = (
question_prompt or refine_prompts.REFINE_PROMPT_SELECTOR.get_prompt(llm)
)
_refine_prompt = refine_prompt or refine_prompts.REFINE_PROMPT_SELECTOR.get_prompt(
llm
)
initial_chain = LLMChain(
llm=llm,
prompt=question_prompt,
prompt=_question_prompt,
verbose=verbose,
callback_manager=callback_manager,
)
_refine_llm = refine_llm or llm
refine_chain = LLMChain(
llm=_refine_llm,
prompt=refine_prompt,
prompt=_refine_prompt,
verbose=verbose,
callback_manager=callback_manager,
)
@ -170,7 +185,7 @@ def _load_refine_chain(
def load_qa_chain(
llm: BaseLLM,
llm: BaseLanguageModel,
chain_type: str = "stuff",
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,

View File

@ -1,5 +1,14 @@
# flake8: noqa
from langchain.prompts import PromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.prompts.chat import (
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
ChatPromptTemplate,
)
from langchain.chains.prompt_selector import (
ConditionalPromptSelector,
is_chat_model,
)
question_prompt_template = """Use the following portion of a long document to see if any of the text is relevant to answer the question.
Return any relevant text verbatim.
@ -9,6 +18,20 @@ Relevant text, if any:"""
QUESTION_PROMPT = PromptTemplate(
template=question_prompt_template, input_variables=["context", "question"]
)
system_template = """Use the following portion of a long document to see if any of the text is relevant to answer the question.
Return any relevant text verbatim.
______________________
{context}"""
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{question}"),
]
CHAT_QUESTION_PROMPT = ChatPromptTemplate.from_messages(messages)
QUESTION_PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=QUESTION_PROMPT, conditionals=[(is_chat_model, CHAT_QUESTION_PROMPT)]
)
combine_prompt_template = """Given the following extracted parts of a long document and a question, create a final answer.
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
@ -43,3 +66,18 @@ FINAL ANSWER:"""
COMBINE_PROMPT = PromptTemplate(
template=combine_prompt_template, input_variables=["summaries", "question"]
)
system_template = """Given the following extracted parts of a long document and a question, create a final answer.
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
______________________
{summaries}"""
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{question}"),
]
CHAT_COMBINE_PROMPT = ChatPromptTemplate.from_messages(messages)
COMBINE_PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=COMBINE_PROMPT, conditionals=[(is_chat_model, CHAT_COMBINE_PROMPT)]
)

View File

@ -1,5 +1,16 @@
# flake8: noqa
from langchain.prompts import PromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.prompts.chat import (
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
ChatPromptTemplate,
AIMessagePromptTemplate,
)
from langchain.chains.prompt_selector import (
ConditionalPromptSelector,
is_chat_model,
)
DEFAULT_REFINE_PROMPT_TMPL = (
"The original question is as follows: {question}\n"
@ -17,6 +28,26 @@ DEFAULT_REFINE_PROMPT = PromptTemplate(
input_variables=["question", "existing_answer", "context_str"],
template=DEFAULT_REFINE_PROMPT_TMPL,
)
refine_template = (
"We have the opportunity to refine the existing answer"
"(only if needed) with some more context below.\n"
"------------\n"
"{context_str}\n"
"------------\n"
"Given the new context, refine the original answer to better "
"answer the question. "
"If the context isn't useful, return the original answer."
)
messages = [
HumanMessagePromptTemplate.from_template("{question}"),
AIMessagePromptTemplate.from_template("{existing_answer}"),
HumanMessagePromptTemplate.from_template(refine_template),
]
CHAT_REFINE_PROMPT = ChatPromptTemplate.from_messages(messages)
REFINE_PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=DEFAULT_REFINE_PROMPT,
conditionals=[(is_chat_model, CHAT_REFINE_PROMPT)],
)
DEFAULT_TEXT_QA_PROMPT_TMPL = (
@ -30,3 +61,20 @@ DEFAULT_TEXT_QA_PROMPT_TMPL = (
DEFAULT_TEXT_QA_PROMPT = PromptTemplate(
input_variables=["context_str", "question"], template=DEFAULT_TEXT_QA_PROMPT_TMPL
)
chat_qa_prompt_template = (
"Context information is below. \n"
"---------------------\n"
"{context_str}"
"\n---------------------\n"
"Given the context information and not prior knowledge, "
"answer any questions"
)
messages = [
SystemMessagePromptTemplate.from_template(chat_qa_prompt_template),
HumanMessagePromptTemplate.from_template("{question}"),
]
CHAT_QUESTION_PROMPT = ChatPromptTemplate.from_messages(messages)
QUESTION_PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=DEFAULT_TEXT_QA_PROMPT,
conditionals=[(is_chat_model, CHAT_QUESTION_PROMPT)],
)

View File

@ -1,5 +1,15 @@
# flake8: noqa
from langchain.prompts import PromptTemplate
from langchain.chains.prompt_selector import (
ConditionalPromptSelector,
is_chat_model,
)
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
@ -10,3 +20,18 @@ Helpful Answer:"""
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
system_template = """Use the following pieces of context to answer the users question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
----------------
{context}"""
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{question}"),
]
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
PROMPT_SELECTOR = ConditionalPromptSelector(
default_prompt=PROMPT, conditionals=[(is_chat_model, CHAT_PROMPT)]
)

View File

@ -10,7 +10,7 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.vector_db_qa.prompt import PROMPT
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
from langchain.llms.base import BaseLLM
from langchain.prompts import PromptTemplate
from langchain.vectorstores.base import VectorStore
@ -78,8 +78,8 @@ class VectorDBQA(Chain, BaseModel):
raise ValueError(
"If `combine_documents_chain` not provided, `llm` should be."
)
prompt = values.pop("prompt", PROMPT)
llm = values.pop("llm")
prompt = values.pop("prompt", PROMPT_SELECTOR.get_prompt(llm))
llm_chain = LLMChain(llm=llm, prompt=prompt)
document_prompt = PromptTemplate(
input_variables=["page_content"], template="Context:\n{page_content}"
@ -103,10 +103,11 @@ class VectorDBQA(Chain, BaseModel):
@classmethod
def from_llm(
cls, llm: BaseLLM, prompt: PromptTemplate = PROMPT, **kwargs: Any
cls, llm: BaseLLM, prompt: Optional[PromptTemplate] = None, **kwargs: Any
) -> VectorDBQA:
"""Initialize from LLM."""
llm_chain = LLMChain(llm=llm, prompt=prompt)
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
llm_chain = LLMChain(llm=llm, prompt=_prompt)
document_prompt = PromptTemplate(
input_variables=["page_content"], template="Context:\n{page_content}"
)