From 0874872deed80c8aadde572c27e8ec0be22320be Mon Sep 17 00:00:00 2001 From: Nick Date: Tue, 28 Mar 2023 18:07:31 -0400 Subject: [PATCH] add token reduction to ConversationalRetrievalChain (#2075) This worked for me, but I'm not sure if its the right way to approach something like this, so I'm open to suggestions. Adds class properties `reduce_k_below_max_tokens: bool` and `max_tokens_limit: int` to the `ConversationalRetrievalChain`. The code is basically copied from [`RetreivalQAWithSourcesChain`](https://github.com/nkov/langchain/blob/46d141c6cb6c0fdebb308336d8ae140d8368945a/langchain/chains/qa_with_sources/retrieval.py#L24) --- .../chains/conversational_retrieval/base.py | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/langchain/chains/conversational_retrieval/base.py b/langchain/chains/conversational_retrieval/base.py index 6b5f0741da..5cec8cc257 100644 --- a/langchain/chains/conversational_retrieval/base.py +++ b/langchain/chains/conversational_retrieval/base.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, Extra, Field, root_validator from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain +from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT from langchain.chains.llm import LLMChain from langchain.chains.question_answering import load_qa_chain @@ -116,9 +117,31 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain, BaseModel): """Chain for chatting with an index.""" retriever: BaseRetriever + """Index to connect to.""" + max_tokens_limit: Optional[int] = None + """If set, restricts the docs to return from store based on tokens, enforced only + for StuffDocumentChain""" + + def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]: + num_docs = len(docs) + + if self.max_tokens_limit and isinstance( + self.combine_docs_chain, StuffDocumentsChain + ): + tokens = [ + self.combine_docs_chain.llm_chain.llm.get_num_tokens(doc.page_content) + for doc in docs + ] + token_count = sum(tokens[:num_docs]) + while token_count > self.max_tokens_limit: + num_docs -= 1 + token_count -= tokens[num_docs] + + return docs[:num_docs] def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: - return self.retriever.get_relevant_documents(question) + docs = self.retriever.get_relevant_documents(question) + return self._reduce_tokens_below_limit(docs) @classmethod def from_llm(