|
|
|
@ -24,9 +24,21 @@ class AsyncCombineDocsProtocol(Protocol):
|
|
|
|
|
"""Async interface for the combine_docs method."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _split_list_of_docs(
|
|
|
|
|
def split_list_of_docs(
|
|
|
|
|
docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any
|
|
|
|
|
) -> List[List[Document]]:
|
|
|
|
|
"""Split Documents into subsets that each meet a cumulative length constraint.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
docs: The full list of Documents.
|
|
|
|
|
length_func: Function for computing the cumulative length of a set of Documents.
|
|
|
|
|
token_max: The maximum cumulative length of any subset of Documents.
|
|
|
|
|
**kwargs: Arbitrary additional keyword params to pass to each call of the
|
|
|
|
|
length_func.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A List[List[Document]].
|
|
|
|
|
"""
|
|
|
|
|
new_result_doc_list = []
|
|
|
|
|
_sub_result_docs = []
|
|
|
|
|
for doc in docs:
|
|
|
|
@ -44,11 +56,27 @@ def _split_list_of_docs(
|
|
|
|
|
return new_result_doc_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _collapse_docs(
|
|
|
|
|
def collapse_docs(
|
|
|
|
|
docs: List[Document],
|
|
|
|
|
combine_document_func: CombineDocsProtocol,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> Document:
|
|
|
|
|
"""Execute a collapse function on a set of documents and merge their metadatas.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
docs: A list of Documents to combine.
|
|
|
|
|
combine_document_func: A function that takes in a list of Documents and
|
|
|
|
|
optionally addition keyword parameters and combines them into a single
|
|
|
|
|
string.
|
|
|
|
|
**kwargs: Arbitrary additional keyword params to pass to the
|
|
|
|
|
combine_document_func.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A single Document with the output of combine_document_func for the page content
|
|
|
|
|
and the combined metadata's of all the input documents. All metadata values
|
|
|
|
|
are strings, and where there are overlapping keys across documents the
|
|
|
|
|
values are joined by ", ".
|
|
|
|
|
"""
|
|
|
|
|
result = combine_document_func(docs, **kwargs)
|
|
|
|
|
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
|
|
|
|
|
for doc in docs[1:]:
|
|
|
|
@ -60,11 +88,27 @@ def _collapse_docs(
|
|
|
|
|
return Document(page_content=result, metadata=combined_metadata)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _acollapse_docs(
|
|
|
|
|
async def acollapse_docs(
|
|
|
|
|
docs: List[Document],
|
|
|
|
|
combine_document_func: AsyncCombineDocsProtocol,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> Document:
|
|
|
|
|
"""Execute a collapse function on a set of documents and merge their metadatas.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
docs: A list of Documents to combine.
|
|
|
|
|
combine_document_func: A function that takes in a list of Documents and
|
|
|
|
|
optionally addition keyword parameters and combines them into a single
|
|
|
|
|
string.
|
|
|
|
|
**kwargs: Arbitrary additional keyword params to pass to the
|
|
|
|
|
combine_document_func.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A single Document with the output of combine_document_func for the page content
|
|
|
|
|
and the combined metadata's of all the input documents. All metadata values
|
|
|
|
|
are strings, and where there are overlapping keys across documents the
|
|
|
|
|
values are joined by ", ".
|
|
|
|
|
"""
|
|
|
|
|
result = await combine_document_func(docs, **kwargs)
|
|
|
|
|
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
|
|
|
|
|
for doc in docs[1:]:
|
|
|
|
@ -245,12 +289,12 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
|
|
|
|
|
|
|
|
|
_token_max = token_max or self.token_max
|
|
|
|
|
while num_tokens is not None and num_tokens > _token_max:
|
|
|
|
|
new_result_doc_list = _split_list_of_docs(
|
|
|
|
|
new_result_doc_list = split_list_of_docs(
|
|
|
|
|
result_docs, length_func, _token_max, **kwargs
|
|
|
|
|
)
|
|
|
|
|
result_docs = []
|
|
|
|
|
for docs in new_result_doc_list:
|
|
|
|
|
new_doc = _collapse_docs(docs, _collapse_docs_func, **kwargs)
|
|
|
|
|
new_doc = collapse_docs(docs, _collapse_docs_func, **kwargs)
|
|
|
|
|
result_docs.append(new_doc)
|
|
|
|
|
num_tokens = length_func(result_docs, **kwargs)
|
|
|
|
|
return result_docs, {}
|
|
|
|
@ -273,12 +317,12 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
|
|
|
|
|
|
|
|
|
_token_max = token_max or self.token_max
|
|
|
|
|
while num_tokens is not None and num_tokens > _token_max:
|
|
|
|
|
new_result_doc_list = _split_list_of_docs(
|
|
|
|
|
new_result_doc_list = split_list_of_docs(
|
|
|
|
|
result_docs, length_func, _token_max, **kwargs
|
|
|
|
|
)
|
|
|
|
|
result_docs = []
|
|
|
|
|
for docs in new_result_doc_list:
|
|
|
|
|
new_doc = await _acollapse_docs(docs, _collapse_docs_func, **kwargs)
|
|
|
|
|
new_doc = await acollapse_docs(docs, _collapse_docs_func, **kwargs)
|
|
|
|
|
result_docs.append(new_doc)
|
|
|
|
|
num_tokens = length_func(result_docs, **kwargs)
|
|
|
|
|
return result_docs, {}
|
|
|
|
|