langchain/tests/unit_tests/test_text_splitter.py
Harrison Chase 72f99ff953
Harrison/text splitter (#5417)
adds support for keeping separators around when using recursive text
splitter
2023-05-29 16:56:31 -07:00

197 lines
6.5 KiB
Python

"""Test text splitting functionality."""
import pytest
from langchain.docstore.document import Document
from langchain.text_splitter import (
CharacterTextSplitter,
PythonCodeTextSplitter,
RecursiveCharacterTextSplitter,
)
FAKE_PYTHON_TEXT = """
class Foo:
def bar():
def foo():
def testing_func():
def bar():
"""
def test_character_text_splitter() -> None:
"""Test splitting by character count."""
text = "foo bar baz 123"
splitter = CharacterTextSplitter(separator=" ", chunk_size=7, chunk_overlap=3)
output = splitter.split_text(text)
expected_output = ["foo bar", "bar baz", "baz 123"]
assert output == expected_output
def test_character_text_splitter_empty_doc() -> None:
"""Test splitting by character count doesn't create empty documents."""
text = "foo bar"
splitter = CharacterTextSplitter(separator=" ", chunk_size=2, chunk_overlap=0)
output = splitter.split_text(text)
expected_output = ["foo", "bar"]
assert output == expected_output
def test_character_text_splitter_separtor_empty_doc() -> None:
"""Test edge cases are separators."""
text = "f b"
splitter = CharacterTextSplitter(separator=" ", chunk_size=2, chunk_overlap=0)
output = splitter.split_text(text)
expected_output = ["f", "b"]
assert output == expected_output
def test_character_text_splitter_long() -> None:
"""Test splitting by character count on long words."""
text = "foo bar baz a a"
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=1)
output = splitter.split_text(text)
expected_output = ["foo", "bar", "baz", "a a"]
assert output == expected_output
def test_character_text_splitter_short_words_first() -> None:
"""Test splitting by character count when shorter words are first."""
text = "a a foo bar baz"
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=1)
output = splitter.split_text(text)
expected_output = ["a a", "foo", "bar", "baz"]
assert output == expected_output
def test_character_text_splitter_longer_words() -> None:
"""Test splitting by characters when splits not found easily."""
text = "foo bar baz 123"
splitter = CharacterTextSplitter(separator=" ", chunk_size=1, chunk_overlap=1)
output = splitter.split_text(text)
expected_output = ["foo", "bar", "baz", "123"]
assert output == expected_output
def test_character_text_splitting_args() -> None:
"""Test invalid arguments."""
with pytest.raises(ValueError):
CharacterTextSplitter(chunk_size=2, chunk_overlap=4)
def test_merge_splits() -> None:
"""Test merging splits with a given separator."""
splitter = CharacterTextSplitter(separator=" ", chunk_size=9, chunk_overlap=2)
splits = ["foo", "bar", "baz"]
expected_output = ["foo bar", "baz"]
output = splitter._merge_splits(splits, separator=" ")
assert output == expected_output
def test_create_documents() -> None:
"""Test create documents method."""
texts = ["foo bar", "baz"]
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=0)
docs = splitter.create_documents(texts)
expected_docs = [
Document(page_content="foo"),
Document(page_content="bar"),
Document(page_content="baz"),
]
assert docs == expected_docs
def test_create_documents_with_metadata() -> None:
"""Test create documents with metadata method."""
texts = ["foo bar", "baz"]
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=0)
docs = splitter.create_documents(texts, [{"source": "1"}, {"source": "2"}])
expected_docs = [
Document(page_content="foo", metadata={"source": "1"}),
Document(page_content="bar", metadata={"source": "1"}),
Document(page_content="baz", metadata={"source": "2"}),
]
assert docs == expected_docs
def test_metadata_not_shallow() -> None:
"""Test that metadatas are not shallow."""
texts = ["foo bar"]
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=0)
docs = splitter.create_documents(texts, [{"source": "1"}])
expected_docs = [
Document(page_content="foo", metadata={"source": "1"}),
Document(page_content="bar", metadata={"source": "1"}),
]
assert docs == expected_docs
docs[0].metadata["foo"] = 1
assert docs[0].metadata == {"source": "1", "foo": 1}
assert docs[1].metadata == {"source": "1"}
def test_iterative_text_splitter() -> None:
"""Test iterative text splitter."""
text = """Hi.\n\nI'm Harrison.\n\nHow? Are? You?\nOkay then f f f f.
This is a weird text to write, but gotta test the splittingggg some how.
Bye!\n\n-H."""
splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=1)
output = splitter.split_text(text)
expected_output = [
"Hi.",
"I'm",
"Harrison.",
"How? Are?",
"You?",
"Okay then",
"f f f f.",
"This is a",
"weird",
"text to",
"write,",
"but gotta",
"test the",
"splitting",
"gggg",
"some how.",
"Bye!",
"-H.",
]
assert output == expected_output
def test_split_documents() -> None:
"""Test split_documents."""
splitter = CharacterTextSplitter(separator="", chunk_size=1, chunk_overlap=0)
docs = [
Document(page_content="foo", metadata={"source": "1"}),
Document(page_content="bar", metadata={"source": "2"}),
Document(page_content="baz", metadata={"source": "1"}),
]
expected_output = [
Document(page_content="f", metadata={"source": "1"}),
Document(page_content="o", metadata={"source": "1"}),
Document(page_content="o", metadata={"source": "1"}),
Document(page_content="b", metadata={"source": "2"}),
Document(page_content="a", metadata={"source": "2"}),
Document(page_content="r", metadata={"source": "2"}),
Document(page_content="b", metadata={"source": "1"}),
Document(page_content="a", metadata={"source": "1"}),
Document(page_content="z", metadata={"source": "1"}),
]
assert splitter.split_documents(docs) == expected_output
def test_python_text_splitter() -> None:
splitter = PythonCodeTextSplitter(chunk_size=30, chunk_overlap=0)
splits = splitter.split_text(FAKE_PYTHON_TEXT)
split_0 = """class Foo:\n\n def bar():"""
split_1 = """def foo():"""
split_2 = """def testing_func():"""
split_3 = """def bar():"""
expected_splits = [split_0, split_1, split_2, split_3]
assert splits == expected_splits