mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
ba54d36787
Co-authored-by: James Briggs <35938317+jamescalam@users.noreply.github.com> Co-authored-by: Harrison Chase <harrisonchase@Harrisons-MBP.attlocal.net>
312 lines
11 KiB
Python
312 lines
11 KiB
Python
"""Functionality for splitting text."""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from typing import (
|
|
AbstractSet,
|
|
Any,
|
|
Callable,
|
|
Collection,
|
|
Iterable,
|
|
List,
|
|
Literal,
|
|
Optional,
|
|
Union,
|
|
)
|
|
|
|
from langchain.docstore.document import Document
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
class TextSplitter(ABC):
|
|
"""Interface for splitting text into chunks."""
|
|
|
|
def __init__(
|
|
self,
|
|
chunk_size: int = 4000,
|
|
chunk_overlap: int = 200,
|
|
length_function: Callable[[str], int] = len,
|
|
):
|
|
"""Create a new TextSplitter."""
|
|
if chunk_overlap > chunk_size:
|
|
raise ValueError(
|
|
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
|
f"({chunk_size}), should be smaller."
|
|
)
|
|
self._chunk_size = chunk_size
|
|
self._chunk_overlap = chunk_overlap
|
|
self._length_function = length_function
|
|
|
|
@abstractmethod
|
|
def split_text(self, text: str) -> List[str]:
|
|
"""Split text into multiple components."""
|
|
|
|
def create_documents(
|
|
self, texts: List[str], metadatas: Optional[List[dict]] = None
|
|
) -> List[Document]:
|
|
"""Create documents from a list of texts."""
|
|
_metadatas = metadatas or [{}] * len(texts)
|
|
documents = []
|
|
for i, text in enumerate(texts):
|
|
for chunk in self.split_text(text):
|
|
documents.append(Document(page_content=chunk, metadata=_metadatas[i]))
|
|
return documents
|
|
|
|
def split_documents(self, documents: List[Document]) -> List[Document]:
|
|
"""Split documents."""
|
|
texts = [doc.page_content for doc in documents]
|
|
metadatas = [doc.metadata for doc in documents]
|
|
return self.create_documents(texts, metadatas)
|
|
|
|
def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
|
|
text = separator.join(docs)
|
|
text = text.strip()
|
|
if text == "":
|
|
return None
|
|
else:
|
|
return text
|
|
|
|
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
|
|
# We now want to combine these smaller pieces into medium size
|
|
# chunks to send to the LLM.
|
|
docs = []
|
|
current_doc: List[str] = []
|
|
total = 0
|
|
for d in splits:
|
|
_len = self._length_function(d)
|
|
if total + _len >= self._chunk_size:
|
|
if total > self._chunk_size:
|
|
logger.warning(
|
|
f"Created a chunk of size {total}, "
|
|
f"which is longer than the specified {self._chunk_size}"
|
|
)
|
|
if len(current_doc) > 0:
|
|
doc = self._join_docs(current_doc, separator)
|
|
if doc is not None:
|
|
docs.append(doc)
|
|
# Keep on popping if:
|
|
# - we have a larger chunk than in the chunk overlap
|
|
# - or if we still have any chunks and the length is long
|
|
while total > self._chunk_overlap or (
|
|
total + _len > self._chunk_size and total > 0
|
|
):
|
|
total -= self._length_function(current_doc[0])
|
|
current_doc = current_doc[1:]
|
|
current_doc.append(d)
|
|
total += _len
|
|
doc = self._join_docs(current_doc, separator)
|
|
if doc is not None:
|
|
docs.append(doc)
|
|
return docs
|
|
|
|
@classmethod
|
|
def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
|
|
"""Text splitter that uses HuggingFace tokenizer to count length."""
|
|
try:
|
|
from transformers import PreTrainedTokenizerBase
|
|
|
|
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
|
raise ValueError(
|
|
"Tokenizer received was not an instance of PreTrainedTokenizerBase"
|
|
)
|
|
|
|
def _huggingface_tokenizer_length(text: str) -> int:
|
|
return len(tokenizer.encode(text))
|
|
|
|
except ImportError:
|
|
raise ValueError(
|
|
"Could not import transformers python package. "
|
|
"Please it install it with `pip install transformers`."
|
|
)
|
|
return cls(length_function=_huggingface_tokenizer_length, **kwargs)
|
|
|
|
@classmethod
|
|
def from_tiktoken_encoder(
|
|
cls,
|
|
encoding_name: str = "gpt2",
|
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
|
**kwargs: Any,
|
|
) -> TextSplitter:
|
|
"""Text splitter that uses tiktoken encoder to count length."""
|
|
try:
|
|
import tiktoken
|
|
except ImportError:
|
|
raise ValueError(
|
|
"Could not import tiktoken python package. "
|
|
"This is needed in order to calculate max_tokens_for_prompt. "
|
|
"Please it install it with `pip install tiktoken`."
|
|
)
|
|
|
|
# create a GPT-3 encoder instance
|
|
enc = tiktoken.get_encoding(encoding_name)
|
|
|
|
def _tiktoken_encoder(text: str, **kwargs: Any) -> int:
|
|
return len(
|
|
enc.encode(
|
|
text,
|
|
allowed_special=allowed_special,
|
|
disallowed_special=disallowed_special,
|
|
**kwargs,
|
|
)
|
|
)
|
|
|
|
return cls(length_function=_tiktoken_encoder, **kwargs)
|
|
|
|
|
|
class CharacterTextSplitter(TextSplitter):
|
|
"""Implementation of splitting text that looks at characters."""
|
|
|
|
def __init__(self, separator: str = "\n\n", **kwargs: Any):
|
|
"""Create a new TextSplitter."""
|
|
super().__init__(**kwargs)
|
|
self._separator = separator
|
|
|
|
def split_text(self, text: str) -> List[str]:
|
|
"""Split incoming text and return chunks."""
|
|
# First we naively split the large input into a bunch of smaller ones.
|
|
if self._separator:
|
|
splits = text.split(self._separator)
|
|
else:
|
|
splits = list(text)
|
|
return self._merge_splits(splits, self._separator)
|
|
|
|
|
|
class TokenTextSplitter(TextSplitter):
|
|
"""Implementation of splitting text that looks at tokens."""
|
|
|
|
def __init__(
|
|
self,
|
|
encoding_name: str = "gpt2",
|
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
|
**kwargs: Any,
|
|
):
|
|
"""Create a new TextSplitter."""
|
|
super().__init__(**kwargs)
|
|
try:
|
|
import tiktoken
|
|
except ImportError:
|
|
raise ValueError(
|
|
"Could not import tiktoken python package. "
|
|
"This is needed in order to for TokenTextSplitter. "
|
|
"Please it install it with `pip install tiktoken`."
|
|
)
|
|
# create a GPT-3 encoder instance
|
|
self._tokenizer = tiktoken.get_encoding(encoding_name)
|
|
self._allowed_special = allowed_special
|
|
self._disallowed_special = disallowed_special
|
|
|
|
def split_text(self, text: str) -> List[str]:
|
|
"""Split incoming text and return chunks."""
|
|
splits = []
|
|
input_ids = self._tokenizer.encode(
|
|
text,
|
|
allowed_special=self._allowed_special,
|
|
disallowed_special=self._disallowed_special,
|
|
)
|
|
start_idx = 0
|
|
cur_idx = min(start_idx + self._chunk_size, len(input_ids))
|
|
chunk_ids = input_ids[start_idx:cur_idx]
|
|
while start_idx < len(input_ids):
|
|
splits.append(self._tokenizer.decode(chunk_ids))
|
|
start_idx += self._chunk_size - self._chunk_overlap
|
|
cur_idx = min(start_idx + self._chunk_size, len(input_ids))
|
|
chunk_ids = input_ids[start_idx:cur_idx]
|
|
return splits
|
|
|
|
|
|
class RecursiveCharacterTextSplitter(TextSplitter):
|
|
"""Implementation of splitting text that looks at characters.
|
|
|
|
Recursively tries to split by different characters to find one
|
|
that works.
|
|
"""
|
|
|
|
def __init__(self, separators: Optional[List[str]] = None, **kwargs: Any):
|
|
"""Create a new TextSplitter."""
|
|
super().__init__(**kwargs)
|
|
self._separators = separators or ["\n\n", "\n", " ", ""]
|
|
|
|
def split_text(self, text: str) -> List[str]:
|
|
"""Split incoming text and return chunks."""
|
|
final_chunks = []
|
|
# Get appropriate separator to use
|
|
separator = self._separators[-1]
|
|
for _s in self._separators:
|
|
if _s == "":
|
|
separator = _s
|
|
break
|
|
if _s in text:
|
|
separator = _s
|
|
break
|
|
# Now that we have the separator, split the text
|
|
if separator:
|
|
splits = text.split(separator)
|
|
else:
|
|
splits = list(text)
|
|
# Now go merging things, recursively splitting longer texts.
|
|
_good_splits = []
|
|
for s in splits:
|
|
if len(s) < self._chunk_size:
|
|
_good_splits.append(s)
|
|
else:
|
|
if _good_splits:
|
|
merged_text = self._merge_splits(_good_splits, separator)
|
|
final_chunks.extend(merged_text)
|
|
_good_splits = []
|
|
other_info = self.split_text(s)
|
|
final_chunks.extend(other_info)
|
|
if _good_splits:
|
|
merged_text = self._merge_splits(_good_splits, separator)
|
|
final_chunks.extend(merged_text)
|
|
return final_chunks
|
|
|
|
|
|
class NLTKTextSplitter(TextSplitter):
|
|
"""Implementation of splitting text that looks at sentences using NLTK."""
|
|
|
|
def __init__(self, separator: str = "\n\n", **kwargs: Any):
|
|
"""Initialize the NLTK splitter."""
|
|
super().__init__(**kwargs)
|
|
try:
|
|
from nltk.tokenize import sent_tokenize
|
|
|
|
self._tokenizer = sent_tokenize
|
|
except ImportError:
|
|
raise ImportError(
|
|
"NLTK is not installed, please install it with `pip install nltk`."
|
|
)
|
|
self._separator = separator
|
|
|
|
def split_text(self, text: str) -> List[str]:
|
|
"""Split incoming text and return chunks."""
|
|
# First we naively split the large input into a bunch of smaller ones.
|
|
splits = self._tokenizer(text)
|
|
return self._merge_splits(splits, self._separator)
|
|
|
|
|
|
class SpacyTextSplitter(TextSplitter):
|
|
"""Implementation of splitting text that looks at sentences using Spacy."""
|
|
|
|
def __init__(
|
|
self, separator: str = "\n\n", pipeline: str = "en_core_web_sm", **kwargs: Any
|
|
):
|
|
"""Initialize the spacy text splitter."""
|
|
super().__init__(**kwargs)
|
|
try:
|
|
import spacy
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Spacy is not installed, please install it with `pip install spacy`."
|
|
)
|
|
self._tokenizer = spacy.load(pipeline)
|
|
self._separator = separator
|
|
|
|
def split_text(self, text: str) -> List[str]:
|
|
"""Split incoming text and return chunks."""
|
|
splits = (str(s) for s in self._tokenizer(text).sents)
|
|
return self._merge_splits(splits, self._separator)
|