"""Test text splitting functionality.""" import pytest from langchain.docstore.document import Document from langchain.text_splitter import ( CharacterTextSplitter, RecursiveCharacterTextSplitter, ) def test_character_text_splitter() -> None: """Test splitting by character count.""" text = "foo bar baz 123" splitter = CharacterTextSplitter(separator=" ", chunk_size=7, chunk_overlap=3) output = splitter.split_text(text) expected_output = ["foo bar", "bar baz", "baz 123"] assert output == expected_output def test_character_text_splitter_empty_doc() -> None: """Test splitting by character count doesn't create empty documents.""" text = "foo bar" splitter = CharacterTextSplitter(separator=" ", chunk_size=2, chunk_overlap=0) output = splitter.split_text(text) expected_output = ["foo", "bar"] assert output == expected_output def test_character_text_splitter_separtor_empty_doc() -> None: """Test edge cases are separators.""" text = "f b" splitter = CharacterTextSplitter(separator=" ", chunk_size=2, chunk_overlap=0) output = splitter.split_text(text) expected_output = ["f", "b"] assert output == expected_output def test_character_text_splitter_long() -> None: """Test splitting by character count on long words.""" text = "foo bar baz a a" splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=1) output = splitter.split_text(text) expected_output = ["foo", "bar", "baz", "a a"] assert output == expected_output def test_character_text_splitter_short_words_first() -> None: """Test splitting by character count when shorter words are first.""" text = "a a foo bar baz" splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=1) output = splitter.split_text(text) expected_output = ["a a", "foo", "bar", "baz"] assert output == expected_output def test_character_text_splitter_longer_words() -> None: """Test splitting by characters when splits not found easily.""" text = "foo bar baz 123" splitter = CharacterTextSplitter(separator=" ", chunk_size=1, chunk_overlap=1) output = splitter.split_text(text) expected_output = ["foo", "bar", "baz", "123"] assert output == expected_output def test_character_text_splitting_args() -> None: """Test invalid arguments.""" with pytest.raises(ValueError): CharacterTextSplitter(chunk_size=2, chunk_overlap=4) def test_merge_splits() -> None: """Test merging splits with a given separator.""" splitter = CharacterTextSplitter(separator=" ", chunk_size=9, chunk_overlap=2) splits = ["foo", "bar", "baz"] expected_output = ["foo bar", "baz"] output = splitter._merge_splits(splits, separator=" ") assert output == expected_output def test_create_documents() -> None: """Test create documents method.""" texts = ["foo bar", "baz"] splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=0) docs = splitter.create_documents(texts) expected_docs = [ Document(page_content="foo"), Document(page_content="bar"), Document(page_content="baz"), ] assert docs == expected_docs def test_create_documents_with_metadata() -> None: """Test create documents with metadata method.""" texts = ["foo bar", "baz"] splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=0) docs = splitter.create_documents(texts, [{"source": "1"}, {"source": "2"}]) expected_docs = [ Document(page_content="foo", metadata={"source": "1"}), Document(page_content="bar", metadata={"source": "1"}), Document(page_content="baz", metadata={"source": "2"}), ] assert docs == expected_docs def test_metadata_not_shallow() -> None: """Test that metadatas are not shallow.""" texts = ["foo bar"] splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=0) docs = splitter.create_documents(texts, [{"source": "1"}]) expected_docs = [ Document(page_content="foo", metadata={"source": "1"}), Document(page_content="bar", metadata={"source": "1"}), ] assert docs == expected_docs docs[0].metadata["foo"] = 1 assert docs[0].metadata == {"source": "1", "foo": 1} assert docs[1].metadata == {"source": "1"} def test_iterative_text_splitter() -> None: """Test iterative text splitter.""" text = """Hi.\n\nI'm Harrison.\n\nHow? Are? You?\nOkay then f f f f. This is a weird text to write, but gotta test the splittingggg some how. Bye!\n\n-H.""" splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=1) output = splitter.split_text(text) expected_output = [ "Hi.", "I'm", "Harrison.", "How? Are?", "You?", "Okay then", "f f f f.", "This is a", "a weird", "text to", "write, but", "gotta test", "the", "splittingg", "ggg", "some how.", "Bye!\n\n-H.", ] assert output == expected_output def test_split_documents() -> None: """Test split_documents.""" splitter = CharacterTextSplitter(separator="", chunk_size=1, chunk_overlap=0) docs = [ Document(page_content="foo", metadata={"source": "1"}), Document(page_content="bar", metadata={"source": "2"}), Document(page_content="baz", metadata={"source": "1"}), ] expected_output = [ Document(page_content="f", metadata={"source": "1"}), Document(page_content="o", metadata={"source": "1"}), Document(page_content="o", metadata={"source": "1"}), Document(page_content="b", metadata={"source": "2"}), Document(page_content="a", metadata={"source": "2"}), Document(page_content="r", metadata={"source": "2"}), Document(page_content="b", metadata={"source": "1"}), Document(page_content="a", metadata={"source": "1"}), Document(page_content="z", metadata={"source": "1"}), ] assert splitter.split_documents(docs) == expected_output