diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index 19b86e6069..5a6a69ab10 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -64,10 +64,12 @@ class TextSplitter(BaseDocumentTransformer, ABC): documents.append(new_doc) return documents - def split_documents(self, documents: List[Document]) -> List[Document]: + def split_documents(self, documents: Iterable[Document]) -> List[Document]: """Split documents.""" - texts = [doc.page_content for doc in documents] - metadatas = [doc.metadata for doc in documents] + texts, metadatas = [], [] + for doc in documents: + texts.append(doc.page_content) + metadatas.append(doc.metadata) return self.create_documents(texts, metadatas=metadatas) def _join_docs(self, docs: List[str], separator: str) -> Optional[str]: diff --git a/tests/unit_tests/test_text_splitter.py b/tests/unit_tests/test_text_splitter.py index 40f3c2bcc7..31736a9155 100644 --- a/tests/unit_tests/test_text_splitter.py +++ b/tests/unit_tests/test_text_splitter.py @@ -146,3 +146,25 @@ Bye!\n\n-H.""" "Bye!\n\n-H.", ] assert output == expected_output + + +def test_split_documents() -> None: + """Test split_documents.""" + splitter = CharacterTextSplitter(separator="", chunk_size=1, chunk_overlap=0) + docs = [ + Document(page_content="foo", metadata={"source": "1"}), + Document(page_content="bar", metadata={"source": "2"}), + Document(page_content="baz", metadata={"source": "1"}), + ] + expected_output = [ + Document(page_content="f", metadata={"source": "1"}), + Document(page_content="o", metadata={"source": "1"}), + Document(page_content="o", metadata={"source": "1"}), + Document(page_content="b", metadata={"source": "2"}), + Document(page_content="a", metadata={"source": "2"}), + Document(page_content="r", metadata={"source": "2"}), + Document(page_content="b", metadata={"source": "1"}), + Document(page_content="a", metadata={"source": "1"}), + Document(page_content="z", metadata={"source": "1"}), + ] + assert splitter.split_documents(docs) == expected_output