From 28efbb05bf2f76e99853b6cde7c2d52059054149 Mon Sep 17 00:00:00 2001 From: Smit Shah Date: Fri, 27 Jan 2023 09:13:01 +0530 Subject: [PATCH] Add params to reduce K dynamically to reduce it below token limit (#739) Referring to #687, I implemented the functionality to reduce K if it exceeds the token limit. Edit: I should have ran make lint locally. Also, this only applies to `StuffDocumentChain` --- langchain/chains/qa_with_sources/vector_db.py | 29 ++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/langchain/chains/qa_with_sources/vector_db.py b/langchain/chains/qa_with_sources/vector_db.py index 78cdf2fa..6da7ff33 100644 --- a/langchain/chains/qa_with_sources/vector_db.py +++ b/langchain/chains/qa_with_sources/vector_db.py @@ -1,8 +1,10 @@ """Question-answering with sources over a vector database.""" + from typing import Any, Dict, List from pydantic import BaseModel, Field +from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain from langchain.docstore.document import Document from langchain.vectorstores.base import VectorStore @@ -15,11 +17,36 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain, BaseModel): """Vector Database to connect to.""" k: int = 4 """Number of results to return from store""" + reduce_k_below_max_tokens: bool = False + """Reduce the number of results to return from store based on tokens limit""" + max_tokens_limit: int = 3375 + """Restrict the docs to return from store based on tokens, + enforced only for StuffDocumentChain and if reduce_k_below_max_tokens is to true""" search_kwargs: Dict[str, Any] = Field(default_factory=dict) """Extra search args.""" + def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]: + num_docs = len(docs) + + if self.reduce_k_below_max_tokens and isinstance( + self.combine_documents_chain, StuffDocumentsChain + ): + tokens = [ + self.combine_documents_chain.llm_chain.llm.get_num_tokens( + doc.page_content + ) + for doc in docs + ] + token_count = sum(tokens[:num_docs]) + while token_count > self.max_tokens_limit: + num_docs -= 1 + token_count -= tokens[num_docs] + + return docs[:num_docs] + def _get_docs(self, inputs: Dict[str, Any]) -> List[Document]: question = inputs[self.question_key] - return self.vectorstore.similarity_search( + docs = self.vectorstore.similarity_search( question, k=self.k, **self.search_kwargs ) + return self._reduce_tokens_below_limit(docs)