@ -10,6 +10,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
from langchain . chains . base import Chain
from langchain . chains . combine_documents . base import BaseCombineDocumentsChain
from langchain . chains . combine_documents . stuff import StuffDocumentsChain
from langchain . chains . conversational_retrieval . prompts import CONDENSE_QUESTION_PROMPT
from langchain . chains . llm import LLMChain
from langchain . chains . question_answering import load_qa_chain
@ -116,9 +117,31 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain, BaseModel):
""" Chain for chatting with an index. """
retriever : BaseRetriever
""" Index to connect to. """
max_tokens_limit : Optional [ int ] = None
""" If set, restricts the docs to return from store based on tokens, enforced only
for StuffDocumentChain """
def _reduce_tokens_below_limit ( self , docs : List [ Document ] ) - > List [ Document ] :
num_docs = len ( docs )
if self . max_tokens_limit and isinstance (
self . combine_docs_chain , StuffDocumentsChain
) :
tokens = [
self . combine_docs_chain . llm_chain . llm . get_num_tokens ( doc . page_content )
for doc in docs
]
token_count = sum ( tokens [ : num_docs ] )
while token_count > self . max_tokens_limit :
num_docs - = 1
token_count - = tokens [ num_docs ]
return docs [ : num_docs ]
def _get_docs ( self , question : str , inputs : Dict [ str , Any ] ) - > List [ Document ] :
return self . retriever . get_relevant_documents ( question )
docs = self . retriever . get_relevant_documents ( question )
return self . _reduce_tokens_below_limit ( docs )
@classmethod
def from_llm (