From d56313acba3a0082dda8be67df75ffdb6030b385 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Mon, 22 May 2023 23:00:24 -0400 Subject: [PATCH] Improve effeciency of TextSplitter.split_documents, iterate once (#5111) # Improve TextSplitter.split_documents, collect page_content and metadata in one iteration ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: @eyurtsev In the case where documents is a generator that can only be iterated once making this change is a huge help. Otherwise a silent issue happens where metadata is empty for all documents when documents is a generator. So we expand the argument from `List[Document]` to `Union[Iterable[Document], Sequence[Document]]` --------- Co-authored-by: Steven Tartakovsky --- langchain/text_splitter.py | 8 +++++--- tests/unit_tests/test_text_splitter.py | 22 ++++++++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index 19b86e6069..5a6a69ab10 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -64,10 +64,12 @@ class TextSplitter(BaseDocumentTransformer, ABC): documents.append(new_doc) return documents - def split_documents(self, documents: List[Document]) -> List[Document]: + def split_documents(self, documents: Iterable[Document]) -> List[Document]: """Split documents.""" - texts = [doc.page_content for doc in documents] - metadatas = [doc.metadata for doc in documents] + texts, metadatas = [], [] + for doc in documents: + texts.append(doc.page_content) + metadatas.append(doc.metadata) return self.create_documents(texts, metadatas=metadatas) def _join_docs(self, docs: List[str], separator: str) -> Optional[str]: diff --git a/tests/unit_tests/test_text_splitter.py b/tests/unit_tests/test_text_splitter.py index 40f3c2bcc7..31736a9155 100644 --- a/tests/unit_tests/test_text_splitter.py +++ b/tests/unit_tests/test_text_splitter.py @@ -146,3 +146,25 @@ Bye!\n\n-H.""" "Bye!\n\n-H.", ] assert output == expected_output + + +def test_split_documents() -> None: + """Test split_documents.""" + splitter = CharacterTextSplitter(separator="", chunk_size=1, chunk_overlap=0) + docs = [ + Document(page_content="foo", metadata={"source": "1"}), + Document(page_content="bar", metadata={"source": "2"}), + Document(page_content="baz", metadata={"source": "1"}), + ] + expected_output = [ + Document(page_content="f", metadata={"source": "1"}), + Document(page_content="o", metadata={"source": "1"}), + Document(page_content="o", metadata={"source": "1"}), + Document(page_content="b", metadata={"source": "2"}), + Document(page_content="a", metadata={"source": "2"}), + Document(page_content="r", metadata={"source": "2"}), + Document(page_content="b", metadata={"source": "1"}), + Document(page_content="a", metadata={"source": "1"}), + Document(page_content="z", metadata={"source": "1"}), + ] + assert splitter.split_documents(docs) == expected_output