langchain[patch]: inconsistent results with `RecursiveCharacterTextSplitter`'s `add_start_index=True` (#16583)

This PR fixes issue #16579
pull/16581/head^2
Antonio Lanza 5 months ago committed by GitHub
parent 42db96477f
commit 08d3fd7f2e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -141,12 +141,15 @@ class TextSplitter(BaseDocumentTransformer, ABC):
_metadatas = metadatas or [{}] * len(texts)
documents = []
for i, text in enumerate(texts):
index = -1
index = 0
previous_chunk_len = 0
for chunk in self.split_text(text):
metadata = copy.deepcopy(_metadatas[i])
if self._add_start_index:
index = text.find(chunk, index + 1)
offset = index + previous_chunk_len - self._chunk_overlap
index = text.find(chunk, max(0, offset))
metadata["start_index"] = index
previous_chunk_len = len(chunk)
new_doc = Document(page_content=chunk, metadata=metadata)
documents.append(new_doc)
return documents

@ -13,6 +13,7 @@ from langchain.text_splitter import (
MarkdownHeaderTextSplitter,
PythonCodeTextSplitter,
RecursiveCharacterTextSplitter,
TextSplitter,
Tokenizer,
split_text_on_tokens,
)
@ -169,19 +170,47 @@ def test_create_documents_with_metadata() -> None:
assert docs == expected_docs
def test_create_documents_with_start_index() -> None:
@pytest.mark.parametrize(
"splitter, text, expected_docs",
[
(
CharacterTextSplitter(
separator=" ", chunk_size=7, chunk_overlap=3, add_start_index=True
),
"foo bar baz 123",
[
Document(page_content="foo bar", metadata={"start_index": 0}),
Document(page_content="bar baz", metadata={"start_index": 4}),
Document(page_content="baz 123", metadata={"start_index": 8}),
],
),
(
RecursiveCharacterTextSplitter(
chunk_size=6,
chunk_overlap=0,
separators=["\n\n", "\n", " ", ""],
add_start_index=True,
),
"w1 w1 w1 w1 w1 w1 w1 w1 w1",
[
Document(page_content="w1 w1", metadata={"start_index": 0}),
Document(page_content="w1 w1", metadata={"start_index": 6}),
Document(page_content="w1 w1", metadata={"start_index": 12}),
Document(page_content="w1 w1", metadata={"start_index": 18}),
Document(page_content="w1", metadata={"start_index": 24}),
],
),
],
)
def test_create_documents_with_start_index(
splitter: TextSplitter, text: str, expected_docs: List[Document]
) -> None:
"""Test create documents method."""
texts = ["foo bar baz 123"]
splitter = CharacterTextSplitter(
separator=" ", chunk_size=7, chunk_overlap=3, add_start_index=True
)
docs = splitter.create_documents(texts)
expected_docs = [
Document(page_content="foo bar", metadata={"start_index": 0}),
Document(page_content="bar baz", metadata={"start_index": 4}),
Document(page_content="baz 123", metadata={"start_index": 8}),
]
docs = splitter.create_documents([text])
assert docs == expected_docs
for doc in docs:
s_i = doc.metadata["start_index"]
assert text[s_i : s_i + len(doc.page_content)] == doc.page_content
def test_metadata_not_shallow() -> None:

Loading…
Cancel
Save