mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
add concept of prompt collection (#1507)
This commit is contained in:
parent
97e3666e0d
commit
c4a557bdd4
@ -9,3 +9,12 @@ This is a specific type of chain where multiple other chains are run in sequence
|
||||
to the next. A subtype of this type of chain is the [`SimpleSequentialChain`](./generic/sequential_chains.html#simplesequentialchain), where all subchains have only one input and one output,
|
||||
and the output of one is therefore used as sole input to the next chain.
|
||||
|
||||
## Prompt Selectors
|
||||
One thing that we've noticed is that the best prompt to use is really dependent on the model you use.
|
||||
Some prompts work really good with some models, but not great with others.
|
||||
One of our goals is provide good chains that "just work" out of the box.
|
||||
A big part of chains like that is having prompts that "just work".
|
||||
So rather than having a default prompt for chains, we are moving towards a paradigm where if a prompt is not explicitly
|
||||
provided we select one with a PromptSelector. This class takes in the model passed in, and returns a default prompt.
|
||||
The inner workings of the PromptSelector can look at any aspect of the model - LLM vs ChatModel, OpenAI vs Cohere, GPT3 vs GPT4, etc.
|
||||
Due to this being a newer feature, this may not be implemented for all chains, but this is the direction we are moving.
|
||||
|
@ -1,17 +1,17 @@
|
||||
"""Chain for chatting with a vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT
|
||||
from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
@ -58,10 +58,10 @@ class ChatVectorDBChain(Chain, BaseModel):
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
vectorstore: VectorStore,
|
||||
condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT,
|
||||
qa_prompt: BasePromptTemplate = QA_PROMPT,
|
||||
qa_prompt: Optional[BasePromptTemplate] = None,
|
||||
chain_type: str = "stuff",
|
||||
**kwargs: Any,
|
||||
) -> ChatVectorDBChain:
|
||||
|
38
langchain/chains/prompt_selector.py
Normal file
38
langchain/chains/prompt_selector.py
Normal file
@ -0,0 +1,38 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class BasePromptSelector(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate:
|
||||
"""Get default prompt for a language model."""
|
||||
|
||||
|
||||
class ConditionalPromptSelector(BasePromptSelector):
|
||||
"""Prompt collection that goes through conditionals."""
|
||||
|
||||
default_prompt: BasePromptTemplate
|
||||
conditionals: List[
|
||||
Tuple[Callable[[BaseLanguageModel], bool], BasePromptTemplate]
|
||||
] = Field(default_factory=list)
|
||||
|
||||
def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate:
|
||||
for condition, prompt in self.conditionals:
|
||||
if condition(llm):
|
||||
return prompt
|
||||
return self.default_prompt
|
||||
|
||||
|
||||
def is_llm(llm: BaseLanguageModel) -> bool:
|
||||
return isinstance(llm, BaseLLM)
|
||||
|
||||
|
||||
def is_chat_model(llm: BaseLanguageModel) -> bool:
|
||||
return isinstance(llm, BaseChatModel)
|
@ -14,19 +14,21 @@ from langchain.chains.question_answering import (
|
||||
refine_prompts,
|
||||
stuff_prompt,
|
||||
)
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.schema import BaseLanguageModel
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
"""Interface for loading the combine documents chain."""
|
||||
|
||||
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
|
||||
def __call__(
|
||||
self, llm: BaseLanguageModel, **kwargs: Any
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Callable to load the combine documents chain."""
|
||||
|
||||
|
||||
def _load_map_rerank_chain(
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = map_rerank_prompt.PROMPT,
|
||||
verbose: bool = False,
|
||||
document_variable_name: str = "context",
|
||||
@ -50,13 +52,14 @@ def _load_map_rerank_chain(
|
||||
|
||||
|
||||
def _load_stuff_chain(
|
||||
llm: BaseLLM,
|
||||
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
document_variable_name: str = "context",
|
||||
verbose: Optional[bool] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> StuffDocumentsChain:
|
||||
_prompt = prompt or stuff_prompt.PROMPT_SELECTOR.get_prompt(llm)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm, prompt=prompt, verbose=verbose, callback_manager=callback_manager
|
||||
)
|
||||
@ -71,28 +74,34 @@ def _load_stuff_chain(
|
||||
|
||||
|
||||
def _load_map_reduce_chain(
|
||||
llm: BaseLLM,
|
||||
question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT,
|
||||
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
|
||||
llm: BaseLanguageModel,
|
||||
question_prompt: Optional[BasePromptTemplate] = None,
|
||||
combine_prompt: Optional[BasePromptTemplate] = None,
|
||||
combine_document_variable_name: str = "summaries",
|
||||
map_reduce_document_variable_name: str = "context",
|
||||
collapse_prompt: Optional[BasePromptTemplate] = None,
|
||||
reduce_llm: Optional[BaseLLM] = None,
|
||||
collapse_llm: Optional[BaseLLM] = None,
|
||||
reduce_llm: Optional[BaseLanguageModel] = None,
|
||||
collapse_llm: Optional[BaseLanguageModel] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> MapReduceDocumentsChain:
|
||||
_question_prompt = (
|
||||
question_prompt or map_reduce_prompt.QUESTION_PROMPT_SELECTOR.get_prompt(llm)
|
||||
)
|
||||
_combine_prompt = (
|
||||
combine_prompt or map_reduce_prompt.COMBINE_PROMPT_SELECTOR.get_prompt(llm)
|
||||
)
|
||||
map_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=question_prompt,
|
||||
prompt=_question_prompt,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
_reduce_llm = reduce_llm or llm
|
||||
reduce_chain = LLMChain(
|
||||
llm=_reduce_llm,
|
||||
prompt=combine_prompt,
|
||||
prompt=_combine_prompt,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
@ -135,26 +144,32 @@ def _load_map_reduce_chain(
|
||||
|
||||
|
||||
def _load_refine_chain(
|
||||
llm: BaseLLM,
|
||||
question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT,
|
||||
refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT,
|
||||
llm: BaseLanguageModel,
|
||||
question_prompt: Optional[BasePromptTemplate] = None,
|
||||
refine_prompt: Optional[BasePromptTemplate] = None,
|
||||
document_variable_name: str = "context_str",
|
||||
initial_response_name: str = "existing_answer",
|
||||
refine_llm: Optional[BaseLLM] = None,
|
||||
refine_llm: Optional[BaseLanguageModel] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> RefineDocumentsChain:
|
||||
_question_prompt = (
|
||||
question_prompt or refine_prompts.REFINE_PROMPT_SELECTOR.get_prompt(llm)
|
||||
)
|
||||
_refine_prompt = refine_prompt or refine_prompts.REFINE_PROMPT_SELECTOR.get_prompt(
|
||||
llm
|
||||
)
|
||||
initial_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=question_prompt,
|
||||
prompt=_question_prompt,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
_refine_llm = refine_llm or llm
|
||||
refine_chain = LLMChain(
|
||||
llm=_refine_llm,
|
||||
prompt=refine_prompt,
|
||||
prompt=_refine_prompt,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
@ -170,7 +185,7 @@ def _load_refine_chain(
|
||||
|
||||
|
||||
def load_qa_chain(
|
||||
llm: BaseLLM,
|
||||
llm: BaseLanguageModel,
|
||||
chain_type: str = "stuff",
|
||||
verbose: Optional[bool] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
|
@ -1,5 +1,14 @@
|
||||
# flake8: noqa
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.prompts.chat import (
|
||||
SystemMessagePromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
)
|
||||
from langchain.chains.prompt_selector import (
|
||||
ConditionalPromptSelector,
|
||||
is_chat_model,
|
||||
)
|
||||
|
||||
question_prompt_template = """Use the following portion of a long document to see if any of the text is relevant to answer the question.
|
||||
Return any relevant text verbatim.
|
||||
@ -9,6 +18,20 @@ Relevant text, if any:"""
|
||||
QUESTION_PROMPT = PromptTemplate(
|
||||
template=question_prompt_template, input_variables=["context", "question"]
|
||||
)
|
||||
system_template = """Use the following portion of a long document to see if any of the text is relevant to answer the question.
|
||||
Return any relevant text verbatim.
|
||||
______________________
|
||||
{context}"""
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(system_template),
|
||||
HumanMessagePromptTemplate.from_template("{question}"),
|
||||
]
|
||||
CHAT_QUESTION_PROMPT = ChatPromptTemplate.from_messages(messages)
|
||||
|
||||
|
||||
QUESTION_PROMPT_SELECTOR = ConditionalPromptSelector(
|
||||
default_prompt=QUESTION_PROMPT, conditionals=[(is_chat_model, CHAT_QUESTION_PROMPT)]
|
||||
)
|
||||
|
||||
combine_prompt_template = """Given the following extracted parts of a long document and a question, create a final answer.
|
||||
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
|
||||
@ -43,3 +66,18 @@ FINAL ANSWER:"""
|
||||
COMBINE_PROMPT = PromptTemplate(
|
||||
template=combine_prompt_template, input_variables=["summaries", "question"]
|
||||
)
|
||||
|
||||
system_template = """Given the following extracted parts of a long document and a question, create a final answer.
|
||||
If you don't know the answer, just say that you don't know. Don't try to make up an answer.
|
||||
______________________
|
||||
{summaries}"""
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(system_template),
|
||||
HumanMessagePromptTemplate.from_template("{question}"),
|
||||
]
|
||||
CHAT_COMBINE_PROMPT = ChatPromptTemplate.from_messages(messages)
|
||||
|
||||
|
||||
COMBINE_PROMPT_SELECTOR = ConditionalPromptSelector(
|
||||
default_prompt=COMBINE_PROMPT, conditionals=[(is_chat_model, CHAT_COMBINE_PROMPT)]
|
||||
)
|
||||
|
@ -1,5 +1,16 @@
|
||||
# flake8: noqa
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.prompts.chat import (
|
||||
SystemMessagePromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
AIMessagePromptTemplate,
|
||||
)
|
||||
from langchain.chains.prompt_selector import (
|
||||
ConditionalPromptSelector,
|
||||
is_chat_model,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_REFINE_PROMPT_TMPL = (
|
||||
"The original question is as follows: {question}\n"
|
||||
@ -17,6 +28,26 @@ DEFAULT_REFINE_PROMPT = PromptTemplate(
|
||||
input_variables=["question", "existing_answer", "context_str"],
|
||||
template=DEFAULT_REFINE_PROMPT_TMPL,
|
||||
)
|
||||
refine_template = (
|
||||
"We have the opportunity to refine the existing answer"
|
||||
"(only if needed) with some more context below.\n"
|
||||
"------------\n"
|
||||
"{context_str}\n"
|
||||
"------------\n"
|
||||
"Given the new context, refine the original answer to better "
|
||||
"answer the question. "
|
||||
"If the context isn't useful, return the original answer."
|
||||
)
|
||||
messages = [
|
||||
HumanMessagePromptTemplate.from_template("{question}"),
|
||||
AIMessagePromptTemplate.from_template("{existing_answer}"),
|
||||
HumanMessagePromptTemplate.from_template(refine_template),
|
||||
]
|
||||
CHAT_REFINE_PROMPT = ChatPromptTemplate.from_messages(messages)
|
||||
REFINE_PROMPT_SELECTOR = ConditionalPromptSelector(
|
||||
default_prompt=DEFAULT_REFINE_PROMPT,
|
||||
conditionals=[(is_chat_model, CHAT_REFINE_PROMPT)],
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_TEXT_QA_PROMPT_TMPL = (
|
||||
@ -30,3 +61,20 @@ DEFAULT_TEXT_QA_PROMPT_TMPL = (
|
||||
DEFAULT_TEXT_QA_PROMPT = PromptTemplate(
|
||||
input_variables=["context_str", "question"], template=DEFAULT_TEXT_QA_PROMPT_TMPL
|
||||
)
|
||||
chat_qa_prompt_template = (
|
||||
"Context information is below. \n"
|
||||
"---------------------\n"
|
||||
"{context_str}"
|
||||
"\n---------------------\n"
|
||||
"Given the context information and not prior knowledge, "
|
||||
"answer any questions"
|
||||
)
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(chat_qa_prompt_template),
|
||||
HumanMessagePromptTemplate.from_template("{question}"),
|
||||
]
|
||||
CHAT_QUESTION_PROMPT = ChatPromptTemplate.from_messages(messages)
|
||||
QUESTION_PROMPT_SELECTOR = ConditionalPromptSelector(
|
||||
default_prompt=DEFAULT_TEXT_QA_PROMPT,
|
||||
conditionals=[(is_chat_model, CHAT_QUESTION_PROMPT)],
|
||||
)
|
||||
|
@ -1,5 +1,15 @@
|
||||
# flake8: noqa
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.chains.prompt_selector import (
|
||||
ConditionalPromptSelector,
|
||||
is_chat_model,
|
||||
)
|
||||
from langchain.prompts.chat import (
|
||||
ChatPromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
)
|
||||
|
||||
|
||||
prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
|
||||
@ -10,3 +20,18 @@ Helpful Answer:"""
|
||||
PROMPT = PromptTemplate(
|
||||
template=prompt_template, input_variables=["context", "question"]
|
||||
)
|
||||
|
||||
system_template = """Use the following pieces of context to answer the users question.
|
||||
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
||||
----------------
|
||||
{context}"""
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(system_template),
|
||||
HumanMessagePromptTemplate.from_template("{question}"),
|
||||
]
|
||||
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
|
||||
|
||||
|
||||
PROMPT_SELECTOR = ConditionalPromptSelector(
|
||||
default_prompt=PROMPT, conditionals=[(is_chat_model, CHAT_PROMPT)]
|
||||
)
|
||||
|
@ -10,7 +10,7 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from langchain.chains.vector_db_qa.prompt import PROMPT
|
||||
from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
@ -78,8 +78,8 @@ class VectorDBQA(Chain, BaseModel):
|
||||
raise ValueError(
|
||||
"If `combine_documents_chain` not provided, `llm` should be."
|
||||
)
|
||||
prompt = values.pop("prompt", PROMPT)
|
||||
llm = values.pop("llm")
|
||||
prompt = values.pop("prompt", PROMPT_SELECTOR.get_prompt(llm))
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
document_prompt = PromptTemplate(
|
||||
input_variables=["page_content"], template="Context:\n{page_content}"
|
||||
@ -103,10 +103,11 @@ class VectorDBQA(Chain, BaseModel):
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls, llm: BaseLLM, prompt: PromptTemplate = PROMPT, **kwargs: Any
|
||||
cls, llm: BaseLLM, prompt: Optional[PromptTemplate] = None, **kwargs: Any
|
||||
) -> VectorDBQA:
|
||||
"""Initialize from LLM."""
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)
|
||||
llm_chain = LLMChain(llm=llm, prompt=_prompt)
|
||||
document_prompt = PromptTemplate(
|
||||
input_variables=["page_content"], template="Context:\n{page_content}"
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user