|
|
|
@ -14,6 +14,8 @@ from typing import (
|
|
|
|
|
Literal,
|
|
|
|
|
Optional,
|
|
|
|
|
Sequence,
|
|
|
|
|
Type,
|
|
|
|
|
TypeVar,
|
|
|
|
|
Union,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -22,6 +24,8 @@ from langchain.schema import BaseDocumentTransformer
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
TS = TypeVar("TS", bound="TextSplitter")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TextSplitter(BaseDocumentTransformer, ABC):
|
|
|
|
|
"""Interface for splitting text into chunks."""
|
|
|
|
@ -139,13 +143,13 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_tiktoken_encoder(
|
|
|
|
|
cls,
|
|
|
|
|
cls: Type[TS],
|
|
|
|
|
encoding_name: str = "gpt2",
|
|
|
|
|
model_name: Optional[str] = None,
|
|
|
|
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
|
|
|
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> TextSplitter:
|
|
|
|
|
) -> TS:
|
|
|
|
|
"""Text splitter that uses tiktoken encoder to count length."""
|
|
|
|
|
try:
|
|
|
|
|
import tiktoken
|
|
|
|
@ -161,16 +165,24 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
|
|
|
|
else:
|
|
|
|
|
enc = tiktoken.get_encoding(encoding_name)
|
|
|
|
|
|
|
|
|
|
def _tiktoken_encoder(text: str, **kwargs: Any) -> int:
|
|
|
|
|
def _tiktoken_encoder(text: str) -> int:
|
|
|
|
|
return len(
|
|
|
|
|
enc.encode(
|
|
|
|
|
text,
|
|
|
|
|
allowed_special=allowed_special,
|
|
|
|
|
disallowed_special=disallowed_special,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if issubclass(cls, TokenTextSplitter):
|
|
|
|
|
extra_kwargs = {
|
|
|
|
|
"encoding_name": encoding_name,
|
|
|
|
|
"model_name": model_name,
|
|
|
|
|
"allowed_special": allowed_special,
|
|
|
|
|
"disallowed_special": disallowed_special,
|
|
|
|
|
}
|
|
|
|
|
kwargs = {**kwargs, **extra_kwargs}
|
|
|
|
|
|
|
|
|
|
return cls(length_function=_tiktoken_encoder, **kwargs)
|
|
|
|
|
|
|
|
|
|
def transform_documents(
|
|
|
|
|