diff --git a/docs/modules/chains/key_concepts.md b/docs/modules/chains/key_concepts.md index d8c0e734..3e781870 100644 --- a/docs/modules/chains/key_concepts.md +++ b/docs/modules/chains/key_concepts.md @@ -9,3 +9,12 @@ This is a specific type of chain where multiple other chains are run in sequence to the next. A subtype of this type of chain is the [`SimpleSequentialChain`](./generic/sequential_chains.html#simplesequentialchain), where all subchains have only one input and one output, and the output of one is therefore used as sole input to the next chain. +## Prompt Selectors +One thing that we've noticed is that the best prompt to use is really dependent on the model you use. +Some prompts work really good with some models, but not great with others. +One of our goals is provide good chains that "just work" out of the box. +A big part of chains like that is having prompts that "just work". +So rather than having a default prompt for chains, we are moving towards a paradigm where if a prompt is not explicitly +provided we select one with a PromptSelector. This class takes in the model passed in, and returns a default prompt. +The inner workings of the PromptSelector can look at any aspect of the model - LLM vs ChatModel, OpenAI vs Cohere, GPT3 vs GPT4, etc. +Due to this being a newer feature, this may not be implemented for all chains, but this is the direction we are moving. diff --git a/langchain/chains/chat_vector_db/base.py b/langchain/chains/chat_vector_db/base.py index d030c6ac..22cbc5dd 100644 --- a/langchain/chains/chat_vector_db/base.py +++ b/langchain/chains/chat_vector_db/base.py @@ -1,17 +1,17 @@ """Chain for chatting with a vector database.""" from __future__ import annotations -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple from pydantic import BaseModel from langchain.chains.base import Chain -from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT +from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.llm import LLMChain from langchain.chains.question_answering import load_qa_chain -from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BaseLanguageModel from langchain.vectorstores.base import VectorStore @@ -58,10 +58,10 @@ class ChatVectorDBChain(Chain, BaseModel): @classmethod def from_llm( cls, - llm: BaseLLM, + llm: BaseLanguageModel, vectorstore: VectorStore, condense_question_prompt: BasePromptTemplate = CONDENSE_QUESTION_PROMPT, - qa_prompt: BasePromptTemplate = QA_PROMPT, + qa_prompt: Optional[BasePromptTemplate] = None, chain_type: str = "stuff", **kwargs: Any, ) -> ChatVectorDBChain: diff --git a/langchain/chains/prompt_selector.py b/langchain/chains/prompt_selector.py new file mode 100644 index 00000000..190907cc --- /dev/null +++ b/langchain/chains/prompt_selector.py @@ -0,0 +1,38 @@ +from abc import ABC, abstractmethod +from typing import Callable, List, Tuple + +from pydantic import BaseModel, Field + +from langchain.chat_models.base import BaseChatModel +from langchain.llms.base import BaseLLM +from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BaseLanguageModel + + +class BasePromptSelector(BaseModel, ABC): + @abstractmethod + def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate: + """Get default prompt for a language model.""" + + +class ConditionalPromptSelector(BasePromptSelector): + """Prompt collection that goes through conditionals.""" + + default_prompt: BasePromptTemplate + conditionals: List[ + Tuple[Callable[[BaseLanguageModel], bool], BasePromptTemplate] + ] = Field(default_factory=list) + + def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate: + for condition, prompt in self.conditionals: + if condition(llm): + return prompt + return self.default_prompt + + +def is_llm(llm: BaseLanguageModel) -> bool: + return isinstance(llm, BaseLLM) + + +def is_chat_model(llm: BaseLanguageModel) -> bool: + return isinstance(llm, BaseChatModel) diff --git a/langchain/chains/question_answering/__init__.py b/langchain/chains/question_answering/__init__.py index 101a6156..8041a038 100644 --- a/langchain/chains/question_answering/__init__.py +++ b/langchain/chains/question_answering/__init__.py @@ -14,19 +14,21 @@ from langchain.chains.question_answering import ( refine_prompts, stuff_prompt, ) -from langchain.llms.base import BaseLLM from langchain.prompts.base import BasePromptTemplate +from langchain.schema import BaseLanguageModel class LoadingCallable(Protocol): """Interface for loading the combine documents chain.""" - def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain: + def __call__( + self, llm: BaseLanguageModel, **kwargs: Any + ) -> BaseCombineDocumentsChain: """Callable to load the combine documents chain.""" def _load_map_rerank_chain( - llm: BaseLLM, + llm: BaseLanguageModel, prompt: BasePromptTemplate = map_rerank_prompt.PROMPT, verbose: bool = False, document_variable_name: str = "context", @@ -50,13 +52,14 @@ def _load_map_rerank_chain( def _load_stuff_chain( - llm: BaseLLM, - prompt: BasePromptTemplate = stuff_prompt.PROMPT, + llm: BaseLanguageModel, + prompt: Optional[BasePromptTemplate] = None, document_variable_name: str = "context", verbose: Optional[bool] = None, callback_manager: Optional[BaseCallbackManager] = None, **kwargs: Any, ) -> StuffDocumentsChain: + _prompt = prompt or stuff_prompt.PROMPT_SELECTOR.get_prompt(llm) llm_chain = LLMChain( llm=llm, prompt=prompt, verbose=verbose, callback_manager=callback_manager ) @@ -71,28 +74,34 @@ def _load_stuff_chain( def _load_map_reduce_chain( - llm: BaseLLM, - question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT, - combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT, + llm: BaseLanguageModel, + question_prompt: Optional[BasePromptTemplate] = None, + combine_prompt: Optional[BasePromptTemplate] = None, combine_document_variable_name: str = "summaries", map_reduce_document_variable_name: str = "context", collapse_prompt: Optional[BasePromptTemplate] = None, - reduce_llm: Optional[BaseLLM] = None, - collapse_llm: Optional[BaseLLM] = None, + reduce_llm: Optional[BaseLanguageModel] = None, + collapse_llm: Optional[BaseLanguageModel] = None, verbose: Optional[bool] = None, callback_manager: Optional[BaseCallbackManager] = None, **kwargs: Any, ) -> MapReduceDocumentsChain: + _question_prompt = ( + question_prompt or map_reduce_prompt.QUESTION_PROMPT_SELECTOR.get_prompt(llm) + ) + _combine_prompt = ( + combine_prompt or map_reduce_prompt.COMBINE_PROMPT_SELECTOR.get_prompt(llm) + ) map_chain = LLMChain( llm=llm, - prompt=question_prompt, + prompt=_question_prompt, verbose=verbose, callback_manager=callback_manager, ) _reduce_llm = reduce_llm or llm reduce_chain = LLMChain( llm=_reduce_llm, - prompt=combine_prompt, + prompt=_combine_prompt, verbose=verbose, callback_manager=callback_manager, ) @@ -135,26 +144,32 @@ def _load_map_reduce_chain( def _load_refine_chain( - llm: BaseLLM, - question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT, - refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT, + llm: BaseLanguageModel, + question_prompt: Optional[BasePromptTemplate] = None, + refine_prompt: Optional[BasePromptTemplate] = None, document_variable_name: str = "context_str", initial_response_name: str = "existing_answer", - refine_llm: Optional[BaseLLM] = None, + refine_llm: Optional[BaseLanguageModel] = None, verbose: Optional[bool] = None, callback_manager: Optional[BaseCallbackManager] = None, **kwargs: Any, ) -> RefineDocumentsChain: + _question_prompt = ( + question_prompt or refine_prompts.REFINE_PROMPT_SELECTOR.get_prompt(llm) + ) + _refine_prompt = refine_prompt or refine_prompts.REFINE_PROMPT_SELECTOR.get_prompt( + llm + ) initial_chain = LLMChain( llm=llm, - prompt=question_prompt, + prompt=_question_prompt, verbose=verbose, callback_manager=callback_manager, ) _refine_llm = refine_llm or llm refine_chain = LLMChain( llm=_refine_llm, - prompt=refine_prompt, + prompt=_refine_prompt, verbose=verbose, callback_manager=callback_manager, ) @@ -170,7 +185,7 @@ def _load_refine_chain( def load_qa_chain( - llm: BaseLLM, + llm: BaseLanguageModel, chain_type: str = "stuff", verbose: Optional[bool] = None, callback_manager: Optional[BaseCallbackManager] = None, diff --git a/langchain/chains/question_answering/map_reduce_prompt.py b/langchain/chains/question_answering/map_reduce_prompt.py index 6268050c..7c0efd77 100644 --- a/langchain/chains/question_answering/map_reduce_prompt.py +++ b/langchain/chains/question_answering/map_reduce_prompt.py @@ -1,5 +1,14 @@ # flake8: noqa -from langchain.prompts import PromptTemplate +from langchain.prompts.prompt import PromptTemplate +from langchain.prompts.chat import ( + SystemMessagePromptTemplate, + HumanMessagePromptTemplate, + ChatPromptTemplate, +) +from langchain.chains.prompt_selector import ( + ConditionalPromptSelector, + is_chat_model, +) question_prompt_template = """Use the following portion of a long document to see if any of the text is relevant to answer the question. Return any relevant text verbatim. @@ -9,6 +18,20 @@ Relevant text, if any:""" QUESTION_PROMPT = PromptTemplate( template=question_prompt_template, input_variables=["context", "question"] ) +system_template = """Use the following portion of a long document to see if any of the text is relevant to answer the question. +Return any relevant text verbatim. +______________________ +{context}""" +messages = [ + SystemMessagePromptTemplate.from_template(system_template), + HumanMessagePromptTemplate.from_template("{question}"), +] +CHAT_QUESTION_PROMPT = ChatPromptTemplate.from_messages(messages) + + +QUESTION_PROMPT_SELECTOR = ConditionalPromptSelector( + default_prompt=QUESTION_PROMPT, conditionals=[(is_chat_model, CHAT_QUESTION_PROMPT)] +) combine_prompt_template = """Given the following extracted parts of a long document and a question, create a final answer. If you don't know the answer, just say that you don't know. Don't try to make up an answer. @@ -43,3 +66,18 @@ FINAL ANSWER:""" COMBINE_PROMPT = PromptTemplate( template=combine_prompt_template, input_variables=["summaries", "question"] ) + +system_template = """Given the following extracted parts of a long document and a question, create a final answer. +If you don't know the answer, just say that you don't know. Don't try to make up an answer. +______________________ +{summaries}""" +messages = [ + SystemMessagePromptTemplate.from_template(system_template), + HumanMessagePromptTemplate.from_template("{question}"), +] +CHAT_COMBINE_PROMPT = ChatPromptTemplate.from_messages(messages) + + +COMBINE_PROMPT_SELECTOR = ConditionalPromptSelector( + default_prompt=COMBINE_PROMPT, conditionals=[(is_chat_model, CHAT_COMBINE_PROMPT)] +) diff --git a/langchain/chains/question_answering/refine_prompts.py b/langchain/chains/question_answering/refine_prompts.py index 5bfbe28b..78c6dd77 100644 --- a/langchain/chains/question_answering/refine_prompts.py +++ b/langchain/chains/question_answering/refine_prompts.py @@ -1,5 +1,16 @@ # flake8: noqa -from langchain.prompts import PromptTemplate +from langchain.prompts.prompt import PromptTemplate +from langchain.prompts.chat import ( + SystemMessagePromptTemplate, + HumanMessagePromptTemplate, + ChatPromptTemplate, + AIMessagePromptTemplate, +) +from langchain.chains.prompt_selector import ( + ConditionalPromptSelector, + is_chat_model, +) + DEFAULT_REFINE_PROMPT_TMPL = ( "The original question is as follows: {question}\n" @@ -17,6 +28,26 @@ DEFAULT_REFINE_PROMPT = PromptTemplate( input_variables=["question", "existing_answer", "context_str"], template=DEFAULT_REFINE_PROMPT_TMPL, ) +refine_template = ( + "We have the opportunity to refine the existing answer" + "(only if needed) with some more context below.\n" + "------------\n" + "{context_str}\n" + "------------\n" + "Given the new context, refine the original answer to better " + "answer the question. " + "If the context isn't useful, return the original answer." +) +messages = [ + HumanMessagePromptTemplate.from_template("{question}"), + AIMessagePromptTemplate.from_template("{existing_answer}"), + HumanMessagePromptTemplate.from_template(refine_template), +] +CHAT_REFINE_PROMPT = ChatPromptTemplate.from_messages(messages) +REFINE_PROMPT_SELECTOR = ConditionalPromptSelector( + default_prompt=DEFAULT_REFINE_PROMPT, + conditionals=[(is_chat_model, CHAT_REFINE_PROMPT)], +) DEFAULT_TEXT_QA_PROMPT_TMPL = ( @@ -30,3 +61,20 @@ DEFAULT_TEXT_QA_PROMPT_TMPL = ( DEFAULT_TEXT_QA_PROMPT = PromptTemplate( input_variables=["context_str", "question"], template=DEFAULT_TEXT_QA_PROMPT_TMPL ) +chat_qa_prompt_template = ( + "Context information is below. \n" + "---------------------\n" + "{context_str}" + "\n---------------------\n" + "Given the context information and not prior knowledge, " + "answer any questions" +) +messages = [ + SystemMessagePromptTemplate.from_template(chat_qa_prompt_template), + HumanMessagePromptTemplate.from_template("{question}"), +] +CHAT_QUESTION_PROMPT = ChatPromptTemplate.from_messages(messages) +QUESTION_PROMPT_SELECTOR = ConditionalPromptSelector( + default_prompt=DEFAULT_TEXT_QA_PROMPT, + conditionals=[(is_chat_model, CHAT_QUESTION_PROMPT)], +) diff --git a/langchain/chains/question_answering/stuff_prompt.py b/langchain/chains/question_answering/stuff_prompt.py index 9ebb89ea..968d2950 100644 --- a/langchain/chains/question_answering/stuff_prompt.py +++ b/langchain/chains/question_answering/stuff_prompt.py @@ -1,5 +1,15 @@ # flake8: noqa from langchain.prompts import PromptTemplate +from langchain.chains.prompt_selector import ( + ConditionalPromptSelector, + is_chat_model, +) +from langchain.prompts.chat import ( + ChatPromptTemplate, + SystemMessagePromptTemplate, + HumanMessagePromptTemplate, +) + prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. @@ -10,3 +20,18 @@ Helpful Answer:""" PROMPT = PromptTemplate( template=prompt_template, input_variables=["context", "question"] ) + +system_template = """Use the following pieces of context to answer the users question. +If you don't know the answer, just say that you don't know, don't try to make up an answer. +---------------- +{context}""" +messages = [ + SystemMessagePromptTemplate.from_template(system_template), + HumanMessagePromptTemplate.from_template("{question}"), +] +CHAT_PROMPT = ChatPromptTemplate.from_messages(messages) + + +PROMPT_SELECTOR = ConditionalPromptSelector( + default_prompt=PROMPT, conditionals=[(is_chat_model, CHAT_PROMPT)] +) diff --git a/langchain/chains/vector_db_qa/base.py b/langchain/chains/vector_db_qa/base.py index 882e05a1..3da04659 100644 --- a/langchain/chains/vector_db_qa/base.py +++ b/langchain/chains/vector_db_qa/base.py @@ -10,7 +10,7 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.llm import LLMChain from langchain.chains.question_answering import load_qa_chain -from langchain.chains.vector_db_qa.prompt import PROMPT +from langchain.chains.question_answering.stuff_prompt import PROMPT_SELECTOR from langchain.llms.base import BaseLLM from langchain.prompts import PromptTemplate from langchain.vectorstores.base import VectorStore @@ -78,8 +78,8 @@ class VectorDBQA(Chain, BaseModel): raise ValueError( "If `combine_documents_chain` not provided, `llm` should be." ) - prompt = values.pop("prompt", PROMPT) llm = values.pop("llm") + prompt = values.pop("prompt", PROMPT_SELECTOR.get_prompt(llm)) llm_chain = LLMChain(llm=llm, prompt=prompt) document_prompt = PromptTemplate( input_variables=["page_content"], template="Context:\n{page_content}" @@ -103,10 +103,11 @@ class VectorDBQA(Chain, BaseModel): @classmethod def from_llm( - cls, llm: BaseLLM, prompt: PromptTemplate = PROMPT, **kwargs: Any + cls, llm: BaseLLM, prompt: Optional[PromptTemplate] = None, **kwargs: Any ) -> VectorDBQA: """Initialize from LLM.""" - llm_chain = LLMChain(llm=llm, prompt=prompt) + _prompt = prompt or PROMPT_SELECTOR.get_prompt(llm) + llm_chain = LLMChain(llm=llm, prompt=_prompt) document_prompt = PromptTemplate( input_variables=["page_content"], template="Context:\n{page_content}" )