You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

829 lines
21 KiB

"""Test text splitting functionality."""
from typing import List
import pytest
from langchain.docstore.document import Document
from langchain.text_splitter import (
class Foo:
def bar():
def foo():
def testing_func():
def bar():
def test_character_text_splitter() -> None:
"""Test splitting by character count."""
text = "foo bar baz 123"
splitter = CharacterTextSplitter(separator=" ", chunk_size=7, chunk_overlap=3)
output = splitter.split_text(text)
expected_output = ["foo bar", "bar baz", "baz 123"]
assert output == expected_output
def test_character_text_splitter_empty_doc() -> None:
"""Test splitting by character count doesn't create empty documents."""
text = "foo bar"
splitter = CharacterTextSplitter(separator=" ", chunk_size=2, chunk_overlap=0)
output = splitter.split_text(text)
expected_output = ["foo", "bar"]
assert output == expected_output
def test_character_text_splitter_separtor_empty_doc() -> None:
"""Test edge cases are separators."""
text = "f b"
splitter = CharacterTextSplitter(separator=" ", chunk_size=2, chunk_overlap=0)
output = splitter.split_text(text)
expected_output = ["f", "b"]
assert output == expected_output
def test_character_text_splitter_long() -> None:
"""Test splitting by character count on long words."""
text = "foo bar baz a a"
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=1)
output = splitter.split_text(text)
expected_output = ["foo", "bar", "baz", "a a"]
assert output == expected_output
def test_character_text_splitter_short_words_first() -> None:
"""Test splitting by character count when shorter words are first."""
text = "a a foo bar baz"
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=1)
output = splitter.split_text(text)
expected_output = ["a a", "foo", "bar", "baz"]
assert output == expected_output
def test_character_text_splitter_longer_words() -> None:
"""Test splitting by characters when splits not found easily."""
text = "foo bar baz 123"
splitter = CharacterTextSplitter(separator=" ", chunk_size=1, chunk_overlap=1)
output = splitter.split_text(text)
expected_output = ["foo", "bar", "baz", "123"]
assert output == expected_output
def test_character_text_splitting_args() -> None:
"""Test invalid arguments."""
with pytest.raises(ValueError):
CharacterTextSplitter(chunk_size=2, chunk_overlap=4)
def test_merge_splits() -> None:
"""Test merging splits with a given separator."""
splitter = CharacterTextSplitter(separator=" ", chunk_size=9, chunk_overlap=2)
splits = ["foo", "bar", "baz"]
expected_output = ["foo bar", "baz"]
output = splitter._merge_splits(splits, separator=" ")
assert output == expected_output
def test_create_documents() -> None:
"""Test create documents method."""
texts = ["foo bar", "baz"]
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=0)
docs = splitter.create_documents(texts)
expected_docs = [
assert docs == expected_docs
def test_create_documents_with_metadata() -> None:
"""Test create documents with metadata method."""
texts = ["foo bar", "baz"]
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=0)
docs = splitter.create_documents(texts, [{"source": "1"}, {"source": "2"}])
expected_docs = [
Document(page_content="foo", metadata={"source": "1"}),
Document(page_content="bar", metadata={"source": "1"}),
Document(page_content="baz", metadata={"source": "2"}),
assert docs == expected_docs
def test_create_documents_with_start_index() -> None:
"""Test create documents method."""
texts = ["foo bar baz 123"]
splitter = CharacterTextSplitter(
separator=" ", chunk_size=7, chunk_overlap=3, add_start_index=True
docs = splitter.create_documents(texts)
expected_docs = [
Document(page_content="foo bar", metadata={"start_index": 0}),
Document(page_content="bar baz", metadata={"start_index": 4}),
Document(page_content="baz 123", metadata={"start_index": 8}),
assert docs == expected_docs
def test_metadata_not_shallow() -> None:
"""Test that metadatas are not shallow."""
texts = ["foo bar"]
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=0)
docs = splitter.create_documents(texts, [{"source": "1"}])
expected_docs = [
Document(page_content="foo", metadata={"source": "1"}),
Document(page_content="bar", metadata={"source": "1"}),
assert docs == expected_docs
docs[0].metadata["foo"] = 1
assert docs[0].metadata == {"source": "1", "foo": 1}
assert docs[1].metadata == {"source": "1"}
def test_iterative_text_splitter_keep_separator() -> None:
chunk_size = 5
output = __test_iterative_text_splitter(chunk_size=chunk_size, keep_separator=True)
assert output == [
def test_iterative_text_splitter_discard_separator() -> None:
chunk_size = 5
output = __test_iterative_text_splitter(chunk_size=chunk_size, keep_separator=False)
assert output == [
def __test_iterative_text_splitter(chunk_size: int, keep_separator: bool) -> List[str]:
chunk_size += 1 if keep_separator else 0
splitter = RecursiveCharacterTextSplitter(
separators=["X", "Y"],
text = "....5X..3Y...4X....5Y..."
output = splitter.split_text(text)
for chunk in output:
assert len(chunk) <= chunk_size, f"Chunk is larger than {chunk_size}"
return output
def test_iterative_text_splitter() -> None:
"""Test iterative text splitter."""
text = """Hi.\n\nI'm Harrison.\n\nHow? Are? You?\nOkay then f f f f.
This is a weird text to write, but gotta test the splittingggg some how.
splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=1)
output = splitter.split_text(text)
expected_output = [
"How? Are?",
"Okay then",
"f f f f.",
"This is a",
"text to",
"but gotta",
"test the",
"some how.",
assert output == expected_output
def test_split_documents() -> None:
"""Test split_documents."""
splitter = CharacterTextSplitter(separator="", chunk_size=1, chunk_overlap=0)
docs = [
Document(page_content="foo", metadata={"source": "1"}),
Document(page_content="bar", metadata={"source": "2"}),
Document(page_content="baz", metadata={"source": "1"}),
expected_output = [
Document(page_content="f", metadata={"source": "1"}),
Document(page_content="o", metadata={"source": "1"}),
Document(page_content="o", metadata={"source": "1"}),
Document(page_content="b", metadata={"source": "2"}),
Document(page_content="a", metadata={"source": "2"}),
Document(page_content="r", metadata={"source": "2"}),
Document(page_content="b", metadata={"source": "1"}),
Document(page_content="a", metadata={"source": "1"}),
Document(page_content="z", metadata={"source": "1"}),
assert splitter.split_documents(docs) == expected_output
def test_python_text_splitter() -> None:
splitter = PythonCodeTextSplitter(chunk_size=30, chunk_overlap=0)
splits = splitter.split_text(FAKE_PYTHON_TEXT)
split_0 = """class Foo:\n\n def bar():"""
split_1 = """def foo():"""
split_2 = """def testing_func():"""
split_3 = """def bar():"""
expected_splits = [split_0, split_1, split_2, split_3]
assert splits == expected_splits
def test_python_code_splitter() -> None:
splitter = RecursiveCharacterTextSplitter.from_language(
Language.PYTHON, chunk_size=CHUNK_SIZE, chunk_overlap=0
code = """
def hello_world():
print("Hello, World!")
# Call the function
chunks = splitter.split_text(code)
assert chunks == [
"# Call the",
def test_golang_code_splitter() -> None:
splitter = RecursiveCharacterTextSplitter.from_language(
Language.GO, chunk_size=CHUNK_SIZE, chunk_overlap=0
code = """
package main
import "fmt"
func helloWorld() {
fmt.Println("Hello, World!")
func main() {
chunks = splitter.split_text(code)
assert chunks == [
"package main",
'import "fmt"',
"helloWorld() {",
"func main() {",
def test_rst_code_splitter() -> None:
splitter = RecursiveCharacterTextSplitter.from_language(
Language.RST, chunk_size=CHUNK_SIZE, chunk_overlap=0
code = """
Sample Document
This is the content of the section.
- Item 1
- Item 2
- Item 3
Not a comment
.. This is a comment
chunks = splitter.split_text(code)
assert chunks == [
"Sample Document",
"This is the",
"content of the",
"- Item 1",
"- Item 2",
"- Item 3",
"Not a comment",
".. This is a",
def test_proto_file_splitter() -> None:
splitter = RecursiveCharacterTextSplitter.from_language(
Language.PROTO, chunk_size=CHUNK_SIZE, chunk_overlap=0
code = """
syntax = "proto3";
package example;
message Person {
string name = 1;
int32 age = 2;
repeated string hobbies = 3;
chunks = splitter.split_text(code)
assert chunks == [
"syntax =",
"message Person",
"string name",
"= 1;",
"int32 age =",
"string hobbies",
"= 3;",
def test_javascript_code_splitter() -> None:
splitter = RecursiveCharacterTextSplitter.from_language(
Language.JS, chunk_size=CHUNK_SIZE, chunk_overlap=0
code = """
function helloWorld() {
console.log("Hello, World!");
// Call the function
chunks = splitter.split_text(code)
assert chunks == [
"helloWorld() {",
"// Call the",
def test_java_code_splitter() -> None:
splitter = RecursiveCharacterTextSplitter.from_language(
Language.JAVA, chunk_size=CHUNK_SIZE, chunk_overlap=0
code = """
public class HelloWorld {
public static void main(String[] args) {
System.out.println("Hello, World!");
chunks = splitter.split_text(code)
assert chunks == [
"public class",
"HelloWorld {",
"static void",
"args) {",
def test_cpp_code_splitter() -> None:
splitter = RecursiveCharacterTextSplitter.from_language(
Language.CPP, chunk_size=CHUNK_SIZE, chunk_overlap=0
code = """
#include <iostream>
int main() {
std::cout << "Hello, World!" << std::endl;
return 0;
chunks = splitter.split_text(code)
assert chunks == [
"int main() {",
'<< "Hello,',
'World!" <<',
"return 0;\n}",
def test_scala_code_splitter() -> None:
splitter = RecursiveCharacterTextSplitter.from_language(
Language.SCALA, chunk_size=CHUNK_SIZE, chunk_overlap=0
code = """
object HelloWorld {
def main(args: Array[String]): Unit = {
println("Hello, World!")
chunks = splitter.split_text(code)
assert chunks == [
"HelloWorld {",
"Unit = {",
def test_ruby_code_splitter() -> None:
splitter = RecursiveCharacterTextSplitter.from_language(
Language.RUBY, chunk_size=CHUNK_SIZE, chunk_overlap=0
code = """
def hello_world
puts "Hello, World!"
chunks = splitter.split_text(code)
assert chunks == [
"def hello_world",
'puts "Hello,',
def test_php_code_splitter() -> None:
splitter = RecursiveCharacterTextSplitter.from_language(
Language.PHP, chunk_size=CHUNK_SIZE, chunk_overlap=0
code = """
function hello_world() {
echo "Hello, World!";
chunks = splitter.split_text(code)
assert chunks == [
"hello_world() {",
def test_swift_code_splitter() -> None:
splitter = RecursiveCharacterTextSplitter.from_language(
Language.SWIFT, chunk_size=CHUNK_SIZE, chunk_overlap=0
code = """
func helloWorld() {
print("Hello, World!")
chunks = splitter.split_text(code)
assert chunks == [
"helloWorld() {",
def test_rust_code_splitter() -> None:
splitter = RecursiveCharacterTextSplitter.from_language(
Language.RUST, chunk_size=CHUNK_SIZE, chunk_overlap=0
code = """
fn main() {
println!("Hello, World!");
chunks = splitter.split_text(code)
assert chunks == ["fn main() {", 'println!("Hello', ",", 'World!");', "}"]
def test_markdown_code_splitter() -> None:
splitter = RecursiveCharacterTextSplitter.from_language(
Language.MARKDOWN, chunk_size=CHUNK_SIZE, chunk_overlap=0
code = """
# Sample Document
## Section
This is the content of the section.
## Lists
- Item 1
- Item 2
- Item 3
### Horizontal lines
#### Code blocks
This is a code block
chunks = splitter.split_text(code)
assert chunks == [
"# Sample",
"## Section",
"This is the",
"content of the",
"## Lists",
"- Item 1",
"- Item 2",
"- Item 3",
"### Horizontal",
"#### Code",
"This is a code",
def test_html_code_splitter() -> None:
splitter = RecursiveCharacterTextSplitter.from_language(
Language.HTML, chunk_size=60, chunk_overlap=0
code = """
<h1>Sample Document</h1>
<p id="1234">Reference content.</p>
<li>Item 1</li>
<li>Item 2</li>
<li>Item 3</li>
<h3>A block</h3>
<div class="amazing">
<p>Some text</p>
<p>Some more text</p>
chunks = splitter.split_text(code)
assert chunks == [
"<h1>Sample Document</h1>\n <h2>Section</h2>",
'<p id="1234">Reference content.</p>',
"<h2>Lists</h2>\n <ul>",
"<li>Item 1</li>\n <li>Item 2</li>",
"<li>Item 3</li>\n </ul>",
"<h3>A block</h3>",
'<div class="amazing">',
"<p>Some text</p>",
"<p>Some more text</p>\n </div>",
def test_md_header_text_splitter_1() -> None:
"""Test markdown splitter by header: Case 1."""
markdown_document = (
"# Foo\n\n"
" ## Bar\n\n"
"Hi this is Jim\n\n"
"Hi this is Joe\n\n"
" ## Baz\n\n"
" Hi this is Molly"
headers_to_split_on = [
("#", "Header 1"),
("##", "Header 2"),
markdown_splitter = MarkdownHeaderTextSplitter(
output = markdown_splitter.split_text(markdown_document)
expected_output = [
page_content="Hi this is Jim \nHi this is Joe",
metadata={"Header 1": "Foo", "Header 2": "Bar"},
page_content="Hi this is Molly",
metadata={"Header 1": "Foo", "Header 2": "Baz"},
assert output == expected_output
def test_md_header_text_splitter_2() -> None:
"""Test markdown splitter by header: Case 2."""
markdown_document = (
"# Foo\n\n"
" ## Bar\n\n"
"Hi this is Jim\n\n"
"Hi this is Joe\n\n"
" ### Boo \n\n"
" Hi this is Lance \n\n"
" ## Baz\n\n"
" Hi this is Molly"
headers_to_split_on = [
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
markdown_splitter = MarkdownHeaderTextSplitter(
output = markdown_splitter.split_text(markdown_document)
expected_output = [
page_content="Hi this is Jim \nHi this is Joe",
metadata={"Header 1": "Foo", "Header 2": "Bar"},
page_content="Hi this is Lance",
metadata={"Header 1": "Foo", "Header 2": "Bar", "Header 3": "Boo"},
page_content="Hi this is Molly",
metadata={"Header 1": "Foo", "Header 2": "Baz"},
assert output == expected_output
def test_md_header_text_splitter_3() -> None:
"""Test markdown splitter by header: Case 3."""
markdown_document = (
"# Foo\n\n"
" ## Bar\n\n"
"Hi this is Jim\n\n"
"Hi this is Joe\n\n"
" ### Boo \n\n"
" Hi this is Lance \n\n"
" #### Bim \n\n"
" Hi this is John \n\n"
" ## Baz\n\n"
" Hi this is Molly"
headers_to_split_on = [
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
("####", "Header 4"),
markdown_splitter = MarkdownHeaderTextSplitter(
output = markdown_splitter.split_text(markdown_document)
expected_output = [
page_content="Hi this is Jim \nHi this is Joe",
metadata={"Header 1": "Foo", "Header 2": "Bar"},
page_content="Hi this is Lance",
metadata={"Header 1": "Foo", "Header 2": "Bar", "Header 3": "Boo"},
page_content="Hi this is John",
"Header 1": "Foo",
"Header 2": "Bar",
"Header 3": "Boo",
"Header 4": "Bim",
page_content="Hi this is Molly",
metadata={"Header 1": "Foo", "Header 2": "Baz"},
assert output == expected_output
def test_solidity_code_splitter() -> None:
splitter = RecursiveCharacterTextSplitter.from_language(
Language.SOL, chunk_size=CHUNK_SIZE, chunk_overlap=0
code = """pragma solidity ^0.8.20;
contract HelloWorld {
function add(uint a, uint b) pure public returns(uint) {
return a + b;
chunks = splitter.split_text(code)
assert chunks == [
"pragma solidity",
"HelloWorld {",
"add(uint a,",
"uint b) pure",
"returns(uint) {",
"return a",
"+ b;",
"}\n }",