mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
5ce74b5958
Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
512 lines
13 KiB
Python
512 lines
13 KiB
Python
"""Test text splitting functionality."""
|
|
import pytest
|
|
|
|
from langchain.docstore.document import Document
|
|
from langchain.text_splitter import (
|
|
CharacterTextSplitter,
|
|
Language,
|
|
PythonCodeTextSplitter,
|
|
RecursiveCharacterTextSplitter,
|
|
)
|
|
|
|
FAKE_PYTHON_TEXT = """
|
|
class Foo:
|
|
|
|
def bar():
|
|
|
|
|
|
def foo():
|
|
|
|
def testing_func():
|
|
|
|
def bar():
|
|
"""
|
|
|
|
|
|
def test_character_text_splitter() -> None:
|
|
"""Test splitting by character count."""
|
|
text = "foo bar baz 123"
|
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=7, chunk_overlap=3)
|
|
output = splitter.split_text(text)
|
|
expected_output = ["foo bar", "bar baz", "baz 123"]
|
|
assert output == expected_output
|
|
|
|
|
|
def test_character_text_splitter_empty_doc() -> None:
|
|
"""Test splitting by character count doesn't create empty documents."""
|
|
text = "foo bar"
|
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=2, chunk_overlap=0)
|
|
output = splitter.split_text(text)
|
|
expected_output = ["foo", "bar"]
|
|
assert output == expected_output
|
|
|
|
|
|
def test_character_text_splitter_separtor_empty_doc() -> None:
|
|
"""Test edge cases are separators."""
|
|
text = "f b"
|
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=2, chunk_overlap=0)
|
|
output = splitter.split_text(text)
|
|
expected_output = ["f", "b"]
|
|
assert output == expected_output
|
|
|
|
|
|
def test_character_text_splitter_long() -> None:
|
|
"""Test splitting by character count on long words."""
|
|
text = "foo bar baz a a"
|
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=1)
|
|
output = splitter.split_text(text)
|
|
expected_output = ["foo", "bar", "baz", "a a"]
|
|
assert output == expected_output
|
|
|
|
|
|
def test_character_text_splitter_short_words_first() -> None:
|
|
"""Test splitting by character count when shorter words are first."""
|
|
text = "a a foo bar baz"
|
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=1)
|
|
output = splitter.split_text(text)
|
|
expected_output = ["a a", "foo", "bar", "baz"]
|
|
assert output == expected_output
|
|
|
|
|
|
def test_character_text_splitter_longer_words() -> None:
|
|
"""Test splitting by characters when splits not found easily."""
|
|
text = "foo bar baz 123"
|
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=1, chunk_overlap=1)
|
|
output = splitter.split_text(text)
|
|
expected_output = ["foo", "bar", "baz", "123"]
|
|
assert output == expected_output
|
|
|
|
|
|
def test_character_text_splitting_args() -> None:
|
|
"""Test invalid arguments."""
|
|
with pytest.raises(ValueError):
|
|
CharacterTextSplitter(chunk_size=2, chunk_overlap=4)
|
|
|
|
|
|
def test_merge_splits() -> None:
|
|
"""Test merging splits with a given separator."""
|
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=9, chunk_overlap=2)
|
|
splits = ["foo", "bar", "baz"]
|
|
expected_output = ["foo bar", "baz"]
|
|
output = splitter._merge_splits(splits, separator=" ")
|
|
assert output == expected_output
|
|
|
|
|
|
def test_create_documents() -> None:
|
|
"""Test create documents method."""
|
|
texts = ["foo bar", "baz"]
|
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=0)
|
|
docs = splitter.create_documents(texts)
|
|
expected_docs = [
|
|
Document(page_content="foo"),
|
|
Document(page_content="bar"),
|
|
Document(page_content="baz"),
|
|
]
|
|
assert docs == expected_docs
|
|
|
|
|
|
def test_create_documents_with_metadata() -> None:
|
|
"""Test create documents with metadata method."""
|
|
texts = ["foo bar", "baz"]
|
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=0)
|
|
docs = splitter.create_documents(texts, [{"source": "1"}, {"source": "2"}])
|
|
expected_docs = [
|
|
Document(page_content="foo", metadata={"source": "1"}),
|
|
Document(page_content="bar", metadata={"source": "1"}),
|
|
Document(page_content="baz", metadata={"source": "2"}),
|
|
]
|
|
assert docs == expected_docs
|
|
|
|
|
|
def test_metadata_not_shallow() -> None:
|
|
"""Test that metadatas are not shallow."""
|
|
texts = ["foo bar"]
|
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=0)
|
|
docs = splitter.create_documents(texts, [{"source": "1"}])
|
|
expected_docs = [
|
|
Document(page_content="foo", metadata={"source": "1"}),
|
|
Document(page_content="bar", metadata={"source": "1"}),
|
|
]
|
|
assert docs == expected_docs
|
|
docs[0].metadata["foo"] = 1
|
|
assert docs[0].metadata == {"source": "1", "foo": 1}
|
|
assert docs[1].metadata == {"source": "1"}
|
|
|
|
|
|
def test_iterative_text_splitter() -> None:
|
|
"""Test iterative text splitter."""
|
|
text = """Hi.\n\nI'm Harrison.\n\nHow? Are? You?\nOkay then f f f f.
|
|
This is a weird text to write, but gotta test the splittingggg some how.
|
|
|
|
Bye!\n\n-H."""
|
|
splitter = RecursiveCharacterTextSplitter(chunk_size=10, chunk_overlap=1)
|
|
output = splitter.split_text(text)
|
|
expected_output = [
|
|
"Hi.",
|
|
"I'm",
|
|
"Harrison.",
|
|
"How? Are?",
|
|
"You?",
|
|
"Okay then",
|
|
"f f f f.",
|
|
"This is a",
|
|
"weird",
|
|
"text to",
|
|
"write,",
|
|
"but gotta",
|
|
"test the",
|
|
"splitting",
|
|
"gggg",
|
|
"some how.",
|
|
"Bye!",
|
|
"-H.",
|
|
]
|
|
assert output == expected_output
|
|
|
|
|
|
def test_split_documents() -> None:
|
|
"""Test split_documents."""
|
|
splitter = CharacterTextSplitter(separator="", chunk_size=1, chunk_overlap=0)
|
|
docs = [
|
|
Document(page_content="foo", metadata={"source": "1"}),
|
|
Document(page_content="bar", metadata={"source": "2"}),
|
|
Document(page_content="baz", metadata={"source": "1"}),
|
|
]
|
|
expected_output = [
|
|
Document(page_content="f", metadata={"source": "1"}),
|
|
Document(page_content="o", metadata={"source": "1"}),
|
|
Document(page_content="o", metadata={"source": "1"}),
|
|
Document(page_content="b", metadata={"source": "2"}),
|
|
Document(page_content="a", metadata={"source": "2"}),
|
|
Document(page_content="r", metadata={"source": "2"}),
|
|
Document(page_content="b", metadata={"source": "1"}),
|
|
Document(page_content="a", metadata={"source": "1"}),
|
|
Document(page_content="z", metadata={"source": "1"}),
|
|
]
|
|
assert splitter.split_documents(docs) == expected_output
|
|
|
|
|
|
def test_python_text_splitter() -> None:
|
|
splitter = PythonCodeTextSplitter(chunk_size=30, chunk_overlap=0)
|
|
splits = splitter.split_text(FAKE_PYTHON_TEXT)
|
|
split_0 = """class Foo:\n\n def bar():"""
|
|
split_1 = """def foo():"""
|
|
split_2 = """def testing_func():"""
|
|
split_3 = """def bar():"""
|
|
expected_splits = [split_0, split_1, split_2, split_3]
|
|
assert splits == expected_splits
|
|
|
|
|
|
CHUNK_SIZE = 16
|
|
|
|
|
|
def test_python_code_splitter() -> None:
|
|
splitter = RecursiveCharacterTextSplitter.from_language(
|
|
Language.PYTHON, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
|
)
|
|
code = """
|
|
def hello_world():
|
|
print("Hello, World!")
|
|
|
|
# Call the function
|
|
hello_world()
|
|
"""
|
|
chunks = splitter.split_text(code)
|
|
assert chunks == [
|
|
"def",
|
|
"hello_world():",
|
|
'print("Hello,',
|
|
'World!")',
|
|
"# Call the",
|
|
"function",
|
|
"hello_world()",
|
|
]
|
|
|
|
|
|
def test_golang_code_splitter() -> None:
|
|
splitter = RecursiveCharacterTextSplitter.from_language(
|
|
Language.GO, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
|
)
|
|
code = """
|
|
package main
|
|
|
|
import "fmt"
|
|
|
|
func helloWorld() {
|
|
fmt.Println("Hello, World!")
|
|
}
|
|
|
|
func main() {
|
|
helloWorld()
|
|
}
|
|
"""
|
|
chunks = splitter.split_text(code)
|
|
assert chunks == [
|
|
"package main",
|
|
'import "fmt"',
|
|
"func",
|
|
"helloWorld() {",
|
|
'fmt.Println("He',
|
|
"llo,",
|
|
'World!")',
|
|
"}",
|
|
"func main() {",
|
|
"helloWorld()",
|
|
"}",
|
|
]
|
|
|
|
|
|
def test_rst_code_splitter() -> None:
|
|
splitter = RecursiveCharacterTextSplitter.from_language(
|
|
Language.RST, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
|
)
|
|
code = """
|
|
Sample Document
|
|
===============
|
|
|
|
Section
|
|
-------
|
|
|
|
This is the content of the section.
|
|
|
|
Lists
|
|
-----
|
|
|
|
- Item 1
|
|
- Item 2
|
|
- Item 3
|
|
"""
|
|
chunks = splitter.split_text(code)
|
|
assert chunks == [
|
|
"Sample Document",
|
|
"===============",
|
|
"Section",
|
|
"-------",
|
|
"This is the",
|
|
"content of the",
|
|
"section.",
|
|
"Lists\n-----",
|
|
"- Item 1",
|
|
"- Item 2",
|
|
"- Item 3",
|
|
]
|
|
|
|
|
|
def test_proto_file_splitter() -> None:
|
|
splitter = RecursiveCharacterTextSplitter.from_language(
|
|
Language.PROTO, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
|
)
|
|
code = """
|
|
syntax = "proto3";
|
|
|
|
package example;
|
|
|
|
message Person {
|
|
string name = 1;
|
|
int32 age = 2;
|
|
repeated string hobbies = 3;
|
|
}
|
|
"""
|
|
chunks = splitter.split_text(code)
|
|
assert chunks == [
|
|
"syntax =",
|
|
'"proto3";',
|
|
"package",
|
|
"example;",
|
|
"message Person",
|
|
"{",
|
|
"string name",
|
|
"= 1;",
|
|
"int32 age =",
|
|
"2;",
|
|
"repeated",
|
|
"string hobbies",
|
|
"= 3;",
|
|
"}",
|
|
]
|
|
|
|
|
|
def test_javascript_code_splitter() -> None:
|
|
splitter = RecursiveCharacterTextSplitter.from_language(
|
|
Language.JS, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
|
)
|
|
code = """
|
|
function helloWorld() {
|
|
console.log("Hello, World!");
|
|
}
|
|
|
|
// Call the function
|
|
helloWorld();
|
|
"""
|
|
chunks = splitter.split_text(code)
|
|
assert chunks == [
|
|
"function",
|
|
"helloWorld() {",
|
|
'console.log("He',
|
|
"llo,",
|
|
'World!");',
|
|
"}",
|
|
"// Call the",
|
|
"function",
|
|
"helloWorld();",
|
|
]
|
|
|
|
|
|
def test_java_code_splitter() -> None:
|
|
splitter = RecursiveCharacterTextSplitter.from_language(
|
|
Language.JAVA, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
|
)
|
|
code = """
|
|
public class HelloWorld {
|
|
public static void main(String[] args) {
|
|
System.out.println("Hello, World!");
|
|
}
|
|
}
|
|
"""
|
|
chunks = splitter.split_text(code)
|
|
assert chunks == [
|
|
"public class",
|
|
"HelloWorld {",
|
|
"public",
|
|
"static void",
|
|
"main(String[]",
|
|
"args) {",
|
|
"System.out.prin",
|
|
'tln("Hello,',
|
|
'World!");',
|
|
"}\n}",
|
|
]
|
|
|
|
|
|
def test_cpp_code_splitter() -> None:
|
|
splitter = RecursiveCharacterTextSplitter.from_language(
|
|
Language.CPP, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
|
)
|
|
code = """
|
|
#include <iostream>
|
|
|
|
int main() {
|
|
std::cout << "Hello, World!" << std::endl;
|
|
return 0;
|
|
}
|
|
"""
|
|
chunks = splitter.split_text(code)
|
|
assert chunks == [
|
|
"#include",
|
|
"<iostream>",
|
|
"int main() {",
|
|
"std::cout",
|
|
'<< "Hello,',
|
|
'World!" <<',
|
|
"std::endl;",
|
|
"return 0;\n}",
|
|
]
|
|
|
|
|
|
def test_scala_code_splitter() -> None:
|
|
splitter = RecursiveCharacterTextSplitter.from_language(
|
|
Language.SCALA, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
|
)
|
|
code = """
|
|
object HelloWorld {
|
|
def main(args: Array[String]): Unit = {
|
|
println("Hello, World!")
|
|
}
|
|
}
|
|
"""
|
|
chunks = splitter.split_text(code)
|
|
assert chunks == [
|
|
"object",
|
|
"HelloWorld {",
|
|
"def",
|
|
"main(args:",
|
|
"Array[String]):",
|
|
"Unit = {",
|
|
'println("Hello,',
|
|
'World!")',
|
|
"}\n}",
|
|
]
|
|
|
|
|
|
def test_ruby_code_splitter() -> None:
|
|
splitter = RecursiveCharacterTextSplitter.from_language(
|
|
Language.RUBY, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
|
)
|
|
code = """
|
|
def hello_world
|
|
puts "Hello, World!"
|
|
end
|
|
|
|
hello_world
|
|
"""
|
|
chunks = splitter.split_text(code)
|
|
assert chunks == [
|
|
"def hello_world",
|
|
'puts "Hello,',
|
|
'World!"',
|
|
"end",
|
|
"hello_world",
|
|
]
|
|
|
|
|
|
def test_php_code_splitter() -> None:
|
|
splitter = RecursiveCharacterTextSplitter.from_language(
|
|
Language.PHP, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
|
)
|
|
code = """
|
|
<?php
|
|
function hello_world() {
|
|
echo "Hello, World!";
|
|
}
|
|
|
|
hello_world();
|
|
?>
|
|
"""
|
|
chunks = splitter.split_text(code)
|
|
assert chunks == [
|
|
"<?php",
|
|
"function",
|
|
"hello_world() {",
|
|
"echo",
|
|
'"Hello,',
|
|
'World!";',
|
|
"}",
|
|
"hello_world();",
|
|
"?>",
|
|
]
|
|
|
|
|
|
def test_swift_code_splitter() -> None:
|
|
splitter = RecursiveCharacterTextSplitter.from_language(
|
|
Language.SWIFT, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
|
)
|
|
code = """
|
|
func helloWorld() {
|
|
print("Hello, World!")
|
|
}
|
|
|
|
helloWorld()
|
|
"""
|
|
chunks = splitter.split_text(code)
|
|
assert chunks == [
|
|
"func",
|
|
"helloWorld() {",
|
|
'print("Hello,',
|
|
'World!")',
|
|
"}",
|
|
"helloWorld()",
|
|
]
|
|
|
|
|
|
def test_rust_code_splitter() -> None:
|
|
splitter = RecursiveCharacterTextSplitter.from_language(
|
|
Language.RUST, chunk_size=CHUNK_SIZE, chunk_overlap=0
|
|
)
|
|
code = """
|
|
fn main() {
|
|
println!("Hello, World!");
|
|
}
|
|
"""
|
|
chunks = splitter.split_text(code)
|
|
assert chunks == ["fn main() {", 'println!("Hello', ",", 'World!");', "}"]
|