mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
d56313acba
# Improve TextSplitter.split_documents, collect page_content and metadata in one iteration ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: @eyurtsev In the case where documents is a generator that can only be iterated once making this change is a huge help. Otherwise a silent issue happens where metadata is empty for all documents when documents is a generator. So we expand the argument from `List[Document]` to `Union[Iterable[Document], Sequence[Document]]` --------- Co-authored-by: Steven Tartakovsky <tartakovsky.developer@gmail.com>
171 lines
6.0 KiB
Python
171 lines
6.0 KiB
Python
"""Test text splitting functionality."""
|
|
import pytest
|
|
|
|
from langchain.docstore.document import Document
|
|
from langchain.text_splitter import (
|
|
CharacterTextSplitter,
|
|
RecursiveCharacterTextSplitter,
|
|
)
|
|
|
|
|
|
def test_character_text_splitter() -> None:
|
|
"""Test splitting by character count."""
|
|
text = "foo bar baz 123"
|
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=7, chunk_overlap=3)
|
|
output = splitter.split_text(text)
|
|
expected_output = ["foo bar", "bar baz", "baz 123"]
|
|
assert output == expected_output
|
|
|
|
|
|
def test_character_text_splitter_empty_doc() -> None:
|
|
"""Test splitting by character count doesn't create empty documents."""
|
|
text = "foo bar"
|
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=2, chunk_overlap=0)
|
|
output = splitter.split_text(text)
|
|
expected_output = ["foo", "bar"]
|
|
assert output == expected_output
|
|
|
|
|
|
def test_character_text_splitter_separtor_empty_doc() -> None:
|
|
"""Test edge cases are separators."""
|
|
text = "f b"
|
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=2, chunk_overlap=0)
|
|
output = splitter.split_text(text)
|
|
expected_output = ["f", "b"]
|
|
assert output == expected_output
|
|
|
|
|
|
def test_character_text_splitter_long() -> None:
|
|
"""Test splitting by character count on long words."""
|
|
text = "foo bar baz a a"
|
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=1)
|
|
output = splitter.split_text(text)
|
|
expected_output = ["foo", "bar", "baz", "a a"]
|
|
assert output == expected_output
|
|
|
|
|
|
def test_character_text_splitter_short_words_first() -> None:
|
|
"""Test splitting by character count when shorter words are first."""
|
|
text = "a a foo bar baz"
|
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=1)
|
|
output = splitter.split_text(text)
|
|
expected_output = ["a a", "foo", "bar", "baz"]
|
|
assert output == expected_output
|
|
|
|
|
|
def test_character_text_splitter_longer_words() -> None:
|
|
"""Test splitting by characters when splits not found easily."""
|
|
text = "foo bar baz 123"
|
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=1, chunk_overlap=1)
|
|
output = splitter.split_text(text)
|
|
expected_output = ["foo", "bar", "baz", "123"]
|
|
assert output == expected_output
|
|
|
|
|
|
def test_character_text_splitting_args() -> None:
|
|
"""Test invalid arguments."""
|
|
with pytest.raises(ValueError):
|
|
CharacterTextSplitter(chunk_size=2, chunk_overlap=4)
|
|
|
|
|
|
def test_merge_splits() -> None:
|
|
"""Test merging splits with a given separator."""
|
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=9, chunk_overlap=2)
|
|
splits = ["foo", "bar", "baz"]
|
|
expected_output = ["foo bar", "baz"]
|
|
output = splitter._merge_splits(splits, separator=" ")
|
|
assert output == expected_output
|
|
|
|
|
|
def test_create_documents() -> None:
|
|
"""Test create documents method."""
|
|
texts = ["foo bar", "baz"]
|
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=0)
|
|
docs = splitter.create_documents(texts)
|
|
expected_docs = [
|
|
Document(page_content="foo"),
|
|
Document(page_content="bar"),
|
|
Document(page_content="baz"),
|
|
]
|
|
assert docs == expected_docs
|
|
|
|
|
|
def test_create_documents_with_metadata() -> None:
|
|
"""Test create documents with metadata method."""
|
|
texts = ["foo bar", "baz"]
|
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=0)
|
|
docs = splitter.create_documents(texts, [{"source": "1"}, {"source": "2"}])
|
|
expected_docs = [
|
|
Document(page_content="foo", metadata={"source": "1"}),
|
|
Document(page_content="bar", metadata={"source": "1"}),
|
|
Document(page_content="baz", metadata={"source": "2"}),
|
|
]
|
|
assert docs == expected_docs
|
|
|
|
|
|
def test_metadata_not_shallow() -> None:
|
|
"""Test that metadatas are not shallow."""
|
|
texts = ["foo bar"]
|
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=0)
|
|
docs = splitter.create_documents(texts, [{"source": "1"}])
|
|
expected_docs = [
|
|
Document(page_content="foo", metadata={"source": "1"}),
|
|
Document(page_content="bar", metadata={"source": "1"}),
|
|
]
|
|
assert docs == expected_docs
|
|
docs[0].metadata["foo"] = 1
|
|
assert docs[0].metadata == {"source": "1", "foo": 1}
|
|
assert docs[1].metadata == {"source": "1"}
|
|
|
|
|
|
def test_iterative_text_splitter() -> None:
|
|
"""Test iterative text splitter."""
|
|
text = """Hi.\n\nI'm Harrison.\n\nHow? Are? You?\nOkay then f f f f.
|
|
This is a weird text to write, but gotta test the splittingggg some how.
|
|
|
|
Bye!\n\n-H."""
|
|
splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=1)
|
|
output = splitter.split_text(text)
|
|
expected_output = [
|
|
"Hi.",
|
|
"I'm",
|
|
"Harrison.",
|
|
"How? Are?",
|
|
"You?",
|
|
"Okay then",
|
|
"f f f f.",
|
|
"This is a",
|
|
"a weird",
|
|
"text to",
|
|
"write, but",
|
|
"gotta test",
|
|
"the",
|
|
"splittingg",
|
|
"ggg",
|
|
"some how.",
|
|
"Bye!\n\n-H.",
|
|
]
|
|
assert output == expected_output
|
|
|
|
|
|
def test_split_documents() -> None:
|
|
"""Test split_documents."""
|
|
splitter = CharacterTextSplitter(separator="", chunk_size=1, chunk_overlap=0)
|
|
docs = [
|
|
Document(page_content="foo", metadata={"source": "1"}),
|
|
Document(page_content="bar", metadata={"source": "2"}),
|
|
Document(page_content="baz", metadata={"source": "1"}),
|
|
]
|
|
expected_output = [
|
|
Document(page_content="f", metadata={"source": "1"}),
|
|
Document(page_content="o", metadata={"source": "1"}),
|
|
Document(page_content="o", metadata={"source": "1"}),
|
|
Document(page_content="b", metadata={"source": "2"}),
|
|
Document(page_content="a", metadata={"source": "2"}),
|
|
Document(page_content="r", metadata={"source": "2"}),
|
|
Document(page_content="b", metadata={"source": "1"}),
|
|
Document(page_content="a", metadata={"source": "1"}),
|
|
Document(page_content="z", metadata={"source": "1"}),
|
|
]
|
|
assert splitter.split_documents(docs) == expected_output
|