From b82cbd1be0c6066e0aa0226f77f115bdfb91ceca Mon Sep 17 00:00:00 2001 From: Ankush Gola <9536492+agola11@users.noreply.github.com> Date: Mon, 10 Apr 2023 03:47:59 +0200 Subject: [PATCH] Use `run` and `arun` in place of `combine_docs` and `acombine_docs` (#2635) `combine_docs` does not go through the standard chain call path which means that chain callbacks won't be triggered, meaning QA chains won't be traced properly, this fixes that. Also fix several errors in the chat_vector_db notebook --- .../index_examples/chat_vector_db.ipynb | 134 +++++++++++------- .../chains/combine_documents/map_reduce.py | 14 +- .../chains/conversational_retrieval/base.py | 4 +- langchain/chains/mapreduce.py | 2 +- langchain/chains/qa_with_sources/base.py | 4 +- langchain/chains/retrieval_qa/base.py | 8 +- .../chains/test_combine_documents.py | 8 +- 7 files changed, 105 insertions(+), 69 deletions(-) diff --git a/docs/modules/chains/index_examples/chat_vector_db.ipynb b/docs/modules/chains/index_examples/chat_vector_db.ipynb index ffb3e316..b5e28b51 100644 --- a/docs/modules/chains/index_examples/chat_vector_db.ipynb +++ b/docs/modules/chains/index_examples/chat_vector_db.ipynb @@ -5,14 +5,14 @@ "id": "134a0785", "metadata": {}, "source": [ - "# Chat Index\n", + "# Chat Over Documents with Chat History\n", "\n", - "This notebook goes over how to set up a chain to chat with an index. The only difference between this chain and the [RetrievalQAChain](./vector_db_qa.ipynb) is that this allows for passing in of a chat history which can be used to allow for follow up questions." + "This notebook goes over how to set up a chain to chat over documents with chat history using a `ConversationalRetrievalChain`. The only difference between this chain and the [RetrievalQAChain](./vector_db_qa.ipynb) is that this allows for passing in of a chat history which can be used to allow for follow up questions." ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "id": "70c4e529", "metadata": { "tags": [] @@ -36,7 +36,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "id": "01c46e92", "metadata": { "tags": [] @@ -58,7 +58,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "id": "433363a5", "metadata": { "tags": [] @@ -81,7 +81,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "id": "a8930cf7", "metadata": { "tags": [] @@ -109,12 +109,12 @@ "id": "3c96b118", "metadata": {}, "source": [ - "We now initialize the ConversationalRetrievalChain" + "We now initialize the `ConversationalRetrievalChain`" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "id": "7b4110f3", "metadata": { "tags": [] @@ -134,7 +134,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "id": "7fe3e730", "metadata": { "tags": [] @@ -148,7 +148,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "id": "bfff9cc8", "metadata": { "tags": [] @@ -160,7 +160,7 @@ "\" The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. He also said that she is a consensus builder and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\"" ] }, - "execution_count": 7, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -179,7 +179,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "id": "00b4cf00", "metadata": { "tags": [] @@ -193,7 +193,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "id": "f01828d1", "metadata": { "tags": [] @@ -202,10 +202,10 @@ { "data": { "text/plain": [ - "' Justice Stephen Breyer'" + "' Ketanji Brown Jackson succeeded Justice Stephen Breyer on the United States Supreme Court.'" ] }, - "execution_count": 9, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -225,9 +225,11 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "id": "562769c6", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "qa = ConversationalRetrievalChain.from_llm(OpenAI(temperature=0), vectorstore.as_retriever(), return_source_documents=True)" @@ -235,9 +237,11 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "id": "ea478300", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "chat_history = []\n", @@ -247,17 +251,19 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "id": "4cb75b4e", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [ { "data": { "text/plain": [ - "Document(page_content='Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \\n\\nTonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \\n\\nOne of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.', lookup_str='', metadata={'source': '../../state_of_the_union.txt'}, lookup_index=0)" + "Document(page_content='Tonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections. \\n\\nTonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service. \\n\\nOne of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court. \\n\\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.', metadata={'source': '../../state_of_the_union.txt'})" ] }, - "execution_count": 13, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -277,9 +283,11 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "id": "5ed8d612", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "vectordbkwargs = {\"search_distance\": 0.9}" @@ -287,9 +295,11 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 16, "id": "6a7b3459", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "qa = ConversationalRetrievalChain.from_llm(OpenAI(temperature=0), vectorstore.as_retriever(), return_source_documents=True)\n", @@ -309,21 +319,25 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 18, "id": "e53a9d66", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "from langchain.chains import LLMChain\n", "from langchain.chains.question_answering import load_qa_chain\n", - "from langchain.chains.chat_index.prompts import CONDENSE_QUESTION_PROMPT" + "from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT" ] }, { "cell_type": "code", "execution_count": 19, "id": "bf205e35", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "llm = OpenAI(temperature=0)\n", @@ -341,7 +355,9 @@ "cell_type": "code", "execution_count": 20, "id": "78155887", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "chat_history = []\n", @@ -353,7 +369,9 @@ "cell_type": "code", "execution_count": 21, "id": "e54b5fa2", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [ { "data": { @@ -384,7 +402,9 @@ "cell_type": "code", "execution_count": 22, "id": "d1058fd2", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "from langchain.chains.qa_with_sources import load_qa_with_sources_chain" @@ -394,7 +414,9 @@ "cell_type": "code", "execution_count": 23, "id": "a6594482", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "llm = OpenAI(temperature=0)\n", @@ -412,7 +434,9 @@ "cell_type": "code", "execution_count": 24, "id": "e2badd21", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "chat_history = []\n", @@ -424,7 +448,9 @@ "cell_type": "code", "execution_count": 25, "id": "edb31fe5", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [ { "data": { @@ -453,7 +479,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 27, "id": "2efacec3-2690-4b05-8de3-a32fd2ac3911", "metadata": { "tags": [] @@ -463,7 +489,7 @@ "from langchain.chains.llm import LLMChain\n", "from langchain.callbacks.base import CallbackManager\n", "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n", - "from langchain.chains.chat_index.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT\n", + "from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT, QA_PROMPT\n", "from langchain.chains.question_answering import load_qa_chain\n", "\n", "# Construct a ConversationalRetrievalChain with a streaming llm for combine docs\n", @@ -480,7 +506,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 28, "id": "fd6d43f4-7428-44a4-81bc-26fe88a98762", "metadata": { "tags": [] @@ -502,7 +528,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 29, "id": "5ab38978-f3e8-4fa7-808c-c79dec48379a", "metadata": { "tags": [] @@ -512,7 +538,7 @@ "name": "stdout", "output_type": "stream", "text": [ - " Justice Stephen Breyer" + " Ketanji Brown Jackson succeeded Justice Stephen Breyer on the United States Supreme Court." ] } ], @@ -533,9 +559,11 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 31, "id": "a7ba9d8c", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "def get_chat_history(inputs) -> str:\n", @@ -543,14 +571,16 @@ " for human, ai in inputs:\n", " res.append(f\"Human:{human}\\nAI:{ai}\")\n", " return \"\\n\".join(res)\n", - "qa = ConversationalRetrievalChain.from_llm(OpenAI(temperature=0), vectorstore, get_chat_history=get_chat_history)" + "qa = ConversationalRetrievalChain.from_llm(OpenAI(temperature=0), vectorstore.as_retriever(), get_chat_history=get_chat_history)" ] }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 32, "id": "a3e33c0d", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "chat_history = []\n", @@ -560,9 +590,11 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 33, "id": "936dc62f", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [ { "data": { @@ -570,7 +602,7 @@ "\" The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a former top litigator in private practice, a former federal public defender, and from a family of public school educators and police officers. He also said that she is a consensus builder and has received a broad range of support from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\"" ] }, - "execution_count": 31, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } @@ -604,7 +636,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.10.9" } }, "nbformat": 4, diff --git a/langchain/chains/combine_documents/map_reduce.py b/langchain/chains/combine_documents/map_reduce.py index 19d5478b..b439870d 100644 --- a/langchain/chains/combine_documents/map_reduce.py +++ b/langchain/chains/combine_documents/map_reduce.py @@ -14,7 +14,7 @@ from langchain.docstore.document import Document class CombineDocsProtocol(Protocol): """Interface for the combine_docs method.""" - def __call__(self, docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: + def __call__(self, docs: List[Document], **kwargs: Any) -> str: """Interface for the combine_docs method.""" @@ -48,7 +48,7 @@ def _collapse_docs( combine_document_func: CombineDocsProtocol, **kwargs: Any, ) -> Document: - result, _ = combine_document_func(docs, **kwargs) + result = combine_document_func(docs, **kwargs) combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()} for doc in docs[1:]: for k, v in doc.metadata.items(): @@ -171,15 +171,17 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): ] length_func = self.combine_document_chain.prompt_length num_tokens = length_func(result_docs, **kwargs) + + def _collapse_docs_func(docs: List[Document], **kwargs: Any) -> str: + return self._collapse_chain.run(input_documents=docs, **kwargs) + while num_tokens is not None and num_tokens > token_max: new_result_doc_list = _split_list_of_docs( result_docs, length_func, token_max, **kwargs ) result_docs = [] for docs in new_result_doc_list: - new_doc = _collapse_docs( - docs, self._collapse_chain.combine_docs, **kwargs - ) + new_doc = _collapse_docs(docs, _collapse_docs_func, **kwargs) result_docs.append(new_doc) num_tokens = self.combine_document_chain.prompt_length( result_docs, **kwargs @@ -189,7 +191,7 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain): extra_return_dict = {"intermediate_steps": _results} else: extra_return_dict = {} - output, _ = self.combine_document_chain.combine_docs(result_docs, **kwargs) + output = self.combine_document_chain.run(input_documents=result_docs, **kwargs) return output, extra_return_dict @property diff --git a/langchain/chains/conversational_retrieval/base.py b/langchain/chains/conversational_retrieval/base.py index b1df5fb1..97424ecb 100644 --- a/langchain/chains/conversational_retrieval/base.py +++ b/langchain/chains/conversational_retrieval/base.py @@ -80,7 +80,7 @@ class BaseConversationalRetrievalChain(Chain): new_inputs = inputs.copy() new_inputs["question"] = new_question new_inputs["chat_history"] = chat_history_str - answer, _ = self.combine_docs_chain.combine_docs(docs, **new_inputs) + answer = self.combine_docs_chain.run(input_documents=docs, **new_inputs) if self.return_source_documents: return {self.output_key: answer, "source_documents": docs} else: @@ -104,7 +104,7 @@ class BaseConversationalRetrievalChain(Chain): new_inputs = inputs.copy() new_inputs["question"] = new_question new_inputs["chat_history"] = chat_history_str - answer, _ = await self.combine_docs_chain.acombine_docs(docs, **new_inputs) + answer = await self.combine_docs_chain.arun(input_documents=docs, **new_inputs) if self.return_source_documents: return {self.output_key: answer, "source_documents": docs} else: diff --git a/langchain/chains/mapreduce.py b/langchain/chains/mapreduce.py index bcaccd3a..062a9431 100644 --- a/langchain/chains/mapreduce.py +++ b/langchain/chains/mapreduce.py @@ -70,5 +70,5 @@ class MapReduceChain(Chain): # Split the larger text into smaller chunks. texts = self.text_splitter.split_text(inputs[self.input_key]) docs = [Document(page_content=text) for text in texts] - outputs, _ = self.combine_documents_chain.combine_docs(docs) + outputs = self.combine_documents_chain.run(input_documents=docs) return {self.output_key: outputs} diff --git a/langchain/chains/qa_with_sources/base.py b/langchain/chains/qa_with_sources/base.py index fd3d2372..5c6317ed 100644 --- a/langchain/chains/qa_with_sources/base.py +++ b/langchain/chains/qa_with_sources/base.py @@ -116,7 +116,7 @@ class BaseQAWithSourcesChain(Chain, ABC): def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]: docs = self._get_docs(inputs) - answer, _ = self.combine_documents_chain.combine_docs(docs, **inputs) + answer = self.combine_documents_chain.run(input_documents=docs, **inputs) if re.search(r"SOURCES:\s", answer): answer, sources = re.split(r"SOURCES:\s", answer) else: @@ -135,7 +135,7 @@ class BaseQAWithSourcesChain(Chain, ABC): async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, Any]: docs = await self._aget_docs(inputs) - answer, _ = await self.combine_documents_chain.acombine_docs(docs, **inputs) + answer = await self.combine_documents_chain.arun(input_documents=docs, **inputs) if re.search(r"SOURCES:\s", answer): answer, sources = re.split(r"SOURCES:\s", answer) else: diff --git a/langchain/chains/retrieval_qa/base.py b/langchain/chains/retrieval_qa/base.py index cf89c99d..bcc26b16 100644 --- a/langchain/chains/retrieval_qa/base.py +++ b/langchain/chains/retrieval_qa/base.py @@ -107,7 +107,9 @@ class BaseRetrievalQA(Chain): question = inputs[self.input_key] docs = self._get_docs(question) - answer, _ = self.combine_documents_chain.combine_docs(docs, question=question) + answer = self.combine_documents_chain.run( + input_documents=docs, question=question + ) if self.return_source_documents: return {self.output_key: answer, "source_documents": docs} @@ -133,8 +135,8 @@ class BaseRetrievalQA(Chain): question = inputs[self.input_key] docs = await self._aget_docs(question) - answer, _ = await self.combine_documents_chain.acombine_docs( - docs, question=question + answer = await self.combine_documents_chain.arun( + input_documents=docs, question=question ) if self.return_source_documents: diff --git a/tests/unit_tests/chains/test_combine_documents.py b/tests/unit_tests/chains/test_combine_documents.py index fca09f4a..095f216d 100644 --- a/tests/unit_tests/chains/test_combine_documents.py +++ b/tests/unit_tests/chains/test_combine_documents.py @@ -1,6 +1,6 @@ """Test functionality related to combining documents.""" -from typing import Any, List, Tuple +from typing import Any, List import pytest @@ -12,11 +12,11 @@ from langchain.docstore.document import Document def _fake_docs_len_func(docs: List[Document]) -> int: - return len(_fake_combine_docs_func(docs)[0]) + return len(_fake_combine_docs_func(docs)) -def _fake_combine_docs_func(docs: List[Document], **kwargs: Any) -> Tuple[str, dict]: - return "".join([d.page_content for d in docs]), {} +def _fake_combine_docs_func(docs: List[Document], **kwargs: Any) -> str: + return "".join([d.page_content for d in docs]) def test__split_list_long_single_doc() -> None: