add token reduction to ConversationalRetrievalChain (#2075)

This worked for me, but I'm not sure if its the right way to approach
something like this, so I'm open to suggestions.

Adds class properties `reduce_k_below_max_tokens: bool` and
`max_tokens_limit: int` to the `ConversationalRetrievalChain`. The code
is basically copied from
[`RetreivalQAWithSourcesChain`](46d141c6cb/langchain/chains/qa_with_sources/retrieval.py (L24))
This commit is contained in:
Nick 2023-03-28 18:07:31 -04:00 committed by GitHub
parent ef25904ecb
commit 0874872dee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10,6 +10,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain from langchain.chains.question_answering import load_qa_chain
@ -116,9 +117,31 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain, BaseModel):
"""Chain for chatting with an index.""" """Chain for chatting with an index."""
retriever: BaseRetriever retriever: BaseRetriever
"""Index to connect to."""
max_tokens_limit: Optional[int] = None
"""If set, restricts the docs to return from store based on tokens, enforced only
for StuffDocumentChain"""
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
num_docs = len(docs)
if self.max_tokens_limit and isinstance(
self.combine_docs_chain, StuffDocumentsChain
):
tokens = [
self.combine_docs_chain.llm_chain.llm.get_num_tokens(doc.page_content)
for doc in docs
]
token_count = sum(tokens[:num_docs])
while token_count > self.max_tokens_limit:
num_docs -= 1
token_count -= tokens[num_docs]
return docs[:num_docs]
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]: def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
return self.retriever.get_relevant_documents(question) docs = self.retriever.get_relevant_documents(question)
return self._reduce_tokens_below_limit(docs)
@classmethod @classmethod
def from_llm( def from_llm(