|
|
|
@ -13,6 +13,7 @@ from langchain.text_splitter import (
|
|
|
|
|
MarkdownHeaderTextSplitter,
|
|
|
|
|
PythonCodeTextSplitter,
|
|
|
|
|
RecursiveCharacterTextSplitter,
|
|
|
|
|
TextSplitter,
|
|
|
|
|
Tokenizer,
|
|
|
|
|
split_text_on_tokens,
|
|
|
|
|
)
|
|
|
|
@ -169,19 +170,47 @@ def test_create_documents_with_metadata() -> None:
|
|
|
|
|
assert docs == expected_docs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_create_documents_with_start_index() -> None:
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
"splitter, text, expected_docs",
|
|
|
|
|
[
|
|
|
|
|
(
|
|
|
|
|
CharacterTextSplitter(
|
|
|
|
|
separator=" ", chunk_size=7, chunk_overlap=3, add_start_index=True
|
|
|
|
|
),
|
|
|
|
|
"foo bar baz 123",
|
|
|
|
|
[
|
|
|
|
|
Document(page_content="foo bar", metadata={"start_index": 0}),
|
|
|
|
|
Document(page_content="bar baz", metadata={"start_index": 4}),
|
|
|
|
|
Document(page_content="baz 123", metadata={"start_index": 8}),
|
|
|
|
|
],
|
|
|
|
|
),
|
|
|
|
|
(
|
|
|
|
|
RecursiveCharacterTextSplitter(
|
|
|
|
|
chunk_size=6,
|
|
|
|
|
chunk_overlap=0,
|
|
|
|
|
separators=["\n\n", "\n", " ", ""],
|
|
|
|
|
add_start_index=True,
|
|
|
|
|
),
|
|
|
|
|
"w1 w1 w1 w1 w1 w1 w1 w1 w1",
|
|
|
|
|
[
|
|
|
|
|
Document(page_content="w1 w1", metadata={"start_index": 0}),
|
|
|
|
|
Document(page_content="w1 w1", metadata={"start_index": 6}),
|
|
|
|
|
Document(page_content="w1 w1", metadata={"start_index": 12}),
|
|
|
|
|
Document(page_content="w1 w1", metadata={"start_index": 18}),
|
|
|
|
|
Document(page_content="w1", metadata={"start_index": 24}),
|
|
|
|
|
],
|
|
|
|
|
),
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
def test_create_documents_with_start_index(
|
|
|
|
|
splitter: TextSplitter, text: str, expected_docs: List[Document]
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Test create documents method."""
|
|
|
|
|
texts = ["foo bar baz 123"]
|
|
|
|
|
splitter = CharacterTextSplitter(
|
|
|
|
|
separator=" ", chunk_size=7, chunk_overlap=3, add_start_index=True
|
|
|
|
|
)
|
|
|
|
|
docs = splitter.create_documents(texts)
|
|
|
|
|
expected_docs = [
|
|
|
|
|
Document(page_content="foo bar", metadata={"start_index": 0}),
|
|
|
|
|
Document(page_content="bar baz", metadata={"start_index": 4}),
|
|
|
|
|
Document(page_content="baz 123", metadata={"start_index": 8}),
|
|
|
|
|
]
|
|
|
|
|
docs = splitter.create_documents([text])
|
|
|
|
|
assert docs == expected_docs
|
|
|
|
|
for doc in docs:
|
|
|
|
|
s_i = doc.metadata["start_index"]
|
|
|
|
|
assert text[s_i : s_i + len(doc.page_content)] == doc.page_content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_metadata_not_shallow() -> None:
|
|
|
|
|