|
|
|
@ -3,7 +3,17 @@ from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
from typing import Any, Callable, Iterable, List, Optional
|
|
|
|
|
from typing import (
|
|
|
|
|
AbstractSet,
|
|
|
|
|
Any,
|
|
|
|
|
Callable,
|
|
|
|
|
Collection,
|
|
|
|
|
Iterable,
|
|
|
|
|
List,
|
|
|
|
|
Literal,
|
|
|
|
|
Optional,
|
|
|
|
|
Union,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
from langchain.docstore.document import Document
|
|
|
|
|
|
|
|
|
@ -114,7 +124,11 @@ class TextSplitter(ABC):
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_tiktoken_encoder(
|
|
|
|
|
cls, encoding_name: str = "gpt2", **kwargs: Any
|
|
|
|
|
cls,
|
|
|
|
|
encoding_name: str = "gpt2",
|
|
|
|
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
|
|
|
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> TextSplitter:
|
|
|
|
|
"""Text splitter that uses tiktoken encoder to count length."""
|
|
|
|
|
try:
|
|
|
|
@ -125,11 +139,19 @@ class TextSplitter(ABC):
|
|
|
|
|
"This is needed in order to calculate max_tokens_for_prompt. "
|
|
|
|
|
"Please it install it with `pip install tiktoken`."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# create a GPT-3 encoder instance
|
|
|
|
|
enc = tiktoken.get_encoding(encoding_name)
|
|
|
|
|
|
|
|
|
|
def _tiktoken_encoder(text: str) -> int:
|
|
|
|
|
return len(enc.encode(text))
|
|
|
|
|
def _tiktoken_encoder(text: str, **kwargs: Any) -> int:
|
|
|
|
|
return len(
|
|
|
|
|
enc.encode(
|
|
|
|
|
text,
|
|
|
|
|
allowed_special=allowed_special,
|
|
|
|
|
disallowed_special=disallowed_special,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return cls(length_function=_tiktoken_encoder, **kwargs)
|
|
|
|
|
|
|
|
|
@ -155,7 +177,13 @@ class CharacterTextSplitter(TextSplitter):
|
|
|
|
|
class TokenTextSplitter(TextSplitter):
|
|
|
|
|
"""Implementation of splitting text that looks at tokens."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, encoding_name: str = "gpt2", **kwargs: Any):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
encoding_name: str = "gpt2",
|
|
|
|
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
|
|
|
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
):
|
|
|
|
|
"""Create a new TextSplitter."""
|
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
|
try:
|
|
|
|
@ -168,11 +196,17 @@ class TokenTextSplitter(TextSplitter):
|
|
|
|
|
)
|
|
|
|
|
# create a GPT-3 encoder instance
|
|
|
|
|
self._tokenizer = tiktoken.get_encoding(encoding_name)
|
|
|
|
|
self._allowed_special = allowed_special
|
|
|
|
|
self._disallowed_special = disallowed_special
|
|
|
|
|
|
|
|
|
|
def split_text(self, text: str) -> List[str]:
|
|
|
|
|
"""Split incoming text and return chunks."""
|
|
|
|
|
splits = []
|
|
|
|
|
input_ids = self._tokenizer.encode(text)
|
|
|
|
|
input_ids = self._tokenizer.encode(
|
|
|
|
|
text,
|
|
|
|
|
allowed_special=self._allowed_special,
|
|
|
|
|
disallowed_special=self._disallowed_special,
|
|
|
|
|
)
|
|
|
|
|
start_idx = 0
|
|
|
|
|
cur_idx = min(start_idx + self._chunk_size, len(input_ids))
|
|
|
|
|
chunk_ids = input_ids[start_idx:cur_idx]
|
|
|
|
|