No relevant documents may be found for a given question. In some use
cases, we could directly respond with a fixed message instead of doing
an LLM call with an empty context. This PR exposes this as an option:
response_if_no_docs_found.

---------

Co-authored-by: Sudharsan Rangarajan <sudranga@nile-global.com>
pull/11257/head
sudranga 1 year ago committed by GitHub
parent 1f7edcd08b
commit 9beb03e771
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -79,6 +79,9 @@ class BaseConversationalRetrievalChain(Chain):
get_chat_history: Optional[Callable[[List[CHAT_TURN_TYPE]], str]] = None
"""An optional function to get a string of the chat history.
If None is provided, will use a default."""
response_if_no_docs_found: Optional[str]
"""If specified, the chain will return a fixed response if no docs
are found for the question. """
class Config:
"""Configuration for this pydantic object."""
@ -143,14 +146,19 @@ class BaseConversationalRetrievalChain(Chain):
docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
else:
docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
new_inputs = inputs.copy()
if self.rephrase_question:
new_inputs["question"] = new_question
new_inputs["chat_history"] = chat_history_str
answer = self.combine_docs_chain.run(
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
)
output: Dict[str, Any] = {self.output_key: answer}
output: Dict[str, Any] = {}
if self.response_if_no_docs_found is not None and len(docs) == 0:
output[self.output_key] = self.response_if_no_docs_found
else:
new_inputs = inputs.copy()
if self.rephrase_question:
new_inputs["question"] = new_question
new_inputs["chat_history"] = chat_history_str
answer = self.combine_docs_chain.run(
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
)
output[self.output_key] = answer
if self.return_source_documents:
output["source_documents"] = docs
if self.return_generated_question:

@ -0,0 +1,52 @@
"""Test conversation chain and memory."""
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
from langchain.llms.fake import FakeListLLM
from langchain.memory.buffer import ConversationBufferMemory
from langchain.schema import Document
from tests.unit_tests.retrievers.sequential_retriever import SequentialRetriever
def test_fixed_message_response_when_no_docs_found() -> None:
fixed_resp = "I don't know"
answer = "I know the answer!"
llm = FakeListLLM(responses=[answer])
retriever = SequentialRetriever(sequential_responses=[[]])
memory = ConversationBufferMemory(
k=1, output_key="answer", memory_key="chat_history", return_messages=True
)
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
memory=memory,
retriever=retriever,
return_source_documents=True,
rephrase_question=False,
response_if_no_docs_found=fixed_resp,
verbose=True,
)
got = qa_chain("What is the answer?")
assert got["chat_history"][1].content == fixed_resp
assert got["answer"] == fixed_resp
def test_fixed_message_response_when_docs_found() -> None:
fixed_resp = "I don't know"
answer = "I know the answer!"
llm = FakeListLLM(responses=[answer])
retriever = SequentialRetriever(
sequential_responses=[[Document(page_content=answer)]]
)
memory = ConversationBufferMemory(
k=1, output_key="answer", memory_key="chat_history", return_messages=True
)
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
memory=memory,
retriever=retriever,
return_source_documents=True,
rephrase_question=False,
response_if_no_docs_found=fixed_resp,
verbose=True,
)
got = qa_chain("What is the answer?")
assert got["chat_history"][1].content == answer
assert got["answer"] == answer

@ -0,0 +1,26 @@
from typing import List
from langchain.schema import BaseRetriever, Document
class SequentialRetriever(BaseRetriever):
"""Test util that returns a sequence of documents"""
sequential_responses: List[List[Document]]
response_index: int = 0
def _get_relevant_documents( # type: ignore[override]
self,
query: str,
) -> List[Document]:
if self.response_index >= len(self.sequential_responses):
return []
else:
self.response_index += 1
return self.sequential_responses[self.response_index - 1]
async def _aget_relevant_documents( # type: ignore[override]
self,
query: str,
) -> List[Document]:
return self._get_relevant_documents(query)
Loading…
Cancel
Save