Harrison/fix text splitter (#1511)

Co-authored-by: ajaysolanky <ajsolanky@gmail.com>
Co-authored-by: Ajay Solanky <ajaysolanky@saw-l14668307kd.myfiosgateway.com>
fix-searx
Harrison Chase 1 year ago committed by GitHub
parent e3354404ad
commit 064741db58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -71,12 +71,17 @@ class TextSplitter(ABC):
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
# We now want to combine these smaller pieces into medium size # We now want to combine these smaller pieces into medium size
# chunks to send to the LLM. # chunks to send to the LLM.
separator_len = self._length_function(separator)
docs = [] docs = []
current_doc: List[str] = [] current_doc: List[str] = []
total = 0 total = 0
for d in splits: for d in splits:
_len = self._length_function(d) _len = self._length_function(d)
if total + _len >= self._chunk_size: if (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
):
if total > self._chunk_size: if total > self._chunk_size:
logger.warning( logger.warning(
f"Created a chunk of size {total}, " f"Created a chunk of size {total}, "
@ -90,12 +95,16 @@ class TextSplitter(ABC):
# - we have a larger chunk than in the chunk overlap # - we have a larger chunk than in the chunk overlap
# - or if we still have any chunks and the length is long # - or if we still have any chunks and the length is long
while total > self._chunk_overlap or ( while total > self._chunk_overlap or (
total + _len > self._chunk_size and total > 0 total + _len + (separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
and total > 0
): ):
total -= self._length_function(current_doc[0]) total -= self._length_function(current_doc[0]) + (
separator_len if len(current_doc) > 1 else 0
)
current_doc = current_doc[1:] current_doc = current_doc[1:]
current_doc.append(d) current_doc.append(d)
total += _len total += _len + (separator_len if len(current_doc) > 1 else 0)
doc = self._join_docs(current_doc, separator) doc = self._join_docs(current_doc, separator)
if doc is not None: if doc is not None:
docs.append(doc) docs.append(doc)

@ -26,6 +26,15 @@ def test_character_text_splitter_empty_doc() -> None:
assert output == expected_output assert output == expected_output
def test_character_text_splitter_separtor_empty_doc() -> None:
"""Test edge cases are separators."""
text = "f b"
splitter = CharacterTextSplitter(separator=" ", chunk_size=2, chunk_overlap=0)
output = splitter.split_text(text)
expected_output = ["f", "b"]
assert output == expected_output
def test_character_text_splitter_long() -> None: def test_character_text_splitter_long() -> None:
"""Test splitting by character count on long words.""" """Test splitting by character count on long words."""
text = "foo bar baz a a" text = "foo bar baz a a"
@ -99,7 +108,7 @@ Bye!\n\n-H."""
"Harrison.", "Harrison.",
"How? Are?", "How? Are?",
"You?", "You?",
"Okay then f", "Okay then",
"f f f f.", "f f f f.",
"This is a", "This is a",
"a weird", "a weird",
@ -107,8 +116,8 @@ Bye!\n\n-H."""
"write, but", "write, but",
"gotta test", "gotta test",
"the", "the",
"splitting", "splittingg",
"gggg", "ggg",
"some how.", "some how.",
"Bye!\n\n-H.", "Bye!\n\n-H.",
] ]

Loading…
Cancel
Save