mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
19c85aa990
@cnhhoang850 slightly more generic fix for #2944, works for whatever the expected metadata keys are not just `source`
142 lines
4.7 KiB
Python
142 lines
4.7 KiB
Python
"""Test functionality related to combining documents."""
|
|
|
|
from typing import Any, List
|
|
|
|
import pytest
|
|
|
|
from langchain import PromptTemplate
|
|
from langchain.chains.combine_documents.base import format_document
|
|
from langchain.chains.combine_documents.map_reduce import (
|
|
_collapse_docs,
|
|
_split_list_of_docs,
|
|
)
|
|
from langchain.docstore.document import Document
|
|
|
|
|
|
def _fake_docs_len_func(docs: List[Document]) -> int:
|
|
return len(_fake_combine_docs_func(docs))
|
|
|
|
|
|
def _fake_combine_docs_func(docs: List[Document], **kwargs: Any) -> str:
|
|
return "".join([d.page_content for d in docs])
|
|
|
|
|
|
def test__split_list_long_single_doc() -> None:
|
|
"""Test splitting of a long single doc."""
|
|
docs = [Document(page_content="foo" * 100)]
|
|
with pytest.raises(ValueError):
|
|
_split_list_of_docs(docs, _fake_docs_len_func, 100)
|
|
|
|
|
|
def test__split_list_long_pair_doc() -> None:
|
|
"""Test splitting of a list with two medium docs."""
|
|
docs = [Document(page_content="foo" * 30)] * 2
|
|
with pytest.raises(ValueError):
|
|
_split_list_of_docs(docs, _fake_docs_len_func, 100)
|
|
|
|
|
|
def test__split_list_single_doc() -> None:
|
|
"""Test splitting works with just a single doc."""
|
|
docs = [Document(page_content="foo")]
|
|
doc_list = _split_list_of_docs(docs, _fake_docs_len_func, 100)
|
|
assert doc_list == [docs]
|
|
|
|
|
|
def test__split_list_double_doc() -> None:
|
|
"""Test splitting works with just two docs."""
|
|
docs = [Document(page_content="foo"), Document(page_content="bar")]
|
|
doc_list = _split_list_of_docs(docs, _fake_docs_len_func, 100)
|
|
assert doc_list == [docs]
|
|
|
|
|
|
def test__split_list_works_correctly() -> None:
|
|
"""Test splitting works correctly."""
|
|
docs = [
|
|
Document(page_content="foo"),
|
|
Document(page_content="bar"),
|
|
Document(page_content="baz"),
|
|
Document(page_content="foo" * 2),
|
|
Document(page_content="bar"),
|
|
Document(page_content="baz"),
|
|
]
|
|
doc_list = _split_list_of_docs(docs, _fake_docs_len_func, 10)
|
|
expected_result = [
|
|
# Test a group of three.
|
|
[
|
|
Document(page_content="foo"),
|
|
Document(page_content="bar"),
|
|
Document(page_content="baz"),
|
|
],
|
|
# Test a group of two, where one is bigger.
|
|
[Document(page_content="foo" * 2), Document(page_content="bar")],
|
|
# Test no errors on last
|
|
[Document(page_content="baz")],
|
|
]
|
|
assert doc_list == expected_result
|
|
|
|
|
|
def test__collapse_docs_no_metadata() -> None:
|
|
"""Test collapse documents functionality when no metadata."""
|
|
docs = [
|
|
Document(page_content="foo"),
|
|
Document(page_content="bar"),
|
|
Document(page_content="baz"),
|
|
]
|
|
output = _collapse_docs(docs, _fake_combine_docs_func)
|
|
expected_output = Document(page_content="foobarbaz")
|
|
assert output == expected_output
|
|
|
|
|
|
def test__collapse_docs_one_doc() -> None:
|
|
"""Test collapse documents functionality when only one document present."""
|
|
# Test with no metadata.
|
|
docs = [Document(page_content="foo")]
|
|
output = _collapse_docs(docs, _fake_combine_docs_func)
|
|
assert output == docs[0]
|
|
|
|
# Test with metadata.
|
|
docs = [Document(page_content="foo", metadata={"source": "a"})]
|
|
output = _collapse_docs(docs, _fake_combine_docs_func)
|
|
assert output == docs[0]
|
|
|
|
|
|
def test__collapse_docs_metadata() -> None:
|
|
"""Test collapse documents functionality when metadata exists."""
|
|
metadata1 = {"source": "a", "foo": 2, "bar": "1", "extra1": "foo"}
|
|
metadata2 = {"source": "b", "foo": "3", "bar": 2, "extra2": "bar"}
|
|
docs = [
|
|
Document(page_content="foo", metadata=metadata1),
|
|
Document(page_content="bar", metadata=metadata2),
|
|
]
|
|
output = _collapse_docs(docs, _fake_combine_docs_func)
|
|
expected_metadata = {
|
|
"source": "a, b",
|
|
"foo": "2, 3",
|
|
"bar": "1, 2",
|
|
"extra1": "foo",
|
|
"extra2": "bar",
|
|
}
|
|
expected_output = Document(page_content="foobar", metadata=expected_metadata)
|
|
assert output == expected_output
|
|
|
|
|
|
def test_format_doc_with_metadata() -> None:
|
|
"""Test format doc on a valid document."""
|
|
doc = Document(page_content="foo", metadata={"bar": "baz"})
|
|
prompt = PromptTemplate(
|
|
input_variables=["page_content", "bar"], template="{page_content}, {bar}"
|
|
)
|
|
expected_output = "foo, baz"
|
|
output = format_document(doc, prompt)
|
|
assert output == expected_output
|
|
|
|
|
|
def test_format_doc_missing_metadata() -> None:
|
|
"""Test format doc on a document with missing metadata."""
|
|
doc = Document(page_content="foo")
|
|
prompt = PromptTemplate(
|
|
input_variables=["page_content", "bar"], template="{page_content}, {bar}"
|
|
)
|
|
with pytest.raises(ValueError):
|
|
format_document(doc, prompt)
|