huggingface tokenizer (#75)

harrison/ape
Harrison Chase 2 years ago committed by GitHub
parent b542941234
commit d87e73ddb1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,104 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "e82c4685",
"metadata": {},
"outputs": [],
"source": [
"from langchain.text_splitter import HuggingFaceTokenizerSplitter"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a8ce51d5",
"metadata": {},
"outputs": [],
"source": [
"from transformers import GPT2TokenizerFast\n",
"\n",
"tokenizer = GPT2TokenizerFast.from_pretrained(\"gpt2\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ca5e72c0",
"metadata": {},
"outputs": [],
"source": [
"with open('state_of_the_union.txt') as f:\n",
" state_of_the_union = f.read()\n",
"text_splitter = HuggingFaceTokenizerSplitter(tokenizer, chunk_size=1000, chunk_overlap=0)\n",
"texts = text_splitter.split_text(state_of_the_union)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "37cdfbeb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet. Justices of the Supreme Court. My fellow Americans. \n",
"\n",
"Last year COVID-19 kept us apart. This year we are finally together again. \n",
"\n",
"Tonight, we meet as Democrats Republicans and Independents. But most importantly as Americans. \n",
"\n",
"With a duty to one another to the American people to the Constitution. \n",
"\n",
"And with an unwavering resolve that freedom will always triumph over tyranny. \n",
"\n",
"Six days ago, Russias Vladimir Putin sought to shake the foundations of the free world thinking he could make it bend to his menacing ways. But he badly miscalculated. \n",
"\n",
"He thought he could roll into Ukraine and the world would roll over. Instead he met a wall of strength he never imagined. \n",
"\n",
"He met the Ukrainian people. \n",
"\n",
"From President Zelenskyy to every Ukrainian, their fearlessness, their courage, their determination, inspires the world. \n",
"\n",
"Groups of citizens blocking tanks with their bodies. Everyone from students to retirees teachers turned soldiers defending their homeland. \n"
]
}
],
"source": [
"print(texts[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d214aec2",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -1,12 +1,18 @@
"""Functionality for splitting text."""
from abc import abstractmethod
from typing import Iterable, List
from abc import ABC, abstractmethod
from typing import Any, Callable, Iterable, List
class TextSplitter:
class TextSplitter(ABC):
"""Interface for splitting text into chunks."""
def __init__(self, separator: str, chunk_size: int, chunk_overlap: int):
def __init__(
self,
separator: str = "\n\n",
chunk_size: int = 4000,
chunk_overlap: int = 200,
length_function: Callable[[str], int] = len,
):
"""Create a new TextSplitter."""
if chunk_overlap > chunk_size:
raise ValueError(
@ -16,6 +22,7 @@ class TextSplitter:
self._separator = separator
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._length_function = length_function
@abstractmethod
def split_text(self, text: str) -> List[str]:
@ -28,29 +35,43 @@ class TextSplitter:
current_doc: List[str] = []
total = 0
for d in splits:
if total > self._chunk_size:
if total >= self._chunk_size:
docs.append(self._separator.join(current_doc))
while total > self._chunk_overlap:
total -= len(current_doc[0])
total -= self._length_function(current_doc[0])
current_doc = current_doc[1:]
current_doc.append(d)
total += len(d)
total += self._length_function(d)
docs.append(self._separator.join(current_doc))
return docs
@classmethod
def from_huggingface_tokenizer(
cls, tokenizer: Any, **kwargs: Any
) -> "TextSplitter":
"""Text splitter than uses HuggingFace tokenizer to count length."""
try:
from transformers import PreTrainedTokenizerBase
if not isinstance(tokenizer, PreTrainedTokenizerBase):
raise ValueError(
"Tokenizer received was not an instance of PreTrainedTokenizerBase"
)
def _huggingface_tokenizer_length(text: str) -> int:
return len(tokenizer.encode(text))
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"Please it install it with `pip install transformers`."
)
return cls(length_function=_huggingface_tokenizer_length, **kwargs)
class CharacterTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at characters."""
def __init__(
self, separator: str = "\n\n", chunk_size: int = 4000, chunk_overlap: int = 200
):
"""Create a new CharacterTextSplitter."""
super(CharacterTextSplitter, self).__init__(
separator, chunk_size, chunk_overlap
)
self._separator = separator
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
# First we naively split the large input into a bunch of smaller ones.

@ -10,6 +10,7 @@ wikipedia
huggingface_hub
faiss-cpu
sentence_transformers
transformers
manifest-ml
spacy
nltk

@ -0,0 +1,23 @@
"""Test text splitters that require an integration."""
import pytest
from langchain.text_splitter import CharacterTextSplitter
def test_huggingface_type_check() -> None:
"""Test that type checks are done properly on input."""
with pytest.raises(ValueError):
CharacterTextSplitter.from_huggingface_tokenizer("foo")
def test_huggingface_tokenizer() -> None:
"""Test text splitter that uses a HuggingFace tokenizer."""
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
text_splitter = CharacterTextSplitter.from_huggingface_tokenizer(
tokenizer, separator=" ", chunk_size=1, chunk_overlap=0
)
output = text_splitter.split_text("foo bar")
assert output == ["foo", "bar"]
Loading…
Cancel
Save