You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/chains/qa_with_sources/__init__.py

120 lines
4.3 KiB
Python

"""Load question answering with sources chains."""
from typing import Any, Mapping, Optional, Protocol
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.combine_documents.refine import RefineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.qa_with_sources import (
map_reduce_prompt,
refine_prompts,
stuff_prompt,
)
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
class LoadingCallable(Protocol):
"""Interface for loading the combine documents chain."""
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
"""Callable to load the combine documents chain."""
def _load_stuff_chain(
llm: BaseLLM,
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
document_variable_name: str = "summaries",
**kwargs: Any,
) -> StuffDocumentsChain:
llm_chain = LLMChain(llm=llm, prompt=prompt)
return StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name=document_variable_name,
document_prompt=stuff_prompt.EXAMPLE_PROMPT,
**kwargs,
)
def _load_map_reduce_chain(
llm: BaseLLM,
question_prompt: BasePromptTemplate = map_reduce_prompt.QUESTION_PROMPT,
combine_prompt: BasePromptTemplate = map_reduce_prompt.COMBINE_PROMPT,
document_prompt: BasePromptTemplate = map_reduce_prompt.EXAMPLE_PROMPT,
combine_document_variable_name: str = "summaries",
map_reduce_document_variable_name: str = "context",
collapse_prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any,
) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=question_prompt)
reduce_chain = LLMChain(llm=llm, prompt=combine_prompt)
combine_document_chain = StuffDocumentsChain(
llm_chain=reduce_chain,
document_variable_name=combine_document_variable_name,
document_prompt=document_prompt,
)
if collapse_prompt is None:
collapse_chain = None
else:
collapse_chain = StuffDocumentsChain(
llm_chain=LLMChain(llm=llm, prompt=collapse_prompt),
document_variable_name=combine_document_variable_name,
document_prompt=document_prompt,
)
return MapReduceDocumentsChain(
llm_chain=map_chain,
combine_document_chain=combine_document_chain,
document_variable_name=map_reduce_document_variable_name,
collapse_document_chain=collapse_chain,
**kwargs,
)
def _load_refine_chain(
llm: BaseLLM,
question_prompt: BasePromptTemplate = refine_prompts.DEFAULT_TEXT_QA_PROMPT,
refine_prompt: BasePromptTemplate = refine_prompts.DEFAULT_REFINE_PROMPT,
document_prompt: BasePromptTemplate = refine_prompts.EXAMPLE_PROMPT,
document_variable_name: str = "context_str",
initial_response_name: str = "existing_answer",
**kwargs: Any,
) -> RefineDocumentsChain:
initial_chain = LLMChain(llm=llm, prompt=question_prompt)
refine_chain = LLMChain(llm=llm, prompt=refine_prompt)
return RefineDocumentsChain(
initial_llm_chain=initial_chain,
refine_llm_chain=refine_chain,
document_variable_name=document_variable_name,
initial_response_name=initial_response_name,
document_prompt=document_prompt,
**kwargs,
)
def load_qa_with_sources_chain(
llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any
) -> BaseCombineDocumentsChain:
"""Load question answering with sources chain.
Args:
llm: Language Model to use in the chain.
chain_type: Type of document combining chain to use. Should be one of "stuff",
"map_reduce", and "refine".
Returns:
A chain to use for question answering with sources.
"""
loader_mapping: Mapping[str, LoadingCallable] = {
"stuff": _load_stuff_chain,
"map_reduce": _load_map_reduce_chain,
"refine": _load_refine_chain,
}
if chain_type not in loader_mapping:
raise ValueError(
f"Got unsupported chain type: {chain_type}. "
f"Should be one of {loader_mapping.keys()}"
)
_func: LoadingCallable = loader_mapping[chain_type]
return _func(llm, **kwargs)