add token max parameter (#7204)

This commit is contained in:
Harrison Chase 2023-07-05 12:09:25 -04:00 committed by GitHub
parent 7b585c7585
commit 8410c6a747
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 26 additions and 11 deletions

View File

@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple
from pydantic import Extra, root_validator
@ -198,7 +198,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
def combine_docs(
self,
docs: List[Document],
token_max: int = 3000,
token_max: Optional[int] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
@ -229,7 +229,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
async def acombine_docs(
self,
docs: List[Document],
token_max: int = 3000,
token_max: Optional[int] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:

View File

@ -152,6 +152,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
"""Chain to use to collapse documents if needed until they can all fit.
If None, will use the combine_documents_chain.
This is typically a StuffDocumentsChain."""
token_max: int = 3000
"""The maximum number of tokens to group documents into. For example, if
set to 3000 then documents will be grouped into chunks of no greater than
3000 tokens before trying to combine them into a smaller chunk."""
class Config:
"""Configuration for this pydantic object."""
@ -169,7 +173,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
def combine_docs(
self,
docs: List[Document],
token_max: int = 3000,
token_max: Optional[int] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
@ -198,7 +202,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
async def acombine_docs(
self,
docs: List[Document],
token_max: int = 3000,
token_max: Optional[int] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[str, dict]:
@ -227,7 +231,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
def _collapse(
self,
docs: List[Document],
token_max: int = 3000,
token_max: Optional[int] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[List[Document], dict]:
@ -240,9 +244,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
input_documents=docs, callbacks=callbacks, **kwargs
)
while num_tokens is not None and num_tokens > token_max:
_token_max = token_max or self.token_max
while num_tokens is not None and num_tokens > _token_max:
new_result_doc_list = _split_list_of_docs(
result_docs, length_func, token_max, **kwargs
result_docs, length_func, _token_max, **kwargs
)
result_docs = []
for docs in new_result_doc_list:
@ -254,7 +259,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
async def _acollapse(
self,
docs: List[Document],
token_max: int = 3000,
token_max: Optional[int] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> Tuple[List[Document], dict]:
@ -267,9 +272,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
input_documents=docs, callbacks=callbacks, **kwargs
)
while num_tokens is not None and num_tokens > token_max:
_token_max = token_max or self.token_max
while num_tokens is not None and num_tokens > _token_max:
new_result_doc_list = _split_list_of_docs(
result_docs, length_func, token_max, **kwargs
result_docs, length_func, _token_max, **kwargs
)
result_docs = []
for docs in new_result_doc_list:

View File

@ -79,6 +79,7 @@ def _load_map_reduce_chain(
reduce_llm: Optional[BaseLanguageModel] = None,
collapse_llm: Optional[BaseLanguageModel] = None,
verbose: Optional[bool] = None,
token_max: int = 3000,
**kwargs: Any,
) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
@ -111,6 +112,8 @@ def _load_map_reduce_chain(
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_chain,
token_max=token_max,
verbose=verbose,
)
return MapReduceDocumentsChain(
llm_chain=map_chain,

View File

@ -99,6 +99,7 @@ def _load_map_reduce_chain(
verbose: Optional[bool] = None,
callback_manager: Optional[BaseCallbackManager] = None,
callbacks: Callbacks = None,
token_max: int = 3000,
**kwargs: Any,
) -> MapReduceDocumentsChain:
_question_prompt = (
@ -154,6 +155,8 @@ def _load_map_reduce_chain(
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_chain,
token_max=token_max,
verbose=verbose,
)
return MapReduceDocumentsChain(
llm_chain=map_chain,

View File

@ -48,6 +48,7 @@ def _load_map_reduce_chain(
reduce_llm: Optional[BaseLanguageModel] = None,
collapse_llm: Optional[BaseLanguageModel] = None,
verbose: Optional[bool] = None,
token_max: int = 3000,
**kwargs: Any,
) -> MapReduceDocumentsChain:
map_chain = LLMChain(llm=llm, prompt=map_prompt, verbose=verbose)
@ -79,6 +80,8 @@ def _load_map_reduce_chain(
reduce_documents_chain = ReduceDocumentsChain(
combine_documents_chain=combine_documents_chain,
collapse_documents_chain=collapse_chain,
token_max=token_max,
verbose=verbose,
)
return MapReduceDocumentsChain(
llm_chain=map_chain,