|
|
|
@ -13,19 +13,21 @@ from langchain.chains.qa_with_sources import (
|
|
|
|
|
stuff_prompt,
|
|
|
|
|
)
|
|
|
|
|
from langchain.chains.question_answering import map_rerank_prompt
|
|
|
|
|
from langchain.llms.base import BaseLLM
|
|
|
|
|
from langchain.prompts.base import BasePromptTemplate
|
|
|
|
|
from langchain.schema import BaseLanguageModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoadingCallable(Protocol):
|
|
|
|
|
"""Interface for loading the combine documents chain."""
|
|
|
|
|
|
|
|
|
|
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
|
|
|
|
|
def __call__(
|
|
|
|
|
self, llm: BaseLanguageModel, **kwargs: Any
|
|
|
|
|
) -> BaseCombineDocumentsChain:
|
|
|
|
|
"""Callable to load the combine documents chain."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_map_rerank_chain(
|
|
|
|
|
llm: BaseLLM,
|
|
|
|
|
llm: BaseLanguageModel,
|
|
|
|
|
prompt: BasePromptTemplate = map_rerank_prompt.PROMPT,
|
|
|
|
|
verbose: bool = False,
|
|
|
|
|
document_variable_name: str = "context",
|
|
|
|
@ -44,7 +46,7 @@ def _load_map_rerank_chain(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_stuff_chain(
|
|
|
|
|
llm: BaseLLM,
|
|
|
|
|
llm: BaseLanguageModel,
|
|
|
|
|
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
|
|
|
|
document_prompt: BasePromptTemplate = stuff_prompt.EXAMPLE_PROMPT,
|
|
|
|
|
document_variable_name: str = "summaries",
|
|
|
|
@ -62,15 +64,15 @@ def _load_stuff_chain(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_map_reduce_chain(
|
|
|
|
|
llm: BaseLLM,
|
|
|
|
|
llm: BaseLanguageModel,
|
|
|
|
|
question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT,
|
|
|
|
|
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
|
|
|
|
|
document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT,
|
|
|
|
|
combine_document_variable_name: str = "summaries",
|
|
|
|
|
map_reduce_document_variable_name: str = "context",
|
|
|
|
|
collapse_prompt: Optional[BasePromptTemplate] = None,
|
|
|
|
|
reduce_llm: Optional[BaseLLM] = None,
|
|
|
|
|
collapse_llm: Optional[BaseLLM] = None,
|
|
|
|
|
reduce_llm: Optional[BaseLanguageModel] = None,
|
|
|
|
|
collapse_llm: Optional[BaseLanguageModel] = None,
|
|
|
|
|
verbose: Optional[bool] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> MapReduceDocumentsChain:
|
|
|
|
@ -112,13 +114,13 @@ def _load_map_reduce_chain(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_refine_chain(
|
|
|
|
|
llm: BaseLLM,
|
|
|
|
|
llm: BaseLanguageModel,
|
|
|
|
|
question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT,
|
|
|
|
|
refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT,
|
|
|
|
|
document_prompt: BasePromptTemplate = refine_prompts.EXAMPLE_PROMPT,
|
|
|
|
|
document_variable_name: str = "context_str",
|
|
|
|
|
initial_response_name: str = "existing_answer",
|
|
|
|
|
refine_llm: Optional[BaseLLM] = None,
|
|
|
|
|
refine_llm: Optional[BaseLanguageModel] = None,
|
|
|
|
|
verbose: Optional[bool] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> RefineDocumentsChain:
|
|
|
|
@ -137,7 +139,7 @@ def _load_refine_chain(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_qa_with_sources_chain(
|
|
|
|
|
llm: BaseLLM,
|
|
|
|
|
llm: BaseLanguageModel,
|
|
|
|
|
chain_type: str = "stuff",
|
|
|
|
|
verbose: Optional[bool] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|