mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
32 lines
1.0 KiB
Python
32 lines
1.0 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any, List
|
|
|
|
from langchain_text_splitters.base import TextSplitter
|
|
|
|
|
|
class NLTKTextSplitter(TextSplitter):
|
|
"""Splitting text using NLTK package."""
|
|
|
|
def __init__(
|
|
self, separator: str = "\n\n", language: str = "english", **kwargs: Any
|
|
) -> None:
|
|
"""Initialize the NLTK splitter."""
|
|
super().__init__(**kwargs)
|
|
try:
|
|
from nltk.tokenize import sent_tokenize
|
|
|
|
self._tokenizer = sent_tokenize
|
|
except ImportError:
|
|
raise ImportError(
|
|
"NLTK is not installed, please install it with `pip install nltk`."
|
|
)
|
|
self._separator = separator
|
|
self._language = language
|
|
|
|
def split_text(self, text: str) -> List[str]:
|
|
"""Split incoming text and return chunks."""
|
|
# First we naively split the large input into a bunch of smaller ones.
|
|
splits = self._tokenizer(text, language=self._language)
|
|
return self._merge_splits(splits, self._separator)
|