You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/text_splitter.py

440 lines
15 KiB
Python

"""Functionality for splitting text."""
from __future__ import annotations
import copy
import logging
from abc import ABC, abstractmethod
from typing import (
AbstractSet,
Any,
Callable,
Collection,
Iterable,
List,
Literal,
Optional,
Sequence,
Type,
TypeVar,
Union,
)
from langchain.docstore.document import Document
from langchain.schema import BaseDocumentTransformer
logger = logging.getLogger(__name__)
TS = TypeVar("TS", bound="TextSplitter")
class TextSplitter(BaseDocumentTransformer, ABC):
"""Interface for splitting text into chunks."""
def __init__(
self,
chunk_size: int = 4000,
chunk_overlap: int = 200,
length_function: Callable[[str], int] = len,
):
"""Create a new TextSplitter."""
if chunk_overlap > chunk_size:
raise ValueError(
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
f"({chunk_size}), should be smaller."
)
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._length_function = length_function
@abstractmethod
def split_text(self, text: str) -> List[str]:
"""Split text into multiple components."""
def create_documents(
self, texts: List[str], metadatas: Optional[List[dict]] = None
) -> List[Document]:
"""Create documents from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)
documents = []
for i, text in enumerate(texts):
for chunk in self.split_text(text):
new_doc = Document(
page_content=chunk, metadata=copy.deepcopy(_metadatas[i])
)
documents.append(new_doc)
return documents
def split_documents(self, documents: List[Document]) -> List[Document]:
"""Split documents."""
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return self.create_documents(texts, metadatas=metadatas)
def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
text = separator.join(docs)
text = text.strip()
if text == "":
return None
else:
return text
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
# We now want to combine these smaller pieces into medium size
# chunks to send to the LLM.
separator_len = self._length_function(separator)
docs = []
current_doc: List[str] = []
total = 0
for d in splits:
_len = self._length_function(d)
if (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
):
if total > self._chunk_size:
logger.warning(
f"Created a chunk of size {total}, "
f"which is longer than the specified {self._chunk_size}"
)
if len(current_doc) > 0:
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
# Keep on popping if:
# - we have a larger chunk than in the chunk overlap
# - or if we still have any chunks and the length is long
while total > self._chunk_overlap or (
total + _len + (separator_len if len(current_doc) > 0 else 0)
> self._chunk_size
and total > 0
):
total -= self._length_function(current_doc[0]) + (
separator_len if len(current_doc) > 1 else 0
)
current_doc = current_doc[1:]
current_doc.append(d)
total += _len + (separator_len if len(current_doc) > 1 else 0)
doc = self._join_docs(current_doc, separator)
if doc is not None:
docs.append(doc)
return docs
@classmethod
def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
"""Text splitter that uses HuggingFace tokenizer to count length."""
try:
from transformers import PreTrainedTokenizerBase
if not isinstance(tokenizer, PreTrainedTokenizerBase):
raise ValueError(
"Tokenizer received was not an instance of PreTrainedTokenizerBase"
)
def _huggingface_tokenizer_length(text: str) -> int:
return len(tokenizer.encode(text))
except ImportError:
raise ValueError(
"Could not import transformers python package. "
"Please install it with `pip install transformers`."
)
return cls(length_function=_huggingface_tokenizer_length, **kwargs)
@classmethod
def from_tiktoken_encoder(
cls: Type[TS],
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
) -> TS:
"""Text splitter that uses tiktoken encoder to count length."""
try:
import tiktoken
except ImportError:
raise ValueError(
"Could not import tiktoken python package. "
"This is needed in order to calculate max_tokens_for_prompt. "
"Please install it with `pip install tiktoken`."
)
if model_name is not None:
enc = tiktoken.encoding_for_model(model_name)
else:
enc = tiktoken.get_encoding(encoding_name)
def _tiktoken_encoder(text: str) -> int:
return len(
enc.encode(
text,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)
)
if issubclass(cls, TokenTextSplitter):
extra_kwargs = {
"encoding_name": encoding_name,
"model_name": model_name,
"allowed_special": allowed_special,
"disallowed_special": disallowed_special,
}
kwargs = {**kwargs, **extra_kwargs}
return cls(length_function=_tiktoken_encoder, **kwargs)
def transform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Transform sequence of documents by splitting them."""
return self.split_documents(list(documents))
async def atransform_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> Sequence[Document]:
"""Asynchronously transform a sequence of documents by splitting them."""
raise NotImplementedError
class CharacterTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at characters."""
def __init__(self, separator: str = "\n\n", **kwargs: Any):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._separator = separator
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
# First we naively split the large input into a bunch of smaller ones.
if self._separator:
splits = text.split(self._separator)
else:
splits = list(text)
return self._merge_splits(splits, self._separator)
class TokenTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at tokens."""
def __init__(
self,
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
try:
import tiktoken
except ImportError:
raise ValueError(
"Could not import tiktoken python package. "
"This is needed in order to for TokenTextSplitter. "
"Please install it with `pip install tiktoken`."
)
if model_name is not None:
enc = tiktoken.encoding_for_model(model_name)
else:
enc = tiktoken.get_encoding(encoding_name)
self._tokenizer = enc
self._allowed_special = allowed_special
self._disallowed_special = disallowed_special
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
splits = []
input_ids = self._tokenizer.encode(
text,
allowed_special=self._allowed_special,
disallowed_special=self._disallowed_special,
)
start_idx = 0
cur_idx = min(start_idx + self._chunk_size, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
while start_idx < len(input_ids):
splits.append(self._tokenizer.decode(chunk_ids))
start_idx += self._chunk_size - self._chunk_overlap
cur_idx = min(start_idx + self._chunk_size, len(input_ids))
chunk_ids = input_ids[start_idx:cur_idx]
return splits
class RecursiveCharacterTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at characters.
Recursively tries to split by different characters to find one
that works.
"""
def __init__(self, separators: Optional[List[str]] = None, **kwargs: Any):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
self._separators = separators or ["\n\n", "\n", " ", ""]
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
final_chunks = []
# Get appropriate separator to use
separator = self._separators[-1]
for _s in self._separators:
if _s == "":
separator = _s
break
if _s in text:
separator = _s
break
# Now that we have the separator, split the text
if separator:
splits = text.split(separator)
else:
splits = list(text)
# Now go merging things, recursively splitting longer texts.
_good_splits = []
for s in splits:
if self._length_function(s) < self._chunk_size:
_good_splits.append(s)
else:
if _good_splits:
merged_text = self._merge_splits(_good_splits, separator)
final_chunks.extend(merged_text)
_good_splits = []
other_info = self.split_text(s)
final_chunks.extend(other_info)
if _good_splits:
merged_text = self._merge_splits(_good_splits, separator)
final_chunks.extend(merged_text)
return final_chunks
class NLTKTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at sentences using NLTK."""
def __init__(self, separator: str = "\n\n", **kwargs: Any):
"""Initialize the NLTK splitter."""
super().__init__(**kwargs)
try:
from nltk.tokenize import sent_tokenize
self._tokenizer = sent_tokenize
except ImportError:
raise ImportError(
"NLTK is not installed, please install it with `pip install nltk`."
)
self._separator = separator
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
# First we naively split the large input into a bunch of smaller ones.
splits = self._tokenizer(text)
return self._merge_splits(splits, self._separator)
class SpacyTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at sentences using Spacy."""
def __init__(
self, separator: str = "\n\n", pipeline: str = "en_core_web_sm", **kwargs: Any
):
"""Initialize the spacy text splitter."""
super().__init__(**kwargs)
try:
import spacy
except ImportError:
raise ImportError(
"Spacy is not installed, please install it with `pip install spacy`."
)
self._tokenizer = spacy.load(pipeline)
self._separator = separator
def split_text(self, text: str) -> List[str]:
"""Split incoming text and return chunks."""
splits = (str(s) for s in self._tokenizer(text).sents)
return self._merge_splits(splits, self._separator)
class MarkdownTextSplitter(RecursiveCharacterTextSplitter):
"""Attempts to split the text along Markdown-formatted headings."""
def __init__(self, **kwargs: Any):
"""Initialize a MarkdownTextSplitter."""
separators = [
# First, try to split along Markdown headings (starting with level 2)
"\n## ",
"\n### ",
"\n#### ",
"\n##### ",
"\n###### ",
# Note the alternative syntax for headings (below) is not handled here
# Heading level 2
# ---------------
# End of code block
"```\n\n",
# Horizontal lines
"\n\n***\n\n",
"\n\n---\n\n",
"\n\n___\n\n",
# Note that this splitter doesn't handle horizontal lines defined
# by *three or more* of ***, ---, or ___, but this is not handled
"\n\n",
"\n",
" ",
"",
]
super().__init__(separators=separators, **kwargs)
class LatexTextSplitter(RecursiveCharacterTextSplitter):
"""Attempts to split the text along Latex-formatted layout elements."""
def __init__(self, **kwargs: Any):
"""Initialize a LatexTextSplitter."""
separators = [
# First, try to split along Latex sections
"\n\\chapter{",
"\n\\section{",
"\n\\subsection{",
"\n\\subsubsection{",
# Now split by environments
"\n\\begin{enumerate}",
"\n\\begin{itemize}",
"\n\\begin{description}",
"\n\\begin{list}",
"\n\\begin{quote}",
"\n\\begin{quotation}",
"\n\\begin{verse}",
"\n\\begin{verbatim}",
## Now split by math environments
"\n\\begin{align}",
"$$",
"$",
# Now split by the normal type of lines
" ",
"",
]
super().__init__(separators=separators, **kwargs)
class PythonCodeTextSplitter(RecursiveCharacterTextSplitter):
"""Attempts to split the text along Python syntax."""
def __init__(self, **kwargs: Any):
"""Initialize a PythonCodeTextSplitter."""
separators = [
# First, try to split along class definitions
"\nclass ",
"\ndef ",
"\n\tdef ",
# Now split by the normal type of lines
"\n\n",
"\n",
" ",
"",
]
super().__init__(separators=separators, **kwargs)