mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
c5a46e7435
<!-- Thank you for contributing to LangChain! Your PR will appear in our release under the title you set. Please make sure it highlights your valuable contribution. Replace this with a description of the change, the issue it fixes (if applicable), and relevant context. List any dependencies required for this change. After you're done, someone will review your PR. They may suggest improvements. If no one reviews your PR within a few days, feel free to @-mention the same people again, as notifications can get lost. Finally, we'd love to show appreciation for your contribution - if you'd like us to shout you out on Twitter, please also include your handle! --> ## Add Solidity programming language support for code splitter. Twitter: @0xjord4n_ <!-- If you're adding a new integration, please include: 1. a test for the integration - favor unit tests that does not rely on network access. 2. an example notebook showing its use See contribution guidelines for more information on how to write tests, lint etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md --> #### Who can review? Tag maintainers/contributors who might be interested: @hwchase17 <!-- For a quicker response, figure out the right person to tag with @ @hwchase17 - project lead Tracing / Callbacks - @agola11 Async - @agola11 DataLoaders - @eyurtsev Models - @hwchase17 - @agola11 Agents / Tools / Toolkits - @hwchase17 VectorStores / Retrievers / Memory - @dev2049 -->
1048 lines
36 KiB
Python
1048 lines
36 KiB
Python
"""Functionality for splitting text."""
|
|
from __future__ import annotations
|
|
|
|
import copy
|
|
import logging
|
|
import re
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import (
|
|
AbstractSet,
|
|
Any,
|
|
Callable,
|
|
Collection,
|
|
Dict,
|
|
Iterable,
|
|
List,
|
|
Literal,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
Type,
|
|
TypedDict,
|
|
TypeVar,
|
|
Union,
|
|
cast,
|
|
)
|
|
|
|
from langchain.docstore.document import Document
|
|
from langchain.schema import BaseDocumentTransformer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
TS = TypeVar("TS", bound="TextSplitter")
|
|
|
|
|
|
def _split_text_with_regex(
|
|
text: str, separator: str, keep_separator: bool
|
|
) -> List[str]:
|
|
# Now that we have the separator, split the text
|
|
if separator:
|
|
if keep_separator:
|
|
# The parentheses in the pattern keep the delimiters in the result.
|
|
_splits = re.split(f"({separator})", text)
|
|
splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
|
|
if len(_splits) % 2 == 0:
|
|
splits += _splits[-1:]
|
|
splits = [_splits[0]] + splits
|
|
else:
|
|
splits = text.split(separator)
|
|
else:
|
|
splits = list(text)
|
|
return [s for s in splits if s != ""]
|
|
|
|
|
|
class TextSplitter(BaseDocumentTransformer, ABC):
|
|
"""Interface for splitting text into chunks."""
|
|
|
|
def __init__(
|
|
self,
|
|
chunk_size: int = 4000,
|
|
chunk_overlap: int = 200,
|
|
length_function: Callable[[str], int] = len,
|
|
keep_separator: bool = False,
|
|
add_start_index: bool = False,
|
|
) -> None:
|
|
"""Create a new TextSplitter.
|
|
|
|
Args:
|
|
chunk_size: Maximum size of chunks to return
|
|
chunk_overlap: Overlap in characters between chunks
|
|
length_function: Function that measures the length of given chunks
|
|
keep_separator: Whether or not to keep the separator in the chunks
|
|
add_start_index: If `True`, includes chunk's start index in metadata
|
|
"""
|
|
if chunk_overlap > chunk_size:
|
|
raise ValueError(
|
|
f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
|
|
f"({chunk_size}), should be smaller."
|
|
)
|
|
self._chunk_size = chunk_size
|
|
self._chunk_overlap = chunk_overlap
|
|
self._length_function = length_function
|
|
self._keep_separator = keep_separator
|
|
self._add_start_index = add_start_index
|
|
|
|
@abstractmethod
|
|
def split_text(self, text: str) -> List[str]:
|
|
"""Split text into multiple components."""
|
|
|
|
def create_documents(
|
|
self, texts: List[str], metadatas: Optional[List[dict]] = None
|
|
) -> List[Document]:
|
|
"""Create documents from a list of texts."""
|
|
_metadatas = metadatas or [{}] * len(texts)
|
|
documents = []
|
|
for i, text in enumerate(texts):
|
|
index = -1
|
|
for chunk in self.split_text(text):
|
|
metadata = copy.deepcopy(_metadatas[i])
|
|
if self._add_start_index:
|
|
index = text.find(chunk, index + 1)
|
|
metadata["start_index"] = index
|
|
new_doc = Document(page_content=chunk, metadata=metadata)
|
|
documents.append(new_doc)
|
|
return documents
|
|
|
|
def split_documents(self, documents: Iterable[Document]) -> List[Document]:
|
|
"""Split documents."""
|
|
texts, metadatas = [], []
|
|
for doc in documents:
|
|
texts.append(doc.page_content)
|
|
metadatas.append(doc.metadata)
|
|
return self.create_documents(texts, metadatas=metadatas)
|
|
|
|
def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
|
|
text = separator.join(docs)
|
|
text = text.strip()
|
|
if text == "":
|
|
return None
|
|
else:
|
|
return text
|
|
|
|
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
|
|
# We now want to combine these smaller pieces into medium size
|
|
# chunks to send to the LLM.
|
|
separator_len = self._length_function(separator)
|
|
|
|
docs = []
|
|
current_doc: List[str] = []
|
|
total = 0
|
|
for d in splits:
|
|
_len = self._length_function(d)
|
|
if (
|
|
total + _len + (separator_len if len(current_doc) > 0 else 0)
|
|
> self._chunk_size
|
|
):
|
|
if total > self._chunk_size:
|
|
logger.warning(
|
|
f"Created a chunk of size {total}, "
|
|
f"which is longer than the specified {self._chunk_size}"
|
|
)
|
|
if len(current_doc) > 0:
|
|
doc = self._join_docs(current_doc, separator)
|
|
if doc is not None:
|
|
docs.append(doc)
|
|
# Keep on popping if:
|
|
# - we have a larger chunk than in the chunk overlap
|
|
# - or if we still have any chunks and the length is long
|
|
while total > self._chunk_overlap or (
|
|
total + _len + (separator_len if len(current_doc) > 0 else 0)
|
|
> self._chunk_size
|
|
and total > 0
|
|
):
|
|
total -= self._length_function(current_doc[0]) + (
|
|
separator_len if len(current_doc) > 1 else 0
|
|
)
|
|
current_doc = current_doc[1:]
|
|
current_doc.append(d)
|
|
total += _len + (separator_len if len(current_doc) > 1 else 0)
|
|
doc = self._join_docs(current_doc, separator)
|
|
if doc is not None:
|
|
docs.append(doc)
|
|
return docs
|
|
|
|
@classmethod
|
|
def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
|
|
"""Text splitter that uses HuggingFace tokenizer to count length."""
|
|
try:
|
|
from transformers import PreTrainedTokenizerBase
|
|
|
|
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
|
raise ValueError(
|
|
"Tokenizer received was not an instance of PreTrainedTokenizerBase"
|
|
)
|
|
|
|
def _huggingface_tokenizer_length(text: str) -> int:
|
|
return len(tokenizer.encode(text))
|
|
|
|
except ImportError:
|
|
raise ValueError(
|
|
"Could not import transformers python package. "
|
|
"Please install it with `pip install transformers`."
|
|
)
|
|
return cls(length_function=_huggingface_tokenizer_length, **kwargs)
|
|
|
|
@classmethod
|
|
def from_tiktoken_encoder(
|
|
cls: Type[TS],
|
|
encoding_name: str = "gpt2",
|
|
model_name: Optional[str] = None,
|
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
|
**kwargs: Any,
|
|
) -> TS:
|
|
"""Text splitter that uses tiktoken encoder to count length."""
|
|
try:
|
|
import tiktoken
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Could not import tiktoken python package. "
|
|
"This is needed in order to calculate max_tokens_for_prompt. "
|
|
"Please install it with `pip install tiktoken`."
|
|
)
|
|
|
|
if model_name is not None:
|
|
enc = tiktoken.encoding_for_model(model_name)
|
|
else:
|
|
enc = tiktoken.get_encoding(encoding_name)
|
|
|
|
def _tiktoken_encoder(text: str) -> int:
|
|
return len(
|
|
enc.encode(
|
|
text,
|
|
allowed_special=allowed_special,
|
|
disallowed_special=disallowed_special,
|
|
)
|
|
)
|
|
|
|
if issubclass(cls, TokenTextSplitter):
|
|
extra_kwargs = {
|
|
"encoding_name": encoding_name,
|
|
"model_name": model_name,
|
|
"allowed_special": allowed_special,
|
|
"disallowed_special": disallowed_special,
|
|
}
|
|
kwargs = {**kwargs, **extra_kwargs}
|
|
|
|
return cls(length_function=_tiktoken_encoder, **kwargs)
|
|
|
|
def transform_documents(
|
|
self, documents: Sequence[Document], **kwargs: Any
|
|
) -> Sequence[Document]:
|
|
"""Transform sequence of documents by splitting them."""
|
|
return self.split_documents(list(documents))
|
|
|
|
async def atransform_documents(
|
|
self, documents: Sequence[Document], **kwargs: Any
|
|
) -> Sequence[Document]:
|
|
"""Asynchronously transform a sequence of documents by splitting them."""
|
|
raise NotImplementedError
|
|
|
|
|
|
class CharacterTextSplitter(TextSplitter):
|
|
"""Implementation of splitting text that looks at characters."""
|
|
|
|
def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
|
|
"""Create a new TextSplitter."""
|
|
super().__init__(**kwargs)
|
|
self._separator = separator
|
|
|
|
def split_text(self, text: str) -> List[str]:
|
|
"""Split incoming text and return chunks."""
|
|
# First we naively split the large input into a bunch of smaller ones.
|
|
splits = _split_text_with_regex(text, self._separator, self._keep_separator)
|
|
_separator = "" if self._keep_separator else self._separator
|
|
return self._merge_splits(splits, _separator)
|
|
|
|
|
|
class LineType(TypedDict):
|
|
metadata: Dict[str, str]
|
|
content: str
|
|
|
|
|
|
class HeaderType(TypedDict):
|
|
level: int
|
|
name: str
|
|
data: str
|
|
|
|
|
|
class MarkdownHeaderTextSplitter:
|
|
"""Implementation of splitting markdown files based on specified headers."""
|
|
|
|
def __init__(
|
|
self, headers_to_split_on: List[Tuple[str, str]], return_each_line: bool = False
|
|
):
|
|
"""Create a new MarkdownHeaderTextSplitter.
|
|
|
|
Args:
|
|
headers_to_split_on: Headers we want to track
|
|
return_each_line: Return each line w/ associated headers
|
|
"""
|
|
# Output line-by-line or aggregated into chunks w/ common headers
|
|
self.return_each_line = return_each_line
|
|
# Given the headers we want to split on,
|
|
# (e.g., "#, ##, etc") order by length
|
|
self.headers_to_split_on = sorted(
|
|
headers_to_split_on, key=lambda split: len(split[0]), reverse=True
|
|
)
|
|
|
|
def aggregate_lines_to_chunks(self, lines: List[LineType]) -> List[LineType]:
|
|
"""Combine lines with common metadata into chunks
|
|
Args:
|
|
lines: Line of text / associated header metadata
|
|
"""
|
|
aggregated_chunks: List[LineType] = []
|
|
|
|
for line in lines:
|
|
if (
|
|
aggregated_chunks
|
|
and aggregated_chunks[-1]["metadata"] == line["metadata"]
|
|
):
|
|
# If the last line in the aggregated list
|
|
# has the same metadata as the current line,
|
|
# append the current content to the last lines's content
|
|
aggregated_chunks[-1]["content"] += " \n" + line["content"]
|
|
else:
|
|
# Otherwise, append the current line to the aggregated list
|
|
aggregated_chunks.append(line)
|
|
return aggregated_chunks
|
|
|
|
def split_text(self, text: str) -> List[LineType]:
|
|
"""Split markdown file
|
|
Args:
|
|
text: Markdown file"""
|
|
|
|
# Split the input text by newline character ("\n").
|
|
lines = text.split("\n")
|
|
# Final output
|
|
lines_with_metadata: List[LineType] = []
|
|
# Content and metadata of the chunk currently being processed
|
|
current_content: List[str] = []
|
|
current_metadata: Dict[str, str] = {}
|
|
# Keep track of the nested header structure
|
|
# header_stack: List[Dict[str, Union[int, str]]] = []
|
|
header_stack: List[HeaderType] = []
|
|
initial_metadata: Dict[str, str] = {}
|
|
|
|
for line in lines:
|
|
stripped_line = line.strip()
|
|
# Check each line against each of the header types (e.g., #, ##)
|
|
for sep, name in self.headers_to_split_on:
|
|
# Check if line starts with a header that we intend to split on
|
|
if stripped_line.startswith(sep) and (
|
|
# Header with no text OR header is followed by space
|
|
# Both are valid conditions that sep is being used a header
|
|
len(stripped_line) == len(sep)
|
|
or stripped_line[len(sep)] == " "
|
|
):
|
|
# Ensure we are tracking the header as metadata
|
|
if name is not None:
|
|
# Get the current header level
|
|
current_header_level = sep.count("#")
|
|
|
|
# Pop out headers of lower or same level from the stack
|
|
while (
|
|
header_stack
|
|
and header_stack[-1]["level"] >= current_header_level
|
|
):
|
|
# We have encountered a new header
|
|
# at the same or higher level
|
|
popped_header = header_stack.pop()
|
|
# Clear the metadata for the
|
|
# popped header in initial_metadata
|
|
if popped_header["name"] in initial_metadata:
|
|
initial_metadata.pop(popped_header["name"])
|
|
|
|
# Push the current header to the stack
|
|
header: HeaderType = {
|
|
"level": current_header_level,
|
|
"name": name,
|
|
"data": stripped_line[len(sep) :].strip(),
|
|
}
|
|
header_stack.append(header)
|
|
# Update initial_metadata with the current header
|
|
initial_metadata[name] = header["data"]
|
|
|
|
# Add the previous line to the lines_with_metadata
|
|
# only if current_content is not empty
|
|
if current_content:
|
|
lines_with_metadata.append(
|
|
{
|
|
"content": "\n".join(current_content),
|
|
"metadata": current_metadata.copy(),
|
|
}
|
|
)
|
|
current_content.clear()
|
|
|
|
break
|
|
else:
|
|
if stripped_line:
|
|
current_content.append(stripped_line)
|
|
elif current_content:
|
|
lines_with_metadata.append(
|
|
{
|
|
"content": "\n".join(current_content),
|
|
"metadata": current_metadata.copy(),
|
|
}
|
|
)
|
|
current_content.clear()
|
|
|
|
current_metadata = initial_metadata.copy()
|
|
|
|
if current_content:
|
|
lines_with_metadata.append(
|
|
{"content": "\n".join(current_content), "metadata": current_metadata}
|
|
)
|
|
|
|
# lines_with_metadata has each line with associated header metadata
|
|
# aggregate these into chunks based on common metadata
|
|
if not self.return_each_line:
|
|
return self.aggregate_lines_to_chunks(lines_with_metadata)
|
|
else:
|
|
return lines_with_metadata
|
|
|
|
|
|
# should be in newer Python versions (3.10+)
|
|
# @dataclass(frozen=True, kw_only=True, slots=True)
|
|
@dataclass(frozen=True)
|
|
class Tokenizer:
|
|
chunk_overlap: int
|
|
tokens_per_chunk: int
|
|
decode: Callable[[list[int]], str]
|
|
encode: Callable[[str], List[int]]
|
|
|
|
|
|
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]:
|
|
"""Split incoming text and return chunks."""
|
|
splits: List[str] = []
|
|
input_ids = tokenizer.encode(text)
|
|
start_idx = 0
|
|
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
|
chunk_ids = input_ids[start_idx:cur_idx]
|
|
while start_idx < len(input_ids):
|
|
splits.append(tokenizer.decode(chunk_ids))
|
|
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
|
|
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
|
|
chunk_ids = input_ids[start_idx:cur_idx]
|
|
return splits
|
|
|
|
|
|
class TokenTextSplitter(TextSplitter):
|
|
"""Implementation of splitting text that looks at tokens."""
|
|
|
|
def __init__(
|
|
self,
|
|
encoding_name: str = "gpt2",
|
|
model_name: Optional[str] = None,
|
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Create a new TextSplitter."""
|
|
super().__init__(**kwargs)
|
|
try:
|
|
import tiktoken
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Could not import tiktoken python package. "
|
|
"This is needed in order to for TokenTextSplitter. "
|
|
"Please install it with `pip install tiktoken`."
|
|
)
|
|
|
|
if model_name is not None:
|
|
enc = tiktoken.encoding_for_model(model_name)
|
|
else:
|
|
enc = tiktoken.get_encoding(encoding_name)
|
|
self._tokenizer = enc
|
|
self._allowed_special = allowed_special
|
|
self._disallowed_special = disallowed_special
|
|
|
|
def split_text(self, text: str) -> List[str]:
|
|
def _encode(_text: str) -> List[int]:
|
|
return self._tokenizer.encode(
|
|
_text,
|
|
allowed_special=self._allowed_special,
|
|
disallowed_special=self._disallowed_special,
|
|
)
|
|
|
|
tokenizer = Tokenizer(
|
|
chunk_overlap=self._chunk_overlap,
|
|
tokens_per_chunk=self._chunk_size,
|
|
decode=self._tokenizer.decode,
|
|
encode=_encode,
|
|
)
|
|
|
|
return split_text_on_tokens(text=text, tokenizer=tokenizer)
|
|
|
|
|
|
class SentenceTransformersTokenTextSplitter(TextSplitter):
|
|
"""Implementation of splitting text that looks at tokens."""
|
|
|
|
def __init__(
|
|
self,
|
|
chunk_overlap: int = 50,
|
|
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
|
tokens_per_chunk: Optional[int] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Create a new TextSplitter."""
|
|
super().__init__(**kwargs, chunk_overlap=chunk_overlap)
|
|
|
|
try:
|
|
from sentence_transformers import SentenceTransformer
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Could not import sentence_transformer python package. "
|
|
"This is needed in order to for SentenceTransformersTokenTextSplitter. "
|
|
"Please install it with `pip install sentence-transformers`."
|
|
)
|
|
|
|
self.model_name = model_name
|
|
self._model = SentenceTransformer(self.model_name)
|
|
self.tokenizer = self._model.tokenizer
|
|
self._initialize_chunk_configuration(tokens_per_chunk=tokens_per_chunk)
|
|
|
|
def _initialize_chunk_configuration(
|
|
self, *, tokens_per_chunk: Optional[int]
|
|
) -> None:
|
|
self.maximum_tokens_per_chunk = cast(int, self._model.max_seq_length)
|
|
|
|
if tokens_per_chunk is None:
|
|
self.tokens_per_chunk = self.maximum_tokens_per_chunk
|
|
else:
|
|
self.tokens_per_chunk = tokens_per_chunk
|
|
|
|
if self.tokens_per_chunk > self.maximum_tokens_per_chunk:
|
|
raise ValueError(
|
|
f"The token limit of the models '{self.model_name}'"
|
|
f" is: {self.maximum_tokens_per_chunk}."
|
|
f" Argument tokens_per_chunk={self.tokens_per_chunk}"
|
|
f" > maximum token limit."
|
|
)
|
|
|
|
def split_text(self, text: str) -> List[str]:
|
|
def encode_strip_start_and_stop_token_ids(text: str) -> List[int]:
|
|
return self._encode(text)[1:-1]
|
|
|
|
tokenizer = Tokenizer(
|
|
chunk_overlap=self._chunk_overlap,
|
|
tokens_per_chunk=self.tokens_per_chunk,
|
|
decode=self.tokenizer.decode,
|
|
encode=encode_strip_start_and_stop_token_ids,
|
|
)
|
|
|
|
return split_text_on_tokens(text=text, tokenizer=tokenizer)
|
|
|
|
def count_tokens(self, *, text: str) -> int:
|
|
return len(self._encode(text))
|
|
|
|
_max_length_equal_32_bit_integer = 2**32
|
|
|
|
def _encode(self, text: str) -> List[int]:
|
|
token_ids_with_start_and_end_token_ids = self.tokenizer.encode(
|
|
text,
|
|
max_length=self._max_length_equal_32_bit_integer,
|
|
truncation="do_not_truncate",
|
|
)
|
|
return token_ids_with_start_and_end_token_ids
|
|
|
|
|
|
class Language(str, Enum):
|
|
CPP = "cpp"
|
|
GO = "go"
|
|
JAVA = "java"
|
|
JS = "js"
|
|
PHP = "php"
|
|
PROTO = "proto"
|
|
PYTHON = "python"
|
|
RST = "rst"
|
|
RUBY = "ruby"
|
|
RUST = "rust"
|
|
SCALA = "scala"
|
|
SWIFT = "swift"
|
|
MARKDOWN = "markdown"
|
|
LATEX = "latex"
|
|
HTML = "html"
|
|
SOL = "sol"
|
|
|
|
|
|
class RecursiveCharacterTextSplitter(TextSplitter):
|
|
"""Implementation of splitting text that looks at characters.
|
|
|
|
Recursively tries to split by different characters to find one
|
|
that works.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
separators: Optional[List[str]] = None,
|
|
keep_separator: bool = True,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
"""Create a new TextSplitter."""
|
|
super().__init__(keep_separator=keep_separator, **kwargs)
|
|
self._separators = separators or ["\n\n", "\n", " ", ""]
|
|
|
|
def _split_text(self, text: str, separators: List[str]) -> List[str]:
|
|
"""Split incoming text and return chunks."""
|
|
final_chunks = []
|
|
# Get appropriate separator to use
|
|
separator = separators[-1]
|
|
new_separators = []
|
|
for i, _s in enumerate(separators):
|
|
if _s == "":
|
|
separator = _s
|
|
break
|
|
if re.search(_s, text):
|
|
separator = _s
|
|
new_separators = separators[i + 1 :]
|
|
break
|
|
|
|
splits = _split_text_with_regex(text, separator, self._keep_separator)
|
|
# Now go merging things, recursively splitting longer texts.
|
|
_good_splits = []
|
|
_separator = "" if self._keep_separator else separator
|
|
for s in splits:
|
|
if self._length_function(s) < self._chunk_size:
|
|
_good_splits.append(s)
|
|
else:
|
|
if _good_splits:
|
|
merged_text = self._merge_splits(_good_splits, _separator)
|
|
final_chunks.extend(merged_text)
|
|
_good_splits = []
|
|
if not new_separators:
|
|
final_chunks.append(s)
|
|
else:
|
|
other_info = self._split_text(s, new_separators)
|
|
final_chunks.extend(other_info)
|
|
if _good_splits:
|
|
merged_text = self._merge_splits(_good_splits, _separator)
|
|
final_chunks.extend(merged_text)
|
|
return final_chunks
|
|
|
|
def split_text(self, text: str) -> List[str]:
|
|
return self._split_text(text, self._separators)
|
|
|
|
@classmethod
|
|
def from_language(
|
|
cls, language: Language, **kwargs: Any
|
|
) -> RecursiveCharacterTextSplitter:
|
|
separators = cls.get_separators_for_language(language)
|
|
return cls(separators=separators, **kwargs)
|
|
|
|
@staticmethod
|
|
def get_separators_for_language(language: Language) -> List[str]:
|
|
if language == Language.CPP:
|
|
return [
|
|
# Split along class definitions
|
|
"\nclass ",
|
|
# Split along function definitions
|
|
"\nvoid ",
|
|
"\nint ",
|
|
"\nfloat ",
|
|
"\ndouble ",
|
|
# Split along control flow statements
|
|
"\nif ",
|
|
"\nfor ",
|
|
"\nwhile ",
|
|
"\nswitch ",
|
|
"\ncase ",
|
|
# Split by the normal type of lines
|
|
"\n\n",
|
|
"\n",
|
|
" ",
|
|
"",
|
|
]
|
|
elif language == Language.GO:
|
|
return [
|
|
# Split along function definitions
|
|
"\nfunc ",
|
|
"\nvar ",
|
|
"\nconst ",
|
|
"\ntype ",
|
|
# Split along control flow statements
|
|
"\nif ",
|
|
"\nfor ",
|
|
"\nswitch ",
|
|
"\ncase ",
|
|
# Split by the normal type of lines
|
|
"\n\n",
|
|
"\n",
|
|
" ",
|
|
"",
|
|
]
|
|
elif language == Language.JAVA:
|
|
return [
|
|
# Split along class definitions
|
|
"\nclass ",
|
|
# Split along method definitions
|
|
"\npublic ",
|
|
"\nprotected ",
|
|
"\nprivate ",
|
|
"\nstatic ",
|
|
# Split along control flow statements
|
|
"\nif ",
|
|
"\nfor ",
|
|
"\nwhile ",
|
|
"\nswitch ",
|
|
"\ncase ",
|
|
# Split by the normal type of lines
|
|
"\n\n",
|
|
"\n",
|
|
" ",
|
|
"",
|
|
]
|
|
elif language == Language.JS:
|
|
return [
|
|
# Split along function definitions
|
|
"\nfunction ",
|
|
"\nconst ",
|
|
"\nlet ",
|
|
"\nvar ",
|
|
"\nclass ",
|
|
# Split along control flow statements
|
|
"\nif ",
|
|
"\nfor ",
|
|
"\nwhile ",
|
|
"\nswitch ",
|
|
"\ncase ",
|
|
"\ndefault ",
|
|
# Split by the normal type of lines
|
|
"\n\n",
|
|
"\n",
|
|
" ",
|
|
"",
|
|
]
|
|
elif language == Language.PHP:
|
|
return [
|
|
# Split along function definitions
|
|
"\nfunction ",
|
|
# Split along class definitions
|
|
"\nclass ",
|
|
# Split along control flow statements
|
|
"\nif ",
|
|
"\nforeach ",
|
|
"\nwhile ",
|
|
"\ndo ",
|
|
"\nswitch ",
|
|
"\ncase ",
|
|
# Split by the normal type of lines
|
|
"\n\n",
|
|
"\n",
|
|
" ",
|
|
"",
|
|
]
|
|
elif language == Language.PROTO:
|
|
return [
|
|
# Split along message definitions
|
|
"\nmessage ",
|
|
# Split along service definitions
|
|
"\nservice ",
|
|
# Split along enum definitions
|
|
"\nenum ",
|
|
# Split along option definitions
|
|
"\noption ",
|
|
# Split along import statements
|
|
"\nimport ",
|
|
# Split along syntax declarations
|
|
"\nsyntax ",
|
|
# Split by the normal type of lines
|
|
"\n\n",
|
|
"\n",
|
|
" ",
|
|
"",
|
|
]
|
|
elif language == Language.PYTHON:
|
|
return [
|
|
# First, try to split along class definitions
|
|
"\nclass ",
|
|
"\ndef ",
|
|
"\n\tdef ",
|
|
# Now split by the normal type of lines
|
|
"\n\n",
|
|
"\n",
|
|
" ",
|
|
"",
|
|
]
|
|
elif language == Language.RST:
|
|
return [
|
|
# Split along section titles
|
|
"\n=+\n",
|
|
"\n-+\n",
|
|
"\n\*+\n",
|
|
# Split along directive markers
|
|
"\n\n.. *\n\n",
|
|
# Split by the normal type of lines
|
|
"\n\n",
|
|
"\n",
|
|
" ",
|
|
"",
|
|
]
|
|
elif language == Language.RUBY:
|
|
return [
|
|
# Split along method definitions
|
|
"\ndef ",
|
|
"\nclass ",
|
|
# Split along control flow statements
|
|
"\nif ",
|
|
"\nunless ",
|
|
"\nwhile ",
|
|
"\nfor ",
|
|
"\ndo ",
|
|
"\nbegin ",
|
|
"\nrescue ",
|
|
# Split by the normal type of lines
|
|
"\n\n",
|
|
"\n",
|
|
" ",
|
|
"",
|
|
]
|
|
elif language == Language.RUST:
|
|
return [
|
|
# Split along function definitions
|
|
"\nfn ",
|
|
"\nconst ",
|
|
"\nlet ",
|
|
# Split along control flow statements
|
|
"\nif ",
|
|
"\nwhile ",
|
|
"\nfor ",
|
|
"\nloop ",
|
|
"\nmatch ",
|
|
"\nconst ",
|
|
# Split by the normal type of lines
|
|
"\n\n",
|
|
"\n",
|
|
" ",
|
|
"",
|
|
]
|
|
elif language == Language.SCALA:
|
|
return [
|
|
# Split along class definitions
|
|
"\nclass ",
|
|
"\nobject ",
|
|
# Split along method definitions
|
|
"\ndef ",
|
|
"\nval ",
|
|
"\nvar ",
|
|
# Split along control flow statements
|
|
"\nif ",
|
|
"\nfor ",
|
|
"\nwhile ",
|
|
"\nmatch ",
|
|
"\ncase ",
|
|
# Split by the normal type of lines
|
|
"\n\n",
|
|
"\n",
|
|
" ",
|
|
"",
|
|
]
|
|
elif language == Language.SWIFT:
|
|
return [
|
|
# Split along function definitions
|
|
"\nfunc ",
|
|
# Split along class definitions
|
|
"\nclass ",
|
|
"\nstruct ",
|
|
"\nenum ",
|
|
# Split along control flow statements
|
|
"\nif ",
|
|
"\nfor ",
|
|
"\nwhile ",
|
|
"\ndo ",
|
|
"\nswitch ",
|
|
"\ncase ",
|
|
# Split by the normal type of lines
|
|
"\n\n",
|
|
"\n",
|
|
" ",
|
|
"",
|
|
]
|
|
elif language == Language.MARKDOWN:
|
|
return [
|
|
# First, try to split along Markdown headings (starting with level 2)
|
|
"\n#{1,6} ",
|
|
# Note the alternative syntax for headings (below) is not handled here
|
|
# Heading level 2
|
|
# ---------------
|
|
# End of code block
|
|
"```\n",
|
|
# Horizontal lines
|
|
"\n\*\*\*+\n",
|
|
"\n---+\n",
|
|
"\n___+\n",
|
|
# Note that this splitter doesn't handle horizontal lines defined
|
|
# by *three or more* of ***, ---, or ___, but this is not handled
|
|
"\n\n",
|
|
"\n",
|
|
" ",
|
|
"",
|
|
]
|
|
elif language == Language.LATEX:
|
|
return [
|
|
# First, try to split along Latex sections
|
|
"\n\\\chapter{",
|
|
"\n\\\section{",
|
|
"\n\\\subsection{",
|
|
"\n\\\subsubsection{",
|
|
# Now split by environments
|
|
"\n\\\begin{enumerate}",
|
|
"\n\\\begin{itemize}",
|
|
"\n\\\begin{description}",
|
|
"\n\\\begin{list}",
|
|
"\n\\\begin{quote}",
|
|
"\n\\\begin{quotation}",
|
|
"\n\\\begin{verse}",
|
|
"\n\\\begin{verbatim}",
|
|
# Now split by math environments
|
|
"\n\\\begin{align}",
|
|
"$$",
|
|
"$",
|
|
# Now split by the normal type of lines
|
|
" ",
|
|
"",
|
|
]
|
|
elif language == Language.HTML:
|
|
return [
|
|
# First, try to split along HTML tags
|
|
"<body",
|
|
"<div",
|
|
"<p",
|
|
"<br",
|
|
"<li",
|
|
"<h1",
|
|
"<h2",
|
|
"<h3",
|
|
"<h4",
|
|
"<h5",
|
|
"<h6",
|
|
"<span",
|
|
"<table",
|
|
"<tr",
|
|
"<td",
|
|
"<th",
|
|
"<ul",
|
|
"<ol",
|
|
"<header",
|
|
"<footer",
|
|
"<nav",
|
|
# Head
|
|
"<head",
|
|
"<style",
|
|
"<script",
|
|
"<meta",
|
|
"<title",
|
|
"",
|
|
]
|
|
elif language == Language.SOL:
|
|
return [
|
|
# Split along compiler informations definitions
|
|
"\npragma ",
|
|
"\nusing ",
|
|
# Split along contract definitions
|
|
"\ncontract ",
|
|
"\ninterface ",
|
|
"\nlibrary ",
|
|
# Split along method definitions
|
|
"\nconstructor ",
|
|
"\ntype ",
|
|
"\nfunction ",
|
|
"\nevent ",
|
|
"\nmodifier ",
|
|
"\nerror ",
|
|
"\nstruct ",
|
|
"\nenum ",
|
|
# Split along control flow statements
|
|
"\nif ",
|
|
"\nfor ",
|
|
"\nwhile ",
|
|
"\ndo while ",
|
|
"\nassembly ",
|
|
# Split by the normal type of lines
|
|
"\n\n",
|
|
"\n",
|
|
" ",
|
|
"",
|
|
]
|
|
else:
|
|
raise ValueError(
|
|
f"Language {language} is not supported! "
|
|
f"Please choose from {list(Language)}"
|
|
)
|
|
|
|
|
|
class NLTKTextSplitter(TextSplitter):
|
|
"""Implementation of splitting text that looks at sentences using NLTK."""
|
|
|
|
def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
|
|
"""Initialize the NLTK splitter."""
|
|
super().__init__(**kwargs)
|
|
try:
|
|
from nltk.tokenize import sent_tokenize
|
|
|
|
self._tokenizer = sent_tokenize
|
|
except ImportError:
|
|
raise ImportError(
|
|
"NLTK is not installed, please install it with `pip install nltk`."
|
|
)
|
|
self._separator = separator
|
|
|
|
def split_text(self, text: str) -> List[str]:
|
|
"""Split incoming text and return chunks."""
|
|
# First we naively split the large input into a bunch of smaller ones.
|
|
splits = self._tokenizer(text)
|
|
return self._merge_splits(splits, self._separator)
|
|
|
|
|
|
class SpacyTextSplitter(TextSplitter):
|
|
"""Implementation of splitting text that looks at sentences using Spacy."""
|
|
|
|
def __init__(
|
|
self, separator: str = "\n\n", pipeline: str = "en_core_web_sm", **kwargs: Any
|
|
) -> None:
|
|
"""Initialize the spacy text splitter."""
|
|
super().__init__(**kwargs)
|
|
try:
|
|
import spacy
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Spacy is not installed, please install it with `pip install spacy`."
|
|
)
|
|
self._tokenizer = spacy.load(pipeline)
|
|
self._separator = separator
|
|
|
|
def split_text(self, text: str) -> List[str]:
|
|
"""Split incoming text and return chunks."""
|
|
splits = (str(s) for s in self._tokenizer(text).sents)
|
|
return self._merge_splits(splits, self._separator)
|
|
|
|
|
|
# For backwards compatibility
|
|
class PythonCodeTextSplitter(RecursiveCharacterTextSplitter):
|
|
"""Attempts to split the text along Python syntax."""
|
|
|
|
def __init__(self, **kwargs: Any) -> None:
|
|
"""Initialize a PythonCodeTextSplitter."""
|
|
separators = self.get_separators_for_language(Language.PYTHON)
|
|
super().__init__(separators=separators, **kwargs)
|
|
|
|
|
|
class MarkdownTextSplitter(RecursiveCharacterTextSplitter):
|
|
"""Attempts to split the text along Markdown-formatted headings."""
|
|
|
|
def __init__(self, **kwargs: Any) -> None:
|
|
"""Initialize a MarkdownTextSplitter."""
|
|
separators = self.get_separators_for_language(Language.MARKDOWN)
|
|
super().__init__(separators=separators, **kwargs)
|
|
|
|
|
|
class LatexTextSplitter(RecursiveCharacterTextSplitter):
|
|
"""Attempts to split the text along Latex-formatted layout elements."""
|
|
|
|
def __init__(self, **kwargs: Any) -> None:
|
|
"""Initialize a LatexTextSplitter."""
|
|
separators = self.get_separators_for_language(Language.LATEX)
|
|
super().__init__(separators=separators, **kwargs)
|