mirror of https://github.com/hwchase17/langchain
11474 (#11519)
No relevant documents may be found for a given question. In some use cases, we could directly respond with a fixed message instead of doing an LLM call with an empty context. This PR exposes this as an option: response_if_no_docs_found. --------- Co-authored-by: Sudharsan Rangarajan <sudranga@nile-global.com>pull/11257/head
parent
1f7edcd08b
commit
9beb03e771
@ -0,0 +1,52 @@
|
||||
"""Test conversation chain and memory."""
|
||||
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
|
||||
from langchain.llms.fake import FakeListLLM
|
||||
from langchain.memory.buffer import ConversationBufferMemory
|
||||
from langchain.schema import Document
|
||||
from tests.unit_tests.retrievers.sequential_retriever import SequentialRetriever
|
||||
|
||||
|
||||
def test_fixed_message_response_when_no_docs_found() -> None:
|
||||
fixed_resp = "I don't know"
|
||||
answer = "I know the answer!"
|
||||
llm = FakeListLLM(responses=[answer])
|
||||
retriever = SequentialRetriever(sequential_responses=[[]])
|
||||
memory = ConversationBufferMemory(
|
||||
k=1, output_key="answer", memory_key="chat_history", return_messages=True
|
||||
)
|
||||
qa_chain = ConversationalRetrievalChain.from_llm(
|
||||
llm=llm,
|
||||
memory=memory,
|
||||
retriever=retriever,
|
||||
return_source_documents=True,
|
||||
rephrase_question=False,
|
||||
response_if_no_docs_found=fixed_resp,
|
||||
verbose=True,
|
||||
)
|
||||
got = qa_chain("What is the answer?")
|
||||
assert got["chat_history"][1].content == fixed_resp
|
||||
assert got["answer"] == fixed_resp
|
||||
|
||||
|
||||
def test_fixed_message_response_when_docs_found() -> None:
|
||||
fixed_resp = "I don't know"
|
||||
answer = "I know the answer!"
|
||||
llm = FakeListLLM(responses=[answer])
|
||||
retriever = SequentialRetriever(
|
||||
sequential_responses=[[Document(page_content=answer)]]
|
||||
)
|
||||
memory = ConversationBufferMemory(
|
||||
k=1, output_key="answer", memory_key="chat_history", return_messages=True
|
||||
)
|
||||
qa_chain = ConversationalRetrievalChain.from_llm(
|
||||
llm=llm,
|
||||
memory=memory,
|
||||
retriever=retriever,
|
||||
return_source_documents=True,
|
||||
rephrase_question=False,
|
||||
response_if_no_docs_found=fixed_resp,
|
||||
verbose=True,
|
||||
)
|
||||
got = qa_chain("What is the answer?")
|
||||
assert got["chat_history"][1].content == answer
|
||||
assert got["answer"] == answer
|
@ -0,0 +1,26 @@
|
||||
from typing import List
|
||||
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
class SequentialRetriever(BaseRetriever):
|
||||
"""Test util that returns a sequence of documents"""
|
||||
|
||||
sequential_responses: List[List[Document]]
|
||||
response_index: int = 0
|
||||
|
||||
def _get_relevant_documents( # type: ignore[override]
|
||||
self,
|
||||
query: str,
|
||||
) -> List[Document]:
|
||||
if self.response_index >= len(self.sequential_responses):
|
||||
return []
|
||||
else:
|
||||
self.response_index += 1
|
||||
return self.sequential_responses[self.response_index - 1]
|
||||
|
||||
async def _aget_relevant_documents( # type: ignore[override]
|
||||
self,
|
||||
query: str,
|
||||
) -> List[Document]:
|
||||
return self._get_relevant_documents(query)
|
Loading…
Reference in New Issue