forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
108 lines
3.8 KiB
Python
108 lines
3.8 KiB
Python
"""Load summarizing chains."""
|
|
from typing import Any, Mapping, Optional, Protocol
|
|
|
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
|
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
|
|
from langchain.chains.combine_documents.refine import RefineDocumentsChain
|
|
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
|
from langchain.chains.llm import LLMChain
|
|
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
|
|
from langchain.llms.base import BaseLLM
|
|
from langchain.prompts.base import BasePromptTemplate
|
|
|
|
|
|
class LoadingCallable(Protocol):
|
|
"""Interface for loading the combine documents chain."""
|
|
|
|
def __call__(self, llm: BaseLLM, **kwargs: Any) -> BaseCombineDocumentsChain:
|
|
"""Callable to load the combine documents chain."""
|
|
|
|
|
|
def _load_stuff_chain(
|
|
llm: BaseLLM,
|
|
prompt: BasePromptTemplate = stuff_prompt.PROMPT,
|
|
document_variable_name: str = "text",
|
|
**kwargs: Any,
|
|
) -> StuffDocumentsChain:
|
|
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
|
# TODO: document prompt
|
|
return StuffDocumentsChain(
|
|
llm_chain=llm_chain, document_variable_name=document_variable_name, **kwargs
|
|
)
|
|
|
|
|
|
def _load_map_reduce_chain(
|
|
llm: BaseLLM,
|
|
map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
|
|
combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
|
|
combine_document_variable_name: str = "text",
|
|
map_reduce_document_variable_name: str = "text",
|
|
collapse_prompt: Optional[BasePromptTemplate] = None,
|
|
**kwargs: Any,
|
|
) -> MapReduceDocumentsChain:
|
|
map_chain = LLMChain(llm=llm, prompt=map_prompt)
|
|
reduce_chain = LLMChain(llm=llm, prompt=combine_prompt)
|
|
# TODO: document prompt
|
|
combine_document_chain = StuffDocumentsChain(
|
|
llm_chain=reduce_chain, document_variable_name=combine_document_variable_name
|
|
)
|
|
if collapse_prompt is None:
|
|
collapse_chain = None
|
|
else:
|
|
collapse_chain = StuffDocumentsChain(
|
|
llm_chain=LLMChain(llm=llm, prompt=collapse_prompt),
|
|
document_variable_name=combine_document_variable_name,
|
|
)
|
|
return MapReduceDocumentsChain(
|
|
llm_chain=map_chain,
|
|
combine_document_chain=combine_document_chain,
|
|
document_variable_name=map_reduce_document_variable_name,
|
|
collapse_document_chain=collapse_chain,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
def _load_refine_chain(
|
|
llm: BaseLLM,
|
|
question_prompt: BasePromptTemplate = refine_prompts.PROMPT,
|
|
refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT,
|
|
document_variable_name: str = "text",
|
|
initial_response_name: str = "existing_answer",
|
|
**kwargs: Any,
|
|
) -> RefineDocumentsChain:
|
|
initial_chain = LLMChain(llm=llm, prompt=question_prompt)
|
|
refine_chain = LLMChain(llm=llm, prompt=refine_prompt)
|
|
return RefineDocumentsChain(
|
|
initial_llm_chain=initial_chain,
|
|
refine_llm_chain=refine_chain,
|
|
document_variable_name=document_variable_name,
|
|
initial_response_name=initial_response_name,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
def load_summarize_chain(
|
|
llm: BaseLLM, chain_type: str = "stuff", **kwargs: Any
|
|
) -> BaseCombineDocumentsChain:
|
|
"""Load summarizing chain.
|
|
|
|
Args:
|
|
llm: Language Model to use in the chain.
|
|
chain_type: Type of document combining chain to use. Should be one of "stuff",
|
|
"map_reduce", and "refine".
|
|
|
|
Returns:
|
|
A chain to use for summarizing.
|
|
"""
|
|
loader_mapping: Mapping[str, LoadingCallable] = {
|
|
"stuff": _load_stuff_chain,
|
|
"map_reduce": _load_map_reduce_chain,
|
|
"refine": _load_refine_chain,
|
|
}
|
|
if chain_type not in loader_mapping:
|
|
raise ValueError(
|
|
f"Got unsupported chain type: {chain_type}. "
|
|
f"Should be one of {loader_mapping.keys()}"
|
|
)
|
|
return loader_mapping[chain_type](llm, **kwargs)
|