From c1b50b7b13545b1549c051182f547cfc2f8ed0be Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 15 Dec 2022 17:49:14 -0800 Subject: [PATCH] Harrison/map reduce merge (#344) Co-authored-by: John Nay --- docs/examples/chains/qa_with_sources.ipynb | 10 +- langchain/chains/combine_documents/base.py | 9 +- .../chains/combine_documents/map_reduce.py | 73 ++++++++++- langchain/chains/combine_documents/stuff.py | 16 ++- .../chains/test_combine_documents.py | 118 ++++++++++++++++++ 5 files changed, 217 insertions(+), 9 deletions(-) create mode 100644 tests/unit_tests/chains/test_combine_documents.py diff --git a/docs/examples/chains/qa_with_sources.ipynb b/docs/examples/chains/qa_with_sources.ipynb index 1135e86c1f..29454eae23 100644 --- a/docs/examples/chains/qa_with_sources.ipynb +++ b/docs/examples/chains/qa_with_sources.ipynb @@ -159,6 +159,14 @@ "id": "e417926a", "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n", + "Token indices sequence length is longer than the specified maximum sequence length for this model (1546 > 1024). Running this sequence through the model will result in indexing errors\n" + ] + }, { "data": { "text/plain": [ @@ -204,7 +212,7 @@ { "data": { "text/plain": [ - "{'output_text': \"\\n\\nThe president did not mention Justice Breyer in his speech to the European Parliament. He discussed the situation in Ukraine, the NATO Alliance, and the United States' response to Putin's attack on Ukraine. He spoke about the extensive preparation and coalition building that was done in advance of the attack, and the unified response from the European Union, Canada, Japan, Korea, Australia, New Zealand, and many other countries. He also discussed the economic sanctions that have been imposed on Russia, and the effects they have had on Putin's war fund. Source: 1, 2\"}" + "{'output_text': \"\\n\\nThe president did not mention Justice Breyer in his speech to the European Parliament, which focused on building a coalition of freedom-loving nations to confront Putin, unifying European allies, countering Russia's lies with truth, and enforcing powerful economic sanctions. Source: 2\"}" ] }, "execution_count": 12, diff --git a/langchain/chains/combine_documents/base.py b/langchain/chains/combine_documents/base.py index 26cbcb89fe..7d5574caae 100644 --- a/langchain/chains/combine_documents/base.py +++ b/langchain/chains/combine_documents/base.py @@ -1,7 +1,7 @@ """Base interface for chains combining documents.""" from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from pydantic import BaseModel @@ -31,6 +31,13 @@ class BaseCombineDocumentsChain(Chain, BaseModel, ABC): """ return [self.output_key] + def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]: + """Return the prompt length given the documents passed in. + + Returns None if the method does not depend on the prompt length. + """ + return None + @abstractmethod def combine_docs(self, docs: List[Document], **kwargs: Any) -> str: """Combine documents into a single string.""" diff --git a/langchain/chains/combine_documents/map_reduce.py b/langchain/chains/combine_documents/map_reduce.py index b182e9e553..ff74bc4f85 100644 --- a/langchain/chains/combine_documents/map_reduce.py +++ b/langchain/chains/combine_documents/map_reduce.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict, List +from typing import Any, Callable, Dict, List from pydantic import BaseModel, Extra, root_validator @@ -11,6 +11,47 @@ from langchain.chains.llm import LLMChain from langchain.docstore.document import Document +def _split_list_of_docs( + docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any +) -> List[List[Document]]: + new_result_doc_list = [] + _sub_result_docs = [] + for doc in docs: + _sub_result_docs.append(doc) + _num_tokens = length_func(_sub_result_docs, **kwargs) + if _num_tokens > token_max: + if len(_sub_result_docs) == 1: + raise ValueError( + "A single document was longer than the context length," + " we cannot handle this." + ) + if len(_sub_result_docs) == 2: + raise ValueError( + "A single document was so long it could not be combined " + "with another document, we cannot handle this." + ) + new_result_doc_list.append(_sub_result_docs[:-1]) + _sub_result_docs = _sub_result_docs[-1:] + new_result_doc_list.append(_sub_result_docs) + return new_result_doc_list + + +def _collapse_docs( + docs: List[Document], + combine_document_func: Callable, + **kwargs: Any, +) -> Document: + result = combine_document_func(docs, **kwargs) + combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()} + for doc in docs[1:]: + for k, v in doc.metadata.items(): + if k in combined_metadata: + combined_metadata[k] += f", {v}" + else: + combined_metadata[k] = str(v) + return Document(page_content=result, metadata=combined_metadata) + + class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel): """Combining documents by mapping a chain over them, then combining results.""" @@ -49,14 +90,38 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel): ) return values - def combine_docs(self, docs: List[Document], **kwargs: Any) -> str: - """Combine by mapping first chain over all, then stuffing into final chain.""" + def combine_docs( + self, docs: List[Document], token_max: int = 3000, **kwargs: Any + ) -> str: + """Combine documents in a map reduce manner. + + Combine by mapping first chain over all documents, then reducing the results. + This reducing can be done recursively if needed (if there are many documents). + """ results = self.llm_chain.apply( + # FYI - this is parallelized and so it is fast. [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs] ) question_result_key = self.llm_chain.output_key result_docs = [ Document(page_content=r[question_result_key], metadata=docs[i].metadata) + # This uses metadata from the docs, and the textual results from `results` for i, r in enumerate(results) ] - return self.combine_document_chain.combine_docs(result_docs, **kwargs) + length_func = self.combine_document_chain.prompt_length + num_tokens = length_func(result_docs, **kwargs) + while num_tokens is not None and num_tokens > token_max: + new_result_doc_list = _split_list_of_docs( + result_docs, length_func, token_max, **kwargs + ) + result_docs = [] + for docs in new_result_doc_list: + new_doc = _collapse_docs( + docs, self.combine_document_chain.combine_docs, **kwargs + ) + result_docs.append(new_doc) + num_tokens = self.combine_document_chain.prompt_length( + result_docs, **kwargs + ) + output = self.combine_document_chain.combine_docs(result_docs, **kwargs) + return output diff --git a/langchain/chains/combine_documents/stuff.py b/langchain/chains/combine_documents/stuff.py index 7b56c2113a..796de39e37 100644 --- a/langchain/chains/combine_documents/stuff.py +++ b/langchain/chains/combine_documents/stuff.py @@ -1,6 +1,6 @@ """Chain that combines documents by stuffing into context.""" -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Extra, Field, root_validator @@ -55,8 +55,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain, BaseModel): ) return values - def combine_docs(self, docs: List[Document], **kwargs: Any) -> str: - """Stuff all documents into one prompt and pass to LLM.""" + def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict: # Get relevant information from each document. doc_dicts = [] for doc in docs: @@ -71,5 +70,16 @@ class StuffDocumentsChain(BaseCombineDocumentsChain, BaseModel): # Join the documents together to put them in the prompt. inputs = kwargs.copy() inputs[self.document_variable_name] = "\n\n".join(doc_strings) + return inputs + + def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]: + """Get the prompt length by formatting the prompt.""" + inputs = self._get_inputs(docs, **kwargs) + prompt = self.llm_chain.prompt.format(**inputs) + return self.llm_chain.llm.get_num_tokens(prompt) + + def combine_docs(self, docs: List[Document], **kwargs: Any) -> str: + """Stuff all documents into one prompt and pass to LLM.""" + inputs = self._get_inputs(docs, **kwargs) # Call predict on the LLM. return self.llm_chain.predict(**inputs) diff --git a/tests/unit_tests/chains/test_combine_documents.py b/tests/unit_tests/chains/test_combine_documents.py new file mode 100644 index 0000000000..81468f9718 --- /dev/null +++ b/tests/unit_tests/chains/test_combine_documents.py @@ -0,0 +1,118 @@ +"""Test functionality related to combining documents.""" + +from typing import List + +import pytest + +from langchain.chains.combine_documents.map_reduce import ( + _collapse_docs, + _split_list_of_docs, +) +from langchain.docstore.document import Document + + +def _fake_docs_len_func(docs: List[Document]) -> int: + return len(_fake_combine_docs_func(docs)) + + +def _fake_combine_docs_func(docs: List[Document]) -> str: + return "".join([d.page_content for d in docs]) + + +def test__split_list_long_single_doc() -> None: + """Test splitting of a long single doc.""" + docs = [Document(page_content="foo" * 100)] + with pytest.raises(ValueError): + _split_list_of_docs(docs, _fake_docs_len_func, 100) + + +def test__split_list_long_pair_doc() -> None: + """Test splitting of a list with two medium docs.""" + docs = [Document(page_content="foo" * 30)] * 2 + with pytest.raises(ValueError): + _split_list_of_docs(docs, _fake_docs_len_func, 100) + + +def test__split_list_single_doc() -> None: + """Test splitting works with just a single doc.""" + docs = [Document(page_content="foo")] + doc_list = _split_list_of_docs(docs, _fake_docs_len_func, 100) + assert doc_list == [docs] + + +def test__split_list_double_doc() -> None: + """Test splitting works with just two docs.""" + docs = [Document(page_content="foo"), Document(page_content="bar")] + doc_list = _split_list_of_docs(docs, _fake_docs_len_func, 100) + assert doc_list == [docs] + + +def test__split_list_works_correctly() -> None: + """Test splitting works correctly.""" + docs = [ + Document(page_content="foo"), + Document(page_content="bar"), + Document(page_content="baz"), + Document(page_content="foo" * 2), + Document(page_content="bar"), + Document(page_content="baz"), + ] + doc_list = _split_list_of_docs(docs, _fake_docs_len_func, 10) + expected_result = [ + # Test a group of three. + [ + Document(page_content="foo"), + Document(page_content="bar"), + Document(page_content="baz"), + ], + # Test a group of two, where one is bigger. + [Document(page_content="foo" * 2), Document(page_content="bar")], + # Test no errors on last + [Document(page_content="baz")], + ] + assert doc_list == expected_result + + +def test__collapse_docs_no_metadata() -> None: + """Test collapse documents functionality when no metadata.""" + docs = [ + Document(page_content="foo"), + Document(page_content="bar"), + Document(page_content="baz"), + ] + output = _collapse_docs(docs, _fake_combine_docs_func) + expected_output = Document(page_content="foobarbaz") + assert output == expected_output + + +def test__collapse_docs_one_doc() -> None: + """Test collapse documents functionality when only one document present.""" + # Test with no metadata. + docs = [Document(page_content="foo")] + output = _collapse_docs(docs, _fake_combine_docs_func) + assert output == docs[0] + + # Test with metadata. + docs = [Document(page_content="foo", metadata={"source": "a"})] + output = _collapse_docs(docs, _fake_combine_docs_func) + assert output == docs[0] + + +def test__collapse_docs_metadata() -> None: + """Test collapse documents functionality when metadata exists.""" + metadata1 = {"source": "a", "foo": 2, "bar": "1", "extra1": "foo"} + metadata2 = {"source": "b", "foo": "3", "bar": 2, "extra2": "bar"} + docs = [ + Document(page_content="foo", metadata=metadata1), + Document(page_content="bar", metadata=metadata2), + ] + output = _collapse_docs(docs, _fake_combine_docs_func) + expected_metadata = { + "source": "a, b", + "foo": "2, 3", + "bar": "1, 2", + "extra1": "foo", + "extra2": "bar", + } + expected_output = Document(page_content="foobar", metadata=expected_metadata) + assert output == expected_output