diff --git a/libs/langchain/langchain/text_splitter.py b/libs/langchain/langchain/text_splitter.py index be189548c1..2e5f7021f3 100644 --- a/libs/langchain/langchain/text_splitter.py +++ b/libs/langchain/langchain/text_splitter.py @@ -1081,7 +1081,9 @@ class RecursiveCharacterTextSplitter(TextSplitter): class NLTKTextSplitter(TextSplitter): """Splitting text using NLTK package.""" - def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None: + def __init__( + self, separator: str = "\n\n", language: str = "english", **kwargs: Any + ) -> None: """Initialize the NLTK splitter.""" super().__init__(**kwargs) try: @@ -1093,11 +1095,12 @@ class NLTKTextSplitter(TextSplitter): "NLTK is not installed, please install it with `pip install nltk`." ) self._separator = separator + self._language = language def split_text(self, text: str) -> List[str]: """Split incoming text and return chunks.""" # First we naively split the large input into a bunch of smaller ones. - splits = self._tokenizer(text) + splits = self._tokenizer(text, language=self._language) return self._merge_splits(splits, self._separator)