"""Test text splitting functionality.""" from typing import List import pytest from langchain.docstore.document import Document from langchain.text_splitter import ( CharacterTextSplitter, Language, MarkdownHeaderTextSplitter, PythonCodeTextSplitter, RecursiveCharacterTextSplitter, ) FAKE_PYTHON_TEXT = """ class Foo: def bar(): def foo(): def testing_func(): def bar(): """ def test_character_text_splitter() -> None: """Test splitting by character count.""" text = "foo bar baz 123" splitter = CharacterTextSplitter(separator=" ", chunk_size=7, chunk_overlap=3) output = splitter.split_text(text) expected_output = ["foo bar", "bar baz", "baz 123"] assert output == expected_output def test_character_text_splitter_empty_doc() -> None: """Test splitting by character count doesn't create empty documents.""" text = "foo bar" splitter = CharacterTextSplitter(separator=" ", chunk_size=2, chunk_overlap=0) output = splitter.split_text(text) expected_output = ["foo", "bar"] assert output == expected_output def test_character_text_splitter_separtor_empty_doc() -> None: """Test edge cases are separators.""" text = "f b" splitter = CharacterTextSplitter(separator=" ", chunk_size=2, chunk_overlap=0) output = splitter.split_text(text) expected_output = ["f", "b"] assert output == expected_output def test_character_text_splitter_long() -> None: """Test splitting by character count on long words.""" text = "foo bar baz a a" splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=1) output = splitter.split_text(text) expected_output = ["foo", "bar", "baz", "a a"] assert output == expected_output def test_character_text_splitter_short_words_first() -> None: """Test splitting by character count when shorter words are first.""" text = "a a foo bar baz" splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=1) output = splitter.split_text(text) expected_output = ["a a", "foo", "bar", "baz"] assert output == expected_output def test_character_text_splitter_longer_words() -> None: """Test splitting by characters when splits not found easily.""" text = "foo bar baz 123" splitter = CharacterTextSplitter(separator=" ", chunk_size=1, chunk_overlap=1) output = splitter.split_text(text) expected_output = ["foo", "bar", "baz", "123"] assert output == expected_output def test_character_text_splitting_args() -> None: """Test invalid arguments.""" with pytest.raises(ValueError): CharacterTextSplitter(chunk_size=2, chunk_overlap=4) def test_merge_splits() -> None: """Test merging splits with a given separator.""" splitter = CharacterTextSplitter(separator=" ", chunk_size=9, chunk_overlap=2) splits = ["foo", "bar", "baz"] expected_output = ["foo bar", "baz"] output = splitter._merge_splits(splits, separator=" ") assert output == expected_output def test_create_documents() -> None: """Test create documents method.""" texts = ["foo bar", "baz"] splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=0) docs = splitter.create_documents(texts) expected_docs = [ Document(page_content="foo"), Document(page_content="bar"), Document(page_content="baz"), ] assert docs == expected_docs def test_create_documents_with_metadata() -> None: """Test create documents with metadata method.""" texts = ["foo bar", "baz"] splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=0) docs = splitter.create_documents(texts, [{"source": "1"}, {"source": "2"}]) expected_docs = [ Document(page_content="foo", metadata={"source": "1"}), Document(page_content="bar", metadata={"source": "1"}), Document(page_content="baz", metadata={"source": "2"}), ] assert docs == expected_docs def test_create_documents_with_start_index() -> None: """Test create documents method.""" texts = ["foo bar baz 123"] splitter = CharacterTextSplitter( separator=" ", chunk_size=7, chunk_overlap=3, add_start_index=True ) docs = splitter.create_documents(texts) expected_docs = [ Document(page_content="foo bar", metadata={"start_index": 0}), Document(page_content="bar baz", metadata={"start_index": 4}), Document(page_content="baz 123", metadata={"start_index": 8}), ] assert docs == expected_docs def test_metadata_not_shallow() -> None: """Test that metadatas are not shallow.""" texts = ["foo bar"] splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=0) docs = splitter.create_documents(texts, [{"source": "1"}]) expected_docs = [ Document(page_content="foo", metadata={"source": "1"}), Document(page_content="bar", metadata={"source": "1"}), ] assert docs == expected_docs docs[0].metadata["foo"] = 1 assert docs[0].metadata == {"source": "1", "foo": 1} assert docs[1].metadata == {"source": "1"} def test_iterative_text_splitter_keep_separator() -> None: chunk_size = 5 output = __test_iterative_text_splitter(chunk_size=chunk_size, keep_separator=True) assert output == [ "....5", "X..3", "Y...4", "X....5", "Y...", ] def test_iterative_text_splitter_discard_separator() -> None: chunk_size = 5 output = __test_iterative_text_splitter(chunk_size=chunk_size, keep_separator=False) assert output == [ "....5", "..3", "...4", "....5", "...", ] def __test_iterative_text_splitter(chunk_size: int, keep_separator: bool) -> List[str]: chunk_size += 1 if keep_separator else 0 splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=0, separators=["X", "Y"], keep_separator=keep_separator, ) text = "....5X..3Y...4X....5Y..." output = splitter.split_text(text) for chunk in output: assert len(chunk) <= chunk_size, f"Chunk is larger than {chunk_size}" return output def test_iterative_text_splitter() -> None: """Test iterative text splitter.""" text = """Hi.\n\nI'm Harrison.\n\nHow? Are? You?\nOkay then f f f f. This is a weird text to write, but gotta test the splittingggg some how. Bye!\n\n-H.""" splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=1) output = splitter.split_text(text) expected_output = [ "Hi.", "I'm", "Harrison.", "How? Are?", "You?", "Okay then", "f f f f.", "This is a", "weird", "text to", "write,", "but gotta", "test the", "splitting", "gggg", "some how.", "Bye!", "-H.", ] assert output == expected_output def test_split_documents() -> None: """Test split_documents.""" splitter = CharacterTextSplitter(separator="", chunk_size=1, chunk_overlap=0) docs = [ Document(page_content="foo", metadata={"source": "1"}), Document(page_content="bar", metadata={"source": "2"}), Document(page_content="baz", metadata={"source": "1"}), ] expected_output = [ Document(page_content="f", metadata={"source": "1"}), Document(page_content="o", metadata={"source": "1"}), Document(page_content="o", metadata={"source": "1"}), Document(page_content="b", metadata={"source": "2"}), Document(page_content="a", metadata={"source": "2"}), Document(page_content="r", metadata={"source": "2"}), Document(page_content="b", metadata={"source": "1"}), Document(page_content="a", metadata={"source": "1"}), Document(page_content="z", metadata={"source": "1"}), ] assert splitter.split_documents(docs) == expected_output def test_python_text_splitter() -> None: splitter = PythonCodeTextSplitter(chunk_size=30, chunk_overlap=0) splits = splitter.split_text(FAKE_PYTHON_TEXT) split_0 = """class Foo:\n\n def bar():""" split_1 = """def foo():""" split_2 = """def testing_func():""" split_3 = """def bar():""" expected_splits = [split_0, split_1, split_2, split_3] assert splits == expected_splits CHUNK_SIZE = 16 def test_python_code_splitter() -> None: splitter = RecursiveCharacterTextSplitter.from_language( Language.PYTHON, chunk_size=CHUNK_SIZE, chunk_overlap=0 ) code = """ def hello_world(): print("Hello, World!") # Call the function hello_world() """ chunks = splitter.split_text(code) assert chunks == [ "def", "hello_world():", 'print("Hello,', 'World!")', "# Call the", "function", "hello_world()", ] def test_golang_code_splitter() -> None: splitter = RecursiveCharacterTextSplitter.from_language( Language.GO, chunk_size=CHUNK_SIZE, chunk_overlap=0 ) code = """ package main import "fmt" func helloWorld() { fmt.Println("Hello, World!") } func main() { helloWorld() } """ chunks = splitter.split_text(code) assert chunks == [ "package main", 'import "fmt"', "func", "helloWorld() {", 'fmt.Println("He', "llo,", 'World!")', "}", "func main() {", "helloWorld()", "}", ] def test_rst_code_splitter() -> None: splitter = RecursiveCharacterTextSplitter.from_language( Language.RST, chunk_size=CHUNK_SIZE, chunk_overlap=0 ) code = """ Sample Document =============== Section ------- This is the content of the section. Lists ----- - Item 1 - Item 2 - Item 3 Comment ******* Not a comment .. This is a comment """ chunks = splitter.split_text(code) assert chunks == [ "Sample Document", "===============", "Section", "-------", "This is the", "content of the", "section.", "Lists", "-----", "- Item 1", "- Item 2", "- Item 3", "Comment", "*******", "Not a comment", ".. This is a", "comment", ] def test_proto_file_splitter() -> None: splitter = RecursiveCharacterTextSplitter.from_language( Language.PROTO, chunk_size=CHUNK_SIZE, chunk_overlap=0 ) code = """ syntax = "proto3"; package example; message Person { string name = 1; int32 age = 2; repeated string hobbies = 3; } """ chunks = splitter.split_text(code) assert chunks == [ "syntax =", '"proto3";', "package", "example;", "message Person", "{", "string name", "= 1;", "int32 age =", "2;", "repeated", "string hobbies", "= 3;", "}", ] def test_javascript_code_splitter() -> None: splitter = RecursiveCharacterTextSplitter.from_language( Language.JS, chunk_size=CHUNK_SIZE, chunk_overlap=0 ) code = """ function helloWorld() { console.log("Hello, World!"); } // Call the function helloWorld(); """ chunks = splitter.split_text(code) assert chunks == [ "function", "helloWorld() {", 'console.log("He', "llo,", 'World!");', "}", "// Call the", "function", "helloWorld();", ] def test_java_code_splitter() -> None: splitter = RecursiveCharacterTextSplitter.from_language( Language.JAVA, chunk_size=CHUNK_SIZE, chunk_overlap=0 ) code = """ public class HelloWorld { public static void main(String[] args) { System.out.println("Hello, World!"); } } """ chunks = splitter.split_text(code) assert chunks == [ "public class", "HelloWorld {", "public", "static void", "main(String[]", "args) {", "System.out.prin", 'tln("Hello,', 'World!");', "}\n}", ] def test_cpp_code_splitter() -> None: splitter = RecursiveCharacterTextSplitter.from_language( Language.CPP, chunk_size=CHUNK_SIZE, chunk_overlap=0 ) code = """ #include int main() { std::cout << "Hello, World!" << std::endl; return 0; } """ chunks = splitter.split_text(code) assert chunks == [ "#include", "", "int main() {", "std::cout", '<< "Hello,', 'World!" <<', "std::endl;", "return 0;\n}", ] def test_scala_code_splitter() -> None: splitter = RecursiveCharacterTextSplitter.from_language( Language.SCALA, chunk_size=CHUNK_SIZE, chunk_overlap=0 ) code = """ object HelloWorld { def main(args: Array[String]): Unit = { println("Hello, World!") } } """ chunks = splitter.split_text(code) assert chunks == [ "object", "HelloWorld {", "def", "main(args:", "Array[String]):", "Unit = {", 'println("Hello,', 'World!")', "}\n}", ] def test_ruby_code_splitter() -> None: splitter = RecursiveCharacterTextSplitter.from_language( Language.RUBY, chunk_size=CHUNK_SIZE, chunk_overlap=0 ) code = """ def hello_world puts "Hello, World!" end hello_world """ chunks = splitter.split_text(code) assert chunks == [ "def hello_world", 'puts "Hello,', 'World!"', "end", "hello_world", ] def test_php_code_splitter() -> None: splitter = RecursiveCharacterTextSplitter.from_language( Language.PHP, chunk_size=CHUNK_SIZE, chunk_overlap=0 ) code = """ """ chunks = splitter.split_text(code) assert chunks == [ "", ] def test_swift_code_splitter() -> None: splitter = RecursiveCharacterTextSplitter.from_language( Language.SWIFT, chunk_size=CHUNK_SIZE, chunk_overlap=0 ) code = """ func helloWorld() { print("Hello, World!") } helloWorld() """ chunks = splitter.split_text(code) assert chunks == [ "func", "helloWorld() {", 'print("Hello,', 'World!")', "}", "helloWorld()", ] def test_rust_code_splitter() -> None: splitter = RecursiveCharacterTextSplitter.from_language( Language.RUST, chunk_size=CHUNK_SIZE, chunk_overlap=0 ) code = """ fn main() { println!("Hello, World!"); } """ chunks = splitter.split_text(code) assert chunks == ["fn main() {", 'println!("Hello', ",", 'World!");', "}"] def test_markdown_code_splitter() -> None: splitter = RecursiveCharacterTextSplitter.from_language( Language.MARKDOWN, chunk_size=CHUNK_SIZE, chunk_overlap=0 ) code = """ # Sample Document ## Section This is the content of the section. ## Lists - Item 1 - Item 2 - Item 3 ### Horizontal lines *********** ____________ ------------------- #### Code blocks ``` This is a code block ``` """ chunks = splitter.split_text(code) assert chunks == [ "# Sample", "Document", "## Section", "This is the", "content of the", "section.", "## Lists", "- Item 1", "- Item 2", "- Item 3", "### Horizontal", "lines", "***********", "____________", "---------------", "----", "#### Code", "blocks", "```", "This is a code", "block", "```", ] def test_html_code_splitter() -> None: splitter = RecursiveCharacterTextSplitter.from_language( Language.HTML, chunk_size=60, chunk_overlap=0 ) code = """

Sample Document

Section

Reference content.

Lists

A block

Some text

Some more text

""" chunks = splitter.split_text(code) assert chunks == [ "

Sample Document

\n

Section

", '

Reference content.

', "

Lists

\n ", "

A block

", '
', "

Some text

", "

Some more text

\n
", ] def test_md_header_text_splitter_1() -> None: """Test markdown splitter by header: Case 1.""" markdown_document = ( "# Foo\n\n" " ## Bar\n\n" "Hi this is Jim\n\n" "Hi this is Joe\n\n" " ## Baz\n\n" " Hi this is Molly" ) headers_to_split_on = [ ("#", "Header 1"), ("##", "Header 2"), ] markdown_splitter = MarkdownHeaderTextSplitter( headers_to_split_on=headers_to_split_on, ) output = markdown_splitter.split_text(markdown_document) expected_output = [ Document( page_content="Hi this is Jim \nHi this is Joe", metadata={"Header 1": "Foo", "Header 2": "Bar"}, ), Document( page_content="Hi this is Molly", metadata={"Header 1": "Foo", "Header 2": "Baz"}, ), ] assert output == expected_output def test_md_header_text_splitter_2() -> None: """Test markdown splitter by header: Case 2.""" markdown_document = ( "# Foo\n\n" " ## Bar\n\n" "Hi this is Jim\n\n" "Hi this is Joe\n\n" " ### Boo \n\n" " Hi this is Lance \n\n" " ## Baz\n\n" " Hi this is Molly" ) headers_to_split_on = [ ("#", "Header 1"), ("##", "Header 2"), ("###", "Header 3"), ] markdown_splitter = MarkdownHeaderTextSplitter( headers_to_split_on=headers_to_split_on, ) output = markdown_splitter.split_text(markdown_document) expected_output = [ Document( page_content="Hi this is Jim \nHi this is Joe", metadata={"Header 1": "Foo", "Header 2": "Bar"}, ), Document( page_content="Hi this is Lance", metadata={"Header 1": "Foo", "Header 2": "Bar", "Header 3": "Boo"}, ), Document( page_content="Hi this is Molly", metadata={"Header 1": "Foo", "Header 2": "Baz"}, ), ] assert output == expected_output def test_md_header_text_splitter_3() -> None: """Test markdown splitter by header: Case 3.""" markdown_document = ( "# Foo\n\n" " ## Bar\n\n" "Hi this is Jim\n\n" "Hi this is Joe\n\n" " ### Boo \n\n" " Hi this is Lance \n\n" " #### Bim \n\n" " Hi this is John \n\n" " ## Baz\n\n" " Hi this is Molly" ) headers_to_split_on = [ ("#", "Header 1"), ("##", "Header 2"), ("###", "Header 3"), ("####", "Header 4"), ] markdown_splitter = MarkdownHeaderTextSplitter( headers_to_split_on=headers_to_split_on, ) output = markdown_splitter.split_text(markdown_document) expected_output = [ Document( page_content="Hi this is Jim \nHi this is Joe", metadata={"Header 1": "Foo", "Header 2": "Bar"}, ), Document( page_content="Hi this is Lance", metadata={"Header 1": "Foo", "Header 2": "Bar", "Header 3": "Boo"}, ), Document( page_content="Hi this is John", metadata={ "Header 1": "Foo", "Header 2": "Bar", "Header 3": "Boo", "Header 4": "Bim", }, ), Document( page_content="Hi this is Molly", metadata={"Header 1": "Foo", "Header 2": "Baz"}, ), ] assert output == expected_output def test_solidity_code_splitter() -> None: splitter = RecursiveCharacterTextSplitter.from_language( Language.SOL, chunk_size=CHUNK_SIZE, chunk_overlap=0 ) code = """pragma solidity ^0.8.20; contract HelloWorld { function add(uint a, uint b) pure public returns(uint) { return a + b; } } """ chunks = splitter.split_text(code) assert chunks == [ "pragma solidity", "^0.8.20;", "contract", "HelloWorld {", "function", "add(uint a,", "uint b) pure", "public", "returns(uint) {", "return a", "+ b;", "}\n }", ]