Use run and arun in place of combine_docs and acombine_docs (#2635)

`combine_docs` does not go through the standard chain call path which
means that chain callbacks won't be triggered, meaning QA chains won't
be traced properly, this fixes that.

Also fix several errors in the chat_vector_db notebook
This commit is contained in:
Ankush Gola 2023-04-10 03:47:59 +02:00 committed by GitHub
parent 50c511d75f
commit b82cbd1be0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 105 additions and 69 deletions

View File

@ -5,14 +5,14 @@
"id": "134a0785",
"metadata": {},
"source": [
"# Chat Index\n",
"# Chat Over Documents with Chat History\n",
"\n",
"This notebook goes over how to set up a chain to chat with an index. The only difference between this chain and the [RetrievalQAChain](./vector_db_qa.ipynb) is that this allows for passing in of a chat history which can be used to allow for follow up questions."
"This notebook goes over how to set up a chain to chat over documents with chat history using a `ConversationalRetrievalChain`. The only difference between this chain and the [RetrievalQAChain](./vector_db_qa.ipynb) is that this allows for passing in of a chat history which can be used to allow for follow up questions."
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"id": "70c4e529",
"metadata": {
"tags": []
@ -36,7 +36,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 4,
"id": "01c46e92",
"metadata": {
"tags": []
@ -58,7 +58,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 5,
"id": "433363a5",
"metadata": {
"tags": []
@ -81,7 +81,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"id": "a8930cf7",
"metadata": {
"tags": []
@ -109,12 +109,12 @@
"id": "3c96b118",
"metadata": {},
"source": [
"We now initialize the ConversationalRetrievalChain"
"We now initialize the `ConversationalRetrievalChain`"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 7,
"id": "7b4110f3",
"metadata": {
"tags": []
@ -134,7 +134,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 8,
"id": "7fe3e730",
"metadata": {
"tags": []
@ -148,7 +148,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 9,
"id": "bfff9cc8",
"metadata": {
"tags": []
@ -160,7 +160,7 @@
"\" The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. He also said that she is a consensus builder and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\""
]
},
"execution_count": 7,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@ -179,7 +179,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 10,
"id": "00b4cf00",
"metadata": {
"tags": []
@ -193,7 +193,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 11,
"id": "f01828d1",
"metadata": {
"tags": []
@ -202,10 +202,10 @@
{
"data": {
"text/plain": [
"' Justice Stephen Breyer'"
"' Ketanji Brown Jackson succeeded Justice Stephen Breyer on the United States Supreme Court.'"
]
},
"execution_count": 9,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
@ -225,9 +225,11 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"id": "562769c6",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"qa = ConversationalRetrievalChain.from_llm(OpenAI(temperature=0), vectorstore.as_retriever(), return_source_documents=True)"
@ -235,9 +237,11 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"id": "ea478300",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"chat_history = []\n",
@ -247,17 +251,19 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"id": "4cb75b4e",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"Document(page_content='Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while youre at it, pass the Disclose Act so Americans can know who is funding our elections. \\n\\nTonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \\n\\nOne of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence.', lookup_str='', metadata={'source': '../../state_of_the_union.txt'}, lookup_index=0)"
"Document(page_content='Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while youre at it, pass the Disclose Act so Americans can know who is funding our elections. \\n\\nTonight, Id like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \\n\\nOne of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nations top legal minds, who will continue Justice Breyers legacy of excellence.', metadata={'source': '../../state_of_the_union.txt'})"
]
},
"execution_count": 13,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
@ -277,9 +283,11 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 15,
"id": "5ed8d612",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"vectordbkwargs = {\"search_distance\": 0.9}"
@ -287,9 +295,11 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 16,
"id": "6a7b3459",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"qa = ConversationalRetrievalChain.from_llm(OpenAI(temperature=0), vectorstore.as_retriever(), return_source_documents=True)\n",
@ -309,21 +319,25 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 18,
"id": "e53a9d66",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.chains import LLMChain\n",
"from langchain.chains.question_answering import load_qa_chain\n",
"from langchain.chains.chat_index.prompts import CONDENSE_QUESTION_PROMPT"
"from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "bf205e35",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"llm = OpenAI(temperature=0)\n",
@ -341,7 +355,9 @@
"cell_type": "code",
"execution_count": 20,
"id": "78155887",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"chat_history = []\n",
@ -353,7 +369,9 @@
"cell_type": "code",
"execution_count": 21,
"id": "e54b5fa2",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
@ -384,7 +402,9 @@
"cell_type": "code",
"execution_count": 22,
"id": "d1058fd2",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from langchain.chains.qa_with_sources import load_qa_with_sources_chain"
@ -394,7 +414,9 @@
"cell_type": "code",
"execution_count": 23,
"id": "a6594482",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"llm = OpenAI(temperature=0)\n",
@ -412,7 +434,9 @@
"cell_type": "code",
"execution_count": 24,
"id": "e2badd21",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"chat_history = []\n",
@ -424,7 +448,9 @@
"cell_type": "code",
"execution_count": 25,
"id": "edb31fe5",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
@ -453,7 +479,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 27,
"id": "2efacec3-2690-4b05-8de3-a32fd2ac3911",
"metadata": {
"tags": []
@ -463,7 +489,7 @@
"from langchain.chains.llm import LLMChain\n",
"from langchain.callbacks.base import CallbackManager\n",
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
"from langchain.chains.chat_index.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT\n",
"from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT\n",
"from langchain.chains.question_answering import load_qa_chain\n",
"\n",
"# Construct a ConversationalRetrievalChain with a streaming llm for combine docs\n",
@ -480,7 +506,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 28,
"id": "fd6d43f4-7428-44a4-81bc-26fe88a98762",
"metadata": {
"tags": []
@ -502,7 +528,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 29,
"id": "5ab38978-f3e8-4fa7-808c-c79dec48379a",
"metadata": {
"tags": []
@ -512,7 +538,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
" Justice Stephen Breyer"
" Ketanji Brown Jackson succeeded Justice Stephen Breyer on the United States Supreme Court."
]
}
],
@ -533,9 +559,11 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 31,
"id": "a7ba9d8c",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def get_chat_history(inputs) -> str:\n",
@ -543,14 +571,16 @@
" for human, ai in inputs:\n",
" res.append(f\"Human:{human}\\nAI:{ai}\")\n",
" return \"\\n\".join(res)\n",
"qa = ConversationalRetrievalChain.from_llm(OpenAI(temperature=0), vectorstore, get_chat_history=get_chat_history)"
"qa = ConversationalRetrievalChain.from_llm(OpenAI(temperature=0), vectorstore.as_retriever(), get_chat_history=get_chat_history)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 32,
"id": "a3e33c0d",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"chat_history = []\n",
@ -560,9 +590,11 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 33,
"id": "936dc62f",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
@ -570,7 +602,7 @@
"\" The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. He also said that she is a consensus builder and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\""
]
},
"execution_count": 31,
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
@ -604,7 +636,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.10.9"
}
},
"nbformat": 4,

View File

@ -14,7 +14,7 @@ from langchain.docstore.document import Document
class CombineDocsProtocol(Protocol):
"""Interface for the combine_docs method."""
def __call__(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
def __call__(self, docs: List[Document], **kwargs: Any) -> str:
"""Interface for the combine_docs method."""
@ -48,7 +48,7 @@ def _collapse_docs(
combine_document_func: CombineDocsProtocol,
**kwargs: Any,
) -> Document:
result, _ = combine_document_func(docs, **kwargs)
result = combine_document_func(docs, **kwargs)
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
for doc in docs[1:]:
for k, v in doc.metadata.items():
@ -171,15 +171,17 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
]
length_func = self.combine_document_chain.prompt_length
num_tokens = length_func(result_docs, **kwargs)
def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str:
return self._collapse_chain.run(input_documents=docs, **kwargs)
while num_tokens is not None and num_tokens > token_max:
new_result_doc_list = _split_list_of_docs(
result_docs, length_func, token_max, **kwargs
)
result_docs = []
for docs in new_result_doc_list:
new_doc = _collapse_docs(
docs, self._collapse_chain.combine_docs, **kwargs
)
new_doc = _collapse_docs(docs, _collapse_docs_func, **kwargs)
result_docs.append(new_doc)
num_tokens = self.combine_document_chain.prompt_length(
result_docs, **kwargs
@ -189,7 +191,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
extra_return_dict = {"intermediate_steps": _results}
else:
extra_return_dict = {}
output, _ = self.combine_document_chain.combine_docs(result_docs, **kwargs)
output = self.combine_document_chain.run(input_documents=result_docs, **kwargs)
return output, extra_return_dict
@property

View File

@ -80,7 +80,7 @@ class BaseConversationalRetrievalChain(Chain):
new_inputs = inputs.copy()
new_inputs["question"] = new_question
new_inputs["chat_history"] = chat_history_str
answer, _ = self.combine_docs_chain.combine_docs(docs, **new_inputs)
answer = self.combine_docs_chain.run(input_documents=docs, **new_inputs)
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
@ -104,7 +104,7 @@ class BaseConversationalRetrievalChain(Chain):
new_inputs = inputs.copy()
new_inputs["question"] = new_question
new_inputs["chat_history"] = chat_history_str
answer, _ = await self.combine_docs_chain.acombine_docs(docs, **new_inputs)
answer = await self.combine_docs_chain.arun(input_documents=docs, **new_inputs)
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:

View File

@ -70,5 +70,5 @@ class MapReduceChain(Chain):
# Split the larger text into smaller chunks.
texts = self.text_splitter.split_text(inputs[self.input_key])
docs = [Document(page_content=text) for text in texts]
outputs, _ = self.combine_documents_chain.combine_docs(docs)
outputs = self.combine_documents_chain.run(input_documents=docs)
return {self.output_key: outputs}

View File

@ -116,7 +116,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
docs = self._get_docs(inputs)
answer, _ = self.combine_documents_chain.combine_docs(docs, **inputs)
answer = self.combine_documents_chain.run(input_documents=docs, **inputs)
if re.search(r"SOURCES:\s", answer):
answer, sources = re.split(r"SOURCES:\s", answer)
else:
@ -135,7 +135,7 @@ class BaseQAWithSourcesChain(Chain, ABC):
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
docs = await self._aget_docs(inputs)
answer, _ = await self.combine_documents_chain.acombine_docs(docs, **inputs)
answer = await self.combine_documents_chain.arun(input_documents=docs, **inputs)
if re.search(r"SOURCES:\s", answer):
answer, sources = re.split(r"SOURCES:\s", answer)
else:

View File

@ -107,7 +107,9 @@ class BaseRetrievalQA(Chain):
question = inputs[self.input_key]
docs = self._get_docs(question)
answer, _ = self.combine_documents_chain.combine_docs(docs, question=question)
answer = self.combine_documents_chain.run(
input_documents=docs, question=question
)
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
@ -133,8 +135,8 @@ class BaseRetrievalQA(Chain):
question = inputs[self.input_key]
docs = await self._aget_docs(question)
answer, _ = await self.combine_documents_chain.acombine_docs(
docs, question=question
answer = await self.combine_documents_chain.arun(
input_documents=docs, question=question
)
if self.return_source_documents:

View File

@ -1,6 +1,6 @@
"""Test functionality related to combining documents."""
from typing import Any, List, Tuple
from typing import Any, List
import pytest
@ -12,11 +12,11 @@ from langchain.docstore.document import Document
def _fake_docs_len_func(docs: List[Document]) -> int:
return len(_fake_combine_docs_func(docs)[0])
return len(_fake_combine_docs_func(docs))
def _fake_combine_docs_func(docs: List[Document], **kwargs: Any) -> Tuple[str, dict]:
return "".join([d.page_content for d in docs]), {}
def _fake_combine_docs_func(docs: List[Document], **kwargs: Any) -> str:
return "".join([d.page_content for d in docs])
def test__split_list_long_single_doc() -> None: