mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
add token reduction to ConversationalRetrievalChain (#2075)
This worked for me, but I'm not sure if its the right way to approach
something like this, so I'm open to suggestions.
Adds class properties `reduce_k_below_max_tokens: bool` and
`max_tokens_limit: int` to the `ConversationalRetrievalChain`. The code
is basically copied from
[`RetreivalQAWithSourcesChain`](46d141c6cb/langchain/chains/qa_with_sources/retrieval.py (L24)
)
This commit is contained in:
parent
ef25904ecb
commit
0874872dee
@ -10,6 +10,7 @@ from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
@ -116,9 +117,31 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain, BaseModel):
|
||||
"""Chain for chatting with an index."""
|
||||
|
||||
retriever: BaseRetriever
|
||||
"""Index to connect to."""
|
||||
max_tokens_limit: Optional[int] = None
|
||||
"""If set, restricts the docs to return from store based on tokens, enforced only
|
||||
for StuffDocumentChain"""
|
||||
|
||||
def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
|
||||
num_docs = len(docs)
|
||||
|
||||
if self.max_tokens_limit and isinstance(
|
||||
self.combine_docs_chain, StuffDocumentsChain
|
||||
):
|
||||
tokens = [
|
||||
self.combine_docs_chain.llm_chain.llm.get_num_tokens(doc.page_content)
|
||||
for doc in docs
|
||||
]
|
||||
token_count = sum(tokens[:num_docs])
|
||||
while token_count > self.max_tokens_limit:
|
||||
num_docs -= 1
|
||||
token_count -= tokens[num_docs]
|
||||
|
||||
return docs[:num_docs]
|
||||
|
||||
def _get_docs(self, question: str, inputs: Dict[str, Any]) -> List[Document]:
|
||||
return self.retriever.get_relevant_documents(question)
|
||||
docs = self.retriever.get_relevant_documents(question)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
@classmethod
|
||||
def from_llm(
|
||||
|
Loading…
Reference in New Issue
Block a user