diff --git a/langchain/chains/conversational_retrieval/base.py b/langchain/chains/conversational_retrieval/base.py index 6b5f0741..5cec8cc2 100644 --- a/langchain/chains/conversational_retrieval/base.py +++ b/langchain/chains/conversational_retrieval/base.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, Extra, Field, root_validator from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain +from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT from langchain.chains.llm import LLMChain from langchain.chains.question_answering import load_qa_chain @@ -116,9 +117,31 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain, BaseModel): """Chain for chatting with an index.""" retriever: BaseRetriever + """Index to connect to.""" + max_tokens_limit: Optional[int] = None + """If set, restricts the docs to return from store based on tokens, enforced only + for StuffDocumentChain""" + + def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]: + num_docs = len(docs) + + if self.max_tokens_limit and isinstance( + self.combine_docs_chain, StuffDocumentsChain + ): + tokens = [ + self.combine_docs_chain.llm_chain.llm.get_num_tokens(doc.page_content) + for doc in docs + ] + token_count = sum(tokens[:num_docs]) + while token_count > self.max_tokens_limit: + num_docs -= 1 + token_count -= tokens[num_docs] + + return docs[:num_docs] def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: - return self.retriever.get_relevant_documents(question) + docs = self.retriever.get_relevant_documents(question) + return self._reduce_tokens_below_limit(docs) @classmethod def from_llm(