"""Functionality for splitting text.""" from __future__ import annotations import logging from abc import ABC, abstractmethod from typing import ( AbstractSet, Any, Callable, Collection, Iterable, List, Literal, Optional, Union, ) from langchain.docstore.document import Document logger = logging.getLogger() class TextSplitter(ABC): """Interface for splitting text into chunks.""" def __init__( self, chunk_size: int = 4000, chunk_overlap: int = 200, length_function: Callable[[str], int] = len, ): """Create a new TextSplitter.""" if chunk_overlap > chunk_size: raise ValueError( f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " f"({chunk_size}), should be smaller." ) self._chunk_size = chunk_size self._chunk_overlap = chunk_overlap self._length_function = length_function @abstractmethod def split_text(self, text: str) -> List[str]: """Split text into multiple components.""" def create_documents( self, texts: List[str], metadatas: Optional[List[dict]] = None ) -> List[Document]: """Create documents from a list of texts.""" _metadatas = metadatas or [{}] * len(texts) documents = [] for i, text in enumerate(texts): for chunk in self.split_text(text): documents.append(Document(page_content=chunk, metadata=_metadatas[i])) return documents def split_documents(self, documents: List[Document]) -> List[Document]: """Split documents.""" texts = [doc.page_content for doc in documents] metadatas = [doc.metadata for doc in documents] return self.create_documents(texts, metadatas) def _join_docs(self, docs: List[str], separator: str) -> Optional[str]: text = separator.join(docs) text = text.strip() if text == "": return None else: return text def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: # We now want to combine these smaller pieces into medium size # chunks to send to the LLM. docs = [] current_doc: List[str] = [] total = 0 for d in splits: _len = self._length_function(d) if total + _len >= self._chunk_size: if total > self._chunk_size: logger.warning( f"Created a chunk of size {total}, " f"which is longer than the specified {self._chunk_size}" ) if len(current_doc) > 0: doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc) # Keep on popping if: # - we have a larger chunk than in the chunk overlap # - or if we still have any chunks and the length is long while total > self._chunk_overlap or ( total + _len > self._chunk_size and total > 0 ): total -= self._length_function(current_doc[0]) current_doc = current_doc[1:] current_doc.append(d) total += _len doc = self._join_docs(current_doc, separator) if doc is not None: docs.append(doc) return docs @classmethod def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter: """Text splitter that uses HuggingFace tokenizer to count length.""" try: from transformers import PreTrainedTokenizerBase if not isinstance(tokenizer, PreTrainedTokenizerBase): raise ValueError( "Tokenizer received was not an instance of PreTrainedTokenizerBase" ) def _huggingface_tokenizer_length(text: str) -> int: return len(tokenizer.encode(text)) except ImportError: raise ValueError( "Could not import transformers python package. " "Please it install it with `pip install transformers`." ) return cls(length_function=_huggingface_tokenizer_length, **kwargs) @classmethod def from_tiktoken_encoder( cls, encoding_name: str = "gpt2", allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), disallowed_special: Union[Literal["all"], Collection[str]] = "all", **kwargs: Any, ) -> TextSplitter: """Text splitter that uses tiktoken encoder to count length.""" try: import tiktoken except ImportError: raise ValueError( "Could not import tiktoken python package. " "This is needed in order to calculate max_tokens_for_prompt. " "Please it install it with `pip install tiktoken`." ) # create a GPT-3 encoder instance enc = tiktoken.get_encoding(encoding_name) def _tiktoken_encoder(text: str, **kwargs: Any) -> int: return len( enc.encode( text, allowed_special=allowed_special, disallowed_special=disallowed_special, **kwargs, ) ) return cls(length_function=_tiktoken_encoder, **kwargs) class CharacterTextSplitter(TextSplitter): """Implementation of splitting text that looks at characters.""" def __init__(self, separator: str = "\n\n", **kwargs: Any): """Create a new TextSplitter.""" super().__init__(**kwargs) self._separator = separator def split_text(self, text: str) -> List[str]: """Split incoming text and return chunks.""" # First we naively split the large input into a bunch of smaller ones. if self._separator: splits = text.split(self._separator) else: splits = list(text) return self._merge_splits(splits, self._separator) class TokenTextSplitter(TextSplitter): """Implementation of splitting text that looks at tokens.""" def __init__( self, encoding_name: str = "gpt2", allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), disallowed_special: Union[Literal["all"], Collection[str]] = "all", **kwargs: Any, ): """Create a new TextSplitter.""" super().__init__(**kwargs) try: import tiktoken except ImportError: raise ValueError( "Could not import tiktoken python package. " "This is needed in order to for TokenTextSplitter. " "Please it install it with `pip install tiktoken`." ) # create a GPT-3 encoder instance self._tokenizer = tiktoken.get_encoding(encoding_name) self._allowed_special = allowed_special self._disallowed_special = disallowed_special def split_text(self, text: str) -> List[str]: """Split incoming text and return chunks.""" splits = [] input_ids = self._tokenizer.encode( text, allowed_special=self._allowed_special, disallowed_special=self._disallowed_special, ) start_idx = 0 cur_idx = min(start_idx + self._chunk_size, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx] while start_idx < len(input_ids): splits.append(self._tokenizer.decode(chunk_ids)) start_idx += self._chunk_size - self._chunk_overlap cur_idx = min(start_idx + self._chunk_size, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx] return splits class RecursiveCharacterTextSplitter(TextSplitter): """Implementation of splitting text that looks at characters. Recursively tries to split by different characters to find one that works. """ def __init__(self, separators: Optional[List[str]] = None, **kwargs: Any): """Create a new TextSplitter.""" super().__init__(**kwargs) self._separators = separators or ["\n\n", "\n", " ", ""] def split_text(self, text: str) -> List[str]: """Split incoming text and return chunks.""" final_chunks = [] # Get appropriate separator to use separator = self._separators[-1] for _s in self._separators: if _s == "": separator = _s break if _s in text: separator = _s break # Now that we have the separator, split the text if separator: splits = text.split(separator) else: splits = list(text) # Now go merging things, recursively splitting longer texts. _good_splits = [] for s in splits: if self._length_function(s) < self._chunk_size: _good_splits.append(s) else: if _good_splits: merged_text = self._merge_splits(_good_splits, separator) final_chunks.extend(merged_text) _good_splits = [] other_info = self.split_text(s) final_chunks.extend(other_info) if _good_splits: merged_text = self._merge_splits(_good_splits, separator) final_chunks.extend(merged_text) return final_chunks class NLTKTextSplitter(TextSplitter): """Implementation of splitting text that looks at sentences using NLTK.""" def __init__(self, separator: str = "\n\n", **kwargs: Any): """Initialize the NLTK splitter.""" super().__init__(**kwargs) try: from nltk.tokenize import sent_tokenize self._tokenizer = sent_tokenize except ImportError: raise ImportError( "NLTK is not installed, please install it with `pip install nltk`." ) self._separator = separator def split_text(self, text: str) -> List[str]: """Split incoming text and return chunks.""" # First we naively split the large input into a bunch of smaller ones. splits = self._tokenizer(text) return self._merge_splits(splits, self._separator) class SpacyTextSplitter(TextSplitter): """Implementation of splitting text that looks at sentences using Spacy.""" def __init__( self, separator: str = "\n\n", pipeline: str = "en_core_web_sm", **kwargs: Any ): """Initialize the spacy text splitter.""" super().__init__(**kwargs) try: import spacy except ImportError: raise ImportError( "Spacy is not installed, please install it with `pip install spacy`." ) self._tokenizer = spacy.load(pipeline) self._separator = separator def split_text(self, text: str) -> List[str]: """Split incoming text and return chunks.""" splits = (str(s) for s in self._tokenizer(text).sents) return self._merge_splits(splits, self._separator) class MarkdownTextSplitter(RecursiveCharacterTextSplitter): """Attempts to split the text along Markdown-formatted headings.""" def __init__(self, **kwargs: Any): """Initialize a MarkdownTextSplitter.""" separators = [ # First, try to split along Markdown headings (starting with level 2) "\n## ", "\n### ", "\n#### ", "\n##### ", "\n###### ", # Note the alternative syntax for headings (below) is not handled here # Heading level 2 # --------------- # End of code block "```\n\n", # Horizontal lines "\n\n***\n\n", "\n\n---\n\n", "\n\n___\n\n", # Note that this splitter doesn't handle horizontal lines defined # by *three or more* of ***, ---, or ___, but this is not handled "\n\n", "\n", " ", "", ] super().__init__(separators=separators, **kwargs) class PythonCodeTextSplitter(RecursiveCharacterTextSplitter): """Attempts to split the text along Python syntax.""" def __init__(self, **kwargs: Any): """Initialize a MarkdownTextSplitter.""" separators = [ # First, try to split along class definitions "\nclass ", "\ndef ", "\n\tdef ", # Now split by the normal type of lines "\n\n", "\n", " ", "", ] super().__init__(separators=separators, **kwargs)