chore: spedd up integration test by using smaller model (#6044)

Adds a new parameter `relative_chunk_overlap` for the
`SentenceTransformersTokenTextSplitter` constructor. The parameter sets
the chunk overlap using a relative factor, e.g. for a model where the
token limit is 100, a `relative_chunk_overlap=0.5` implies that
`chunk_overlap=50`

Tag maintainers/contributors who might be interested:

 @hwchase17, @dev2049
This commit is contained in:
Jens Madsen 2023-06-12 22:27:10 +02:00 committed by GitHub
parent 5922742d56
commit 2c91f0d750
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -52,14 +52,14 @@ def test_token_text_splitter_from_tiktoken() -> None:
def test_sentence_transformers_count_tokens() -> None:
splitter = SentenceTransformersTokenTextSplitter(
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
model_name="sentence-transformers/paraphrase-albert-small-v2"
)
text = "Lorem ipsum"
token_count = splitter.count_tokens(text=text)
expected_start_stop_token_count = 2
expected_text_token_count = 2
expected_text_token_count = 5
expected_token_count = expected_start_stop_token_count + expected_text_token_count
assert expected_token_count == token_count
@ -67,9 +67,9 @@ def test_sentence_transformers_count_tokens() -> None:
def test_sentence_transformers_split_text() -> None:
splitter = SentenceTransformersTokenTextSplitter(
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
model_name="sentence-transformers/paraphrase-albert-small-v2"
)
text = "Lorem ipsum"
text = "lorem ipsum"
text_chunks = splitter.split_text(text=text)
expected_text_chunks = [text]
assert expected_text_chunks == text_chunks
@ -79,14 +79,29 @@ def test_sentence_transformers_multiple_tokens() -> None:
splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0)
text = "Lorem "
text_token_count_including_start_and_stop_tokens = splitter.count_tokens(text=text)
count_start_and_end_tokens = 2
text_token_count = splitter.count_tokens(text=text) - count_start_and_end_tokens
token_multiplier = splitter.maximum_tokens_per_chunk // text_token_count + 1
text_chunks = splitter.split_text(text=text * token_multiplier)
token_multiplier = (
count_start_and_end_tokens
+ (splitter.maximum_tokens_per_chunk - count_start_and_end_tokens)
// (
text_token_count_including_start_and_stop_tokens
- count_start_and_end_tokens
)
+ 1
)
# `text_to_split` does not fit in a single chunk
text_to_embed = text * token_multiplier
text_chunks = splitter.split_text(text=text_to_embed)
expected_number_of_chunks = 2
assert expected_number_of_chunks == len(text_chunks)
actual = splitter.count_tokens(text=text_chunks[1]) - count_start_and_end_tokens
expected = token_multiplier * text_token_count - splitter.maximum_tokens_per_chunk
expected = (
token_multiplier * (text_token_count_including_start_and_stop_tokens - 2)
- splitter.maximum_tokens_per_chunk
)
assert expected == actual