mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
add token max parameter (#7204)
This commit is contained in:
parent
7b585c7585
commit
8410c6a747
@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
@ -198,7 +198,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
def combine_docs(
|
||||
self,
|
||||
docs: List[Document],
|
||||
token_max: int = 3000,
|
||||
token_max: Optional[int] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[str, dict]:
|
||||
@ -229,7 +229,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
async def acombine_docs(
|
||||
self,
|
||||
docs: List[Document],
|
||||
token_max: int = 3000,
|
||||
token_max: Optional[int] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[str, dict]:
|
||||
|
@ -152,6 +152,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""Chain to use to collapse documents if needed until they can all fit.
|
||||
If None, will use the combine_documents_chain.
|
||||
This is typically a StuffDocumentsChain."""
|
||||
token_max: int = 3000
|
||||
"""The maximum number of tokens to group documents into. For example, if
|
||||
set to 3000 then documents will be grouped into chunks of no greater than
|
||||
3000 tokens before trying to combine them into a smaller chunk."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -169,7 +173,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
def combine_docs(
|
||||
self,
|
||||
docs: List[Document],
|
||||
token_max: int = 3000,
|
||||
token_max: Optional[int] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[str, dict]:
|
||||
@ -198,7 +202,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
async def acombine_docs(
|
||||
self,
|
||||
docs: List[Document],
|
||||
token_max: int = 3000,
|
||||
token_max: Optional[int] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[str, dict]:
|
||||
@ -227,7 +231,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
def _collapse(
|
||||
self,
|
||||
docs: List[Document],
|
||||
token_max: int = 3000,
|
||||
token_max: Optional[int] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[List[Document], dict]:
|
||||
@ -240,9 +244,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
input_documents=docs, callbacks=callbacks, **kwargs
|
||||
)
|
||||
|
||||
while num_tokens is not None and num_tokens > token_max:
|
||||
_token_max = token_max or self.token_max
|
||||
while num_tokens is not None and num_tokens > _token_max:
|
||||
new_result_doc_list = _split_list_of_docs(
|
||||
result_docs, length_func, token_max, **kwargs
|
||||
result_docs, length_func, _token_max, **kwargs
|
||||
)
|
||||
result_docs = []
|
||||
for docs in new_result_doc_list:
|
||||
@ -254,7 +259,7 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
async def _acollapse(
|
||||
self,
|
||||
docs: List[Document],
|
||||
token_max: int = 3000,
|
||||
token_max: Optional[int] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[List[Document], dict]:
|
||||
@ -267,9 +272,10 @@ class ReduceDocumentsChain(BaseCombineDocumentsChain):
|
||||
input_documents=docs, callbacks=callbacks, **kwargs
|
||||
)
|
||||
|
||||
while num_tokens is not None and num_tokens > token_max:
|
||||
_token_max = token_max or self.token_max
|
||||
while num_tokens is not None and num_tokens > _token_max:
|
||||
new_result_doc_list = _split_list_of_docs(
|
||||
result_docs, length_func, token_max, **kwargs
|
||||
result_docs, length_func, _token_max, **kwargs
|
||||
)
|
||||
result_docs = []
|
||||
for docs in new_result_doc_list:
|
||||
|
@ -79,6 +79,7 @@ def _load_map_reduce_chain(
|
||||
reduce_llm: Optional[BaseLanguageModel] = None,
|
||||
collapse_llm: Optional[BaseLanguageModel] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
token_max: int = 3000,
|
||||
**kwargs: Any,
|
||||
) -> MapReduceDocumentsChain:
|
||||
map_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)
|
||||
@ -111,6 +112,8 @@ def _load_map_reduce_chain(
|
||||
reduce_documents_chain = ReduceDocumentsChain(
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
collapse_documents_chain=collapse_chain,
|
||||
token_max=token_max,
|
||||
verbose=verbose,
|
||||
)
|
||||
return MapReduceDocumentsChain(
|
||||
llm_chain=map_chain,
|
||||
|
@ -99,6 +99,7 @@ def _load_map_reduce_chain(
|
||||
verbose: Optional[bool] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
callbacks: Callbacks = None,
|
||||
token_max: int = 3000,
|
||||
**kwargs: Any,
|
||||
) -> MapReduceDocumentsChain:
|
||||
_question_prompt = (
|
||||
@ -154,6 +155,8 @@ def _load_map_reduce_chain(
|
||||
reduce_documents_chain = ReduceDocumentsChain(
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
collapse_documents_chain=collapse_chain,
|
||||
token_max=token_max,
|
||||
verbose=verbose,
|
||||
)
|
||||
return MapReduceDocumentsChain(
|
||||
llm_chain=map_chain,
|
||||
|
@ -48,6 +48,7 @@ def _load_map_reduce_chain(
|
||||
reduce_llm: Optional[BaseLanguageModel] = None,
|
||||
collapse_llm: Optional[BaseLanguageModel] = None,
|
||||
verbose: Optional[bool] = None,
|
||||
token_max: int = 3000,
|
||||
**kwargs: Any,
|
||||
) -> MapReduceDocumentsChain:
|
||||
map_chain = LLMChain(llm=llm, prompt=map_prompt, verbose=verbose)
|
||||
@ -79,6 +80,8 @@ def _load_map_reduce_chain(
|
||||
reduce_documents_chain = ReduceDocumentsChain(
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
collapse_documents_chain=collapse_chain,
|
||||
token_max=token_max,
|
||||
verbose=verbose,
|
||||
)
|
||||
return MapReduceDocumentsChain(
|
||||
llm_chain=map_chain,
|
||||
|
Loading…
Reference in New Issue
Block a user