Harrison/map reduce merge (#344)

Co-authored-by: John Nay <JohnNay@users.noreply.github.com>
harrison/agent_multi_inputs^2
Harrison Chase 1 year ago committed by GitHub
parent ed143b598f
commit c1b50b7b13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -159,6 +159,14 @@
"id": "e417926a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n",
"Token indices sequence length is longer than the specified maximum sequence length for this model (1546 > 1024). Running this sequence through the model will result in indexing errors\n"
]
},
{
"data": {
"text/plain": [
@ -204,7 +212,7 @@
{
"data": {
"text/plain": [
"{'output_text': \"\\n\\nThe president did not mention Justice Breyer in his speech to the European Parliament. He discussed the situation in Ukraine, the NATO Alliance, and the United States' response to Putin's attack on Ukraine. He spoke about the extensive preparation and coalition building that was done in advance of the attack, and the unified response from the European Union, Canada, Japan, Korea, Australia, New Zealand, and many other countries. He also discussed the economic sanctions that have been imposed on Russia, and the effects they have had on Putin's war fund. Source: 1, 2\"}"
"{'output_text': \"\\n\\nThe president did not mention Justice Breyer in his speech to the European Parliament, which focused on building a coalition of freedom-loving nations to confront Putin, unifying European allies, countering Russia's lies with truth, and enforcing powerful economic sanctions. Source: 2\"}"
]
},
"execution_count": 12,

@ -1,7 +1,7 @@
"""Base interface for chains combining documents."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
@ -31,6 +31,13 @@ class BaseCombineDocumentsChain(Chain, BaseModel, ABC):
"""
return [self.output_key]
def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]:
"""Return the prompt length given the documents passed in.
Returns None if the method does not depend on the prompt length.
"""
return None
@abstractmethod
def combine_docs(self, docs: List[Document], **kwargs: Any) -> str:
"""Combine documents into a single string."""

@ -2,7 +2,7 @@
from __future__ import annotations
from typing import Any, Dict, List
from typing import Any, Callable, Dict, List
from pydantic import BaseModel, Extra, root_validator
@ -11,6 +11,47 @@ from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
def _split_list_of_docs(
docs: List[Document], length_func: Callable, token_max: int, **kwargs: Any
) -> List[List[Document]]:
new_result_doc_list = []
_sub_result_docs = []
for doc in docs:
_sub_result_docs.append(doc)
_num_tokens = length_func(_sub_result_docs, **kwargs)
if _num_tokens > token_max:
if len(_sub_result_docs) == 1:
raise ValueError(
"A single document was longer than the context length,"
" we cannot handle this."
)
if len(_sub_result_docs) == 2:
raise ValueError(
"A single document was so long it could not be combined "
"with another document, we cannot handle this."
)
new_result_doc_list.append(_sub_result_docs[:-1])
_sub_result_docs = _sub_result_docs[-1:]
new_result_doc_list.append(_sub_result_docs)
return new_result_doc_list
def _collapse_docs(
docs: List[Document],
combine_document_func: Callable,
**kwargs: Any,
) -> Document:
result = combine_document_func(docs, **kwargs)
combined_metadata = {k: str(v) for k, v in docs[0].metadata.items()}
for doc in docs[1:]:
for k, v in doc.metadata.items():
if k in combined_metadata:
combined_metadata[k] += f", {v}"
else:
combined_metadata[k] = str(v)
return Document(page_content=result, metadata=combined_metadata)
class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
"""Combining documents by mapping a chain over them, then combining results."""
@ -49,14 +90,38 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain, BaseModel):
)
return values
def combine_docs(self, docs: List[Document], **kwargs: Any) -> str:
"""Combine by mapping first chain over all, then stuffing into final chain."""
def combine_docs(
self, docs: List[Document], token_max: int = 3000, **kwargs: Any
) -> str:
"""Combine documents in a map reduce manner.
Combine by mapping first chain over all documents, then reducing the results.
This reducing can be done recursively if needed (if there are many documents).
"""
results = self.llm_chain.apply(
# FYI - this is parallelized and so it is fast.
[{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs]
)
question_result_key = self.llm_chain.output_key
result_docs = [
Document(page_content=r[question_result_key], metadata=docs[i].metadata)
# This uses metadata from the docs, and the textual results from `results`
for i, r in enumerate(results)
]
return self.combine_document_chain.combine_docs(result_docs, **kwargs)
length_func = self.combine_document_chain.prompt_length
num_tokens = length_func(result_docs, **kwargs)
while num_tokens is not None and num_tokens > token_max:
new_result_doc_list = _split_list_of_docs(
result_docs, length_func, token_max, **kwargs
)
result_docs = []
for docs in new_result_doc_list:
new_doc = _collapse_docs(
docs, self.combine_document_chain.combine_docs, **kwargs
)
result_docs.append(new_doc)
num_tokens = self.combine_document_chain.prompt_length(
result_docs, **kwargs
)
output = self.combine_document_chain.combine_docs(result_docs, **kwargs)
return output

@ -1,6 +1,6 @@
"""Chain that combines documents by stuffing into context."""
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, Field, root_validator
@ -55,8 +55,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain, BaseModel):
)
return values
def combine_docs(self, docs: List[Document], **kwargs: Any) -> str:
"""Stuff all documents into one prompt and pass to LLM."""
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
# Get relevant information from each document.
doc_dicts = []
for doc in docs:
@ -71,5 +70,16 @@ class StuffDocumentsChain(BaseCombineDocumentsChain, BaseModel):
# Join the documents together to put them in the prompt.
inputs = kwargs.copy()
inputs[self.document_variable_name] = "\n\n".join(doc_strings)
return inputs
def prompt_length(self, docs: List[Document], **kwargs: Any) -> Optional[int]:
"""Get the prompt length by formatting the prompt."""
inputs = self._get_inputs(docs, **kwargs)
prompt = self.llm_chain.prompt.format(**inputs)
return self.llm_chain.llm.get_num_tokens(prompt)
def combine_docs(self, docs: List[Document], **kwargs: Any) -> str:
"""Stuff all documents into one prompt and pass to LLM."""
inputs = self._get_inputs(docs, **kwargs)
# Call predict on the LLM.
return self.llm_chain.predict(**inputs)

@ -0,0 +1,118 @@
"""Test functionality related to combining documents."""
from typing import List
import pytest
from langchain.chains.combine_documents.map_reduce import (
_collapse_docs,
_split_list_of_docs,
)
from langchain.docstore.document import Document
def _fake_docs_len_func(docs: List[Document]) -> int:
return len(_fake_combine_docs_func(docs))
def _fake_combine_docs_func(docs: List[Document]) -> str:
return "".join([d.page_content for d in docs])
def test__split_list_long_single_doc() -> None:
"""Test splitting of a long single doc."""
docs = [Document(page_content="foo" * 100)]
with pytest.raises(ValueError):
_split_list_of_docs(docs, _fake_docs_len_func, 100)
def test__split_list_long_pair_doc() -> None:
"""Test splitting of a list with two medium docs."""
docs = [Document(page_content="foo" * 30)] * 2
with pytest.raises(ValueError):
_split_list_of_docs(docs, _fake_docs_len_func, 100)
def test__split_list_single_doc() -> None:
"""Test splitting works with just a single doc."""
docs = [Document(page_content="foo")]
doc_list = _split_list_of_docs(docs, _fake_docs_len_func, 100)
assert doc_list == [docs]
def test__split_list_double_doc() -> None:
"""Test splitting works with just two docs."""
docs = [Document(page_content="foo"), Document(page_content="bar")]
doc_list = _split_list_of_docs(docs, _fake_docs_len_func, 100)
assert doc_list == [docs]
def test__split_list_works_correctly() -> None:
"""Test splitting works correctly."""
docs = [
Document(page_content="foo"),
Document(page_content="bar"),
Document(page_content="baz"),
Document(page_content="foo" * 2),
Document(page_content="bar"),
Document(page_content="baz"),
]
doc_list = _split_list_of_docs(docs, _fake_docs_len_func, 10)
expected_result = [
# Test a group of three.
[
Document(page_content="foo"),
Document(page_content="bar"),
Document(page_content="baz"),
],
# Test a group of two, where one is bigger.
[Document(page_content="foo" * 2), Document(page_content="bar")],
# Test no errors on last
[Document(page_content="baz")],
]
assert doc_list == expected_result
def test__collapse_docs_no_metadata() -> None:
"""Test collapse documents functionality when no metadata."""
docs = [
Document(page_content="foo"),
Document(page_content="bar"),
Document(page_content="baz"),
]
output = _collapse_docs(docs, _fake_combine_docs_func)
expected_output = Document(page_content="foobarbaz")
assert output == expected_output
def test__collapse_docs_one_doc() -> None:
"""Test collapse documents functionality when only one document present."""
# Test with no metadata.
docs = [Document(page_content="foo")]
output = _collapse_docs(docs, _fake_combine_docs_func)
assert output == docs[0]
# Test with metadata.
docs = [Document(page_content="foo", metadata={"source": "a"})]
output = _collapse_docs(docs, _fake_combine_docs_func)
assert output == docs[0]
def test__collapse_docs_metadata() -> None:
"""Test collapse documents functionality when metadata exists."""
metadata1 = {"source": "a", "foo": 2, "bar": "1", "extra1": "foo"}
metadata2 = {"source": "b", "foo": "3", "bar": 2, "extra2": "bar"}
docs = [
Document(page_content="foo", metadata=metadata1),
Document(page_content="bar", metadata=metadata2),
]
output = _collapse_docs(docs, _fake_combine_docs_func)
expected_metadata = {
"source": "a, b",
"foo": "2, 3",
"bar": "1, 2",
"extra1": "foo",
"extra2": "bar",
}
expected_output = Document(page_content="foobar", metadata=expected_metadata)
assert output == expected_output
Loading…
Cancel
Save