diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index 23332f29..7c875682 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -20,6 +20,7 @@ from typing import ( Type, TypeVar, Union, + cast, ) from langchain.docstore.document import Document @@ -59,7 +60,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): length_function: Callable[[str], int] = len, keep_separator: bool = False, add_start_index: bool = False, - ): + ) -> None: """Create a new TextSplitter. Args: @@ -240,7 +241,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): class CharacterTextSplitter(TextSplitter): """Implementation of splitting text that looks at characters.""" - def __init__(self, separator: str = "\n\n", **kwargs: Any): + def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None: """Create a new TextSplitter.""" super().__init__(**kwargs) self._separator = separator @@ -265,7 +266,7 @@ class Tokenizer: def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]: """Split incoming text and return chunks.""" - splits = [] + splits: List[str] = [] input_ids = tokenizer.encode(text) start_idx = 0 cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) @@ -288,7 +289,7 @@ class TokenTextSplitter(TextSplitter): allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), disallowed_special: Union[Literal["all"], Collection[str]] = "all", **kwargs: Any, - ): + ) -> None: """Create a new TextSplitter.""" super().__init__(**kwargs) try: @@ -335,19 +336,28 @@ class SentenceTransformersTokenTextSplitter(TextSplitter): model_name: str = "sentence-transformers/all-mpnet-base-v2", tokens_per_chunk: Optional[int] = None, **kwargs: Any, - ): + ) -> None: """Create a new TextSplitter.""" super().__init__(**kwargs, chunk_overlap=chunk_overlap) - from transformers import AutoTokenizer + + try: + from sentence_transformers import SentenceTransformer + except ImportError: + raise ImportError( + "Could not import sentence_transformer python package. " + "This is needed in order to for SentenceTransformersTokenTextSplitter. " + "Please install it with `pip install sentence-transformers`." + ) self.model_name = model_name - self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + self._model = SentenceTransformer(self.model_name) + self.tokenizer = self._model.tokenizer self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk) def _initialize_chunk_configuration( self, *, tokens_per_chunk: Optional[int] ) -> None: - self.maximum_tokens_per_chunk = self.tokenizer.max_len_single_sentence + self.maximum_tokens_per_chunk = cast(int, self._model.max_seq_length) if tokens_per_chunk is None: self.tokens_per_chunk = self.maximum_tokens_per_chunk @@ -419,7 +429,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): separators: Optional[List[str]] = None, keep_separator: bool = True, **kwargs: Any, - ): + ) -> None: """Create a new TextSplitter.""" super().__init__(keep_separator=keep_separator, **kwargs) self._separators = separators or ["\n\n", "\n", " ", ""] @@ -785,7 +795,7 @@ class RecursiveCharacterTextSplitter(TextSplitter): class NLTKTextSplitter(TextSplitter): """Implementation of splitting text that looks at sentences using NLTK.""" - def __init__(self, separator: str = "\n\n", **kwargs: Any): + def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None: """Initialize the NLTK splitter.""" super().__init__(**kwargs) try: @@ -810,7 +820,7 @@ class SpacyTextSplitter(TextSplitter): def __init__( self, separator: str = "\n\n", pipeline: str = "en_core_web_sm", **kwargs: Any - ): + ) -> None: """Initialize the spacy text splitter.""" super().__init__(**kwargs) try: @@ -832,7 +842,7 @@ class SpacyTextSplitter(TextSplitter): class PythonCodeTextSplitter(RecursiveCharacterTextSplitter): """Attempts to split the text along Python syntax.""" - def __init__(self, **kwargs: Any): + def __init__(self, **kwargs: Any) -> None: """Initialize a PythonCodeTextSplitter.""" separators = self.get_separators_for_language(Language.PYTHON) super().__init__(separators=separators, **kwargs) @@ -841,7 +851,7 @@ class PythonCodeTextSplitter(RecursiveCharacterTextSplitter): class MarkdownTextSplitter(RecursiveCharacterTextSplitter): """Attempts to split the text along Markdown-formatted headings.""" - def __init__(self, **kwargs: Any): + def __init__(self, **kwargs: Any) -> None: """Initialize a MarkdownTextSplitter.""" separators = self.get_separators_for_language(Language.MARKDOWN) super().__init__(separators=separators, **kwargs) @@ -850,7 +860,7 @@ class MarkdownTextSplitter(RecursiveCharacterTextSplitter): class LatexTextSplitter(RecursiveCharacterTextSplitter): """Attempts to split the text along Latex-formatted layout elements.""" - def __init__(self, **kwargs: Any): + def __init__(self, **kwargs: Any) -> None: """Initialize a LatexTextSplitter.""" separators = self.get_separators_for_language(Language.LATEX) super().__init__(separators=separators, **kwargs)