forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
242 lines
9.1 KiB
Python
242 lines
9.1 KiB
Python
"""Combining documents by mapping a chain over them first, then combining results."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple
|
|
|
|
from pydantic import Extra, root_validator
|
|
|
|
from langchain.callbacks.manager import Callbacks
|
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
|
from langchain.chains.llm import LLMChain
|
|
from langchain.docstore.document import Document
|
|
|
|
|
|
class CombineDocsProtocol(Protocol):
|
|
"""Interface for the combine_docs method."""
|
|
|
|
def __call__(self, docs: List[Document], **kwargs: Any) -> str:
|
|
"""Interface for the combine_docs method."""
|
|
|
|
|
|
def _split_list_of_docs(
|
|
docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any
|
|
) -> List[List[Document]]:
|
|
new_result_doc_list = []
|
|
_sub_result_docs = []
|
|
for doc in docs:
|
|
_sub_result_docs.append(doc)
|
|
_num_tokens = length_func(_sub_result_docs, **kwargs)
|
|
if _num_tokens > token_max:
|
|
if len(_sub_result_docs) == 1:
|
|
raise ValueError(
|
|
"A single document was longer than the context length,"
|
|
" we cannot handle this."
|
|
)
|
|
if len(_sub_result_docs) == 2:
|
|
raise ValueError(
|
|
"A single document was so long it could not be combined "
|
|
"with another document, we cannot handle this."
|
|
)
|
|
new_result_doc_list.append(_sub_result_docs[:-1])
|
|
_sub_result_docs = _sub_result_docs[-1:]
|
|
new_result_doc_list.append(_sub_result_docs)
|
|
return new_result_doc_list
|
|
|
|
|
|
def _collapse_docs(
|
|
docs: List[Document],
|
|
combine_document_func: CombineDocsProtocol,
|
|
**kwargs: Any,
|
|
) -> Document:
|
|
result = combine_document_func(docs, **kwargs)
|
|
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
|
|
for doc in docs[1:]:
|
|
for k, v in doc.metadata.items():
|
|
if k in combined_metadata:
|
|
combined_metadata[k] += f", {v}"
|
|
else:
|
|
combined_metadata[k] = str(v)
|
|
return Document(page_content=result, metadata=combined_metadata)
|
|
|
|
|
|
class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
|
"""Combining documents by mapping a chain over them, then combining results."""
|
|
|
|
llm_chain: LLMChain
|
|
"""Chain to apply to each document individually."""
|
|
combine_document_chain: BaseCombineDocumentsChain
|
|
"""Chain to use to combine results of applying llm_chain to documents."""
|
|
collapse_document_chain: Optional[BaseCombineDocumentsChain] = None
|
|
"""Chain to use to collapse intermediary results if needed.
|
|
If None, will use the combine_document_chain."""
|
|
document_variable_name: str
|
|
"""The variable name in the llm_chain to put the documents in.
|
|
If only one variable in the llm_chain, this need not be provided."""
|
|
return_intermediate_steps: bool = False
|
|
"""Return the results of the map steps in the output."""
|
|
|
|
@property
|
|
def output_keys(self) -> List[str]:
|
|
"""Expect input key.
|
|
|
|
:meta private:
|
|
"""
|
|
_output_keys = super().output_keys
|
|
if self.return_intermediate_steps:
|
|
_output_keys = _output_keys + ["intermediate_steps"]
|
|
return _output_keys
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
arbitrary_types_allowed = True
|
|
|
|
@root_validator(pre=True)
|
|
def get_return_intermediate_steps(cls, values: Dict) -> Dict:
|
|
"""For backwards compatibility."""
|
|
if "return_map_steps" in values:
|
|
values["return_intermediate_steps"] = values["return_map_steps"]
|
|
del values["return_map_steps"]
|
|
return values
|
|
|
|
@root_validator(pre=True)
|
|
def get_default_document_variable_name(cls, values: Dict) -> Dict:
|
|
"""Get default document variable name, if not provided."""
|
|
if "document_variable_name" not in values:
|
|
llm_chain_variables = values["llm_chain"].prompt.input_variables
|
|
if len(llm_chain_variables) == 1:
|
|
values["document_variable_name"] = llm_chain_variables[0]
|
|
else:
|
|
raise ValueError(
|
|
"document_variable_name must be provided if there are "
|
|
"multiple llm_chain input_variables"
|
|
)
|
|
else:
|
|
llm_chain_variables = values["llm_chain"].prompt.input_variables
|
|
if values["document_variable_name"] not in llm_chain_variables:
|
|
raise ValueError(
|
|
f"document_variable_name {values['document_variable_name']} was "
|
|
f"not found in llm_chain input_variables: {llm_chain_variables}"
|
|
)
|
|
return values
|
|
|
|
@property
|
|
def _collapse_chain(self) -> BaseCombineDocumentsChain:
|
|
if self.collapse_document_chain is not None:
|
|
return self.collapse_document_chain
|
|
else:
|
|
return self.combine_document_chain
|
|
|
|
def combine_docs(
|
|
self,
|
|
docs: List[Document],
|
|
token_max: int = 3000,
|
|
callbacks: Callbacks = None,
|
|
**kwargs: Any,
|
|
) -> Tuple[str, dict]:
|
|
"""Combine documents in a map reduce manner.
|
|
|
|
Combine by mapping first chain over all documents, then reducing the results.
|
|
This reducing can be done recursively if needed (if there are many documents).
|
|
"""
|
|
results = self.llm_chain.apply(
|
|
# FYI - this is parallelized and so it is fast.
|
|
[{self.document_variable_name: d.page_content, **kwargs} for d in docs],
|
|
callbacks=callbacks,
|
|
)
|
|
return self._process_results(
|
|
results, docs, token_max, callbacks=callbacks, **kwargs
|
|
)
|
|
|
|
async def acombine_docs(
|
|
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
|
) -> Tuple[str, dict]:
|
|
"""Combine documents in a map reduce manner.
|
|
|
|
Combine by mapping first chain over all documents, then reducing the results.
|
|
This reducing can be done recursively if needed (if there are many documents).
|
|
"""
|
|
results = await self.llm_chain.aapply(
|
|
# FYI - this is parallelized and so it is fast.
|
|
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
|
|
callbacks=callbacks,
|
|
)
|
|
return await self._aprocess_results(
|
|
results, docs, callbacks=callbacks, **kwargs
|
|
)
|
|
|
|
def _process_results_common(
|
|
self,
|
|
results: List[Dict],
|
|
docs: List[Document],
|
|
token_max: int = 3000,
|
|
callbacks: Callbacks = None,
|
|
**kwargs: Any,
|
|
) -> Tuple[List[Document], dict]:
|
|
question_result_key = self.llm_chain.output_key
|
|
result_docs = [
|
|
Document(page_content=r[question_result_key], metadata=docs[i].metadata)
|
|
# This uses metadata from the docs, and the textual results from `results`
|
|
for i, r in enumerate(results)
|
|
]
|
|
length_func = self.combine_document_chain.prompt_length
|
|
num_tokens = length_func(result_docs, **kwargs)
|
|
|
|
def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
|
|
return self._collapse_chain.run(
|
|
input_documents=docs, callbacks=callbacks, **kwargs
|
|
)
|
|
|
|
while num_tokens is not None and num_tokens > token_max:
|
|
new_result_doc_list = _split_list_of_docs(
|
|
result_docs, length_func, token_max, **kwargs
|
|
)
|
|
result_docs = []
|
|
for docs in new_result_doc_list:
|
|
new_doc = _collapse_docs(docs, _collapse_docs_func, **kwargs)
|
|
result_docs.append(new_doc)
|
|
num_tokens = length_func(result_docs, **kwargs)
|
|
if self.return_intermediate_steps:
|
|
_results = [r[self.llm_chain.output_key] for r in results]
|
|
extra_return_dict = {"intermediate_steps": _results}
|
|
else:
|
|
extra_return_dict = {}
|
|
return result_docs, extra_return_dict
|
|
|
|
def _process_results(
|
|
self,
|
|
results: List[Dict],
|
|
docs: List[Document],
|
|
token_max: int = 3000,
|
|
callbacks: Callbacks = None,
|
|
**kwargs: Any,
|
|
) -> Tuple[str, dict]:
|
|
result_docs, extra_return_dict = self._process_results_common(
|
|
results, docs, token_max, callbacks=callbacks, **kwargs
|
|
)
|
|
output = self.combine_document_chain.run(
|
|
input_documents=result_docs, callbacks=callbacks, **kwargs
|
|
)
|
|
return output, extra_return_dict
|
|
|
|
async def _aprocess_results(
|
|
self,
|
|
results: List[Dict],
|
|
docs: List[Document],
|
|
callbacks: Callbacks = None,
|
|
**kwargs: Any,
|
|
) -> Tuple[str, dict]:
|
|
result_docs, extra_return_dict = self._process_results_common(
|
|
results, docs, callbacks=callbacks, **kwargs
|
|
)
|
|
output = await self.combine_document_chain.arun(
|
|
input_documents=result_docs, callbacks=callbacks, **kwargs
|
|
)
|
|
return output, extra_return_dict
|
|
|
|
@property
|
|
def _chain_type(self) -> str:
|
|
return "map_reduce_documents_chain"
|