diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index 0cbe9e4ed8..37de47f0fc 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -34,6 +34,23 @@ logger = logging.getLogger(__name__) TS = TypeVar("TS", bound="TextSplitter") +def _make_spacy_pipeline_for_splitting(pipeline: str) -> Any: # avoid importing spacy + try: + import spacy + except ImportError: + raise ImportError( + "Spacy is not installed, please install it with `pip install spacy`." + ) + if pipeline == "sentencizer": + from spacy.lang.en import English + + sentencizer = English() + sentencizer.add_pipe("sentencizer") + else: + sentencizer = spacy.load(pipeline, disable=["ner"]) + return sentencizer + + def _split_text_with_regex( text: str, separator: str, keep_separator: bool ) -> List[str]: @@ -1010,25 +1027,24 @@ class NLTKTextSplitter(TextSplitter): class SpacyTextSplitter(TextSplitter): - """Implementation of splitting text that looks at sentences using Spacy.""" + """Implementation of splitting text that looks at sentences using Spacy. + + + Per default, Spacy's `en_core_web_sm` model is used. For a faster, but + potentially less accurate splitting, you can use `pipeline='sentencizer'`. + """ def __init__( self, separator: str = "\n\n", pipeline: str = "en_core_web_sm", **kwargs: Any ) -> None: """Initialize the spacy text splitter.""" super().__init__(**kwargs) - try: - import spacy - except ImportError: - raise ImportError( - "Spacy is not installed, please install it with `pip install spacy`." - ) - self._tokenizer = spacy.load(pipeline) + self._tokenizer = _make_spacy_pipeline_for_splitting(pipeline) self._separator = separator def split_text(self, text: str) -> List[str]: """Split incoming text and return chunks.""" - splits = (str(s) for s in self._tokenizer(text).sents) + splits = (s.text for s in self._tokenizer(text).sents) return self._merge_splits(splits, self._separator) diff --git a/tests/integration_tests/test_nlp_text_splitters.py b/tests/integration_tests/test_nlp_text_splitters.py index 4837fe20ad..0f2809ede4 100644 --- a/tests/integration_tests/test_nlp_text_splitters.py +++ b/tests/integration_tests/test_nlp_text_splitters.py @@ -26,11 +26,12 @@ def test_nltk_text_splitter() -> None: assert output == expected_output -def test_spacy_text_splitter() -> None: +@pytest.mark.parametrize("pipeline", ["sentencizer", "en_core_web_sm"]) +def test_spacy_text_splitter(pipeline: str) -> None: """Test splitting by sentence using Spacy.""" text = "This is sentence one. And this is sentence two." separator = "|||" - splitter = SpacyTextSplitter(separator=separator) + splitter = SpacyTextSplitter(separator=separator, pipeline=pipeline) output = splitter.split_text(text) expected_output = [f"This is sentence one.{separator}And this is sentence two."] assert output == expected_output