mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Harrison/fix splitting (#563)
fix issue where text splitting could possibly create empty docs
This commit is contained in:
parent
1192cc0767
commit
1511606799
@ -44,6 +44,14 @@ class TextSplitter(ABC):
|
||||
documents.append(Document(page_content=chunk, metadata=_metadatas[i]))
|
||||
return documents
|
||||
|
||||
def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
|
||||
text = separator.join(docs)
|
||||
text = text.strip()
|
||||
if text == "":
|
||||
return None
|
||||
else:
|
||||
return text
|
||||
|
||||
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
|
||||
# We now want to combine these smaller pieces into medium size
|
||||
# chunks to send to the LLM.
|
||||
@ -59,7 +67,9 @@ class TextSplitter(ABC):
|
||||
f"which is longer than the specified {self._chunk_size}"
|
||||
)
|
||||
if len(current_doc) > 0:
|
||||
docs.append(separator.join(current_doc))
|
||||
doc = self._join_docs(current_doc, separator)
|
||||
if doc is not None:
|
||||
docs.append(doc)
|
||||
# Keep on popping if:
|
||||
# - we have a larger chunk than in the chunk overlap
|
||||
# - or if we still have any chunks and the length is long
|
||||
@ -70,7 +80,9 @@ class TextSplitter(ABC):
|
||||
current_doc = current_doc[1:]
|
||||
current_doc.append(d)
|
||||
total += _len
|
||||
docs.append(separator.join(current_doc))
|
||||
doc = self._join_docs(current_doc, separator)
|
||||
if doc is not None:
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
|
@ -17,6 +17,15 @@ def test_character_text_splitter() -> None:
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_character_text_splitter_empty_doc() -> None:
|
||||
"""Test splitting by character count doesn't create empty documents."""
|
||||
text = "foo bar"
|
||||
splitter = CharacterTextSplitter(separator=" ", chunk_size=2, chunk_overlap=0)
|
||||
output = splitter.split_text(text)
|
||||
expected_output = ["foo", "bar"]
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_character_text_splitter_long() -> None:
|
||||
"""Test splitting by character count on long words."""
|
||||
text = "foo bar baz a a"
|
||||
|
Loading…
Reference in New Issue
Block a user