diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index 71f656e8..eb818436 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -3,7 +3,17 @@ from __future__ import annotations import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Iterable, List, Optional +from typing import ( + AbstractSet, + Any, + Callable, + Collection, + Iterable, + List, + Literal, + Optional, + Union, +) from langchain.docstore.document import Document @@ -114,7 +124,11 @@ class TextSplitter(ABC): @classmethod def from_tiktoken_encoder( - cls, encoding_name: str = "gpt2", **kwargs: Any + cls, + encoding_name: str = "gpt2", + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, ) -> TextSplitter: """Text splitter that uses tiktoken encoder to count length.""" try: @@ -125,11 +139,19 @@ class TextSplitter(ABC): "This is needed in order to calculate max_tokens_for_prompt. " "Please it install it with `pip install tiktoken`." ) + # create a GPT-3 encoder instance enc = tiktoken.get_encoding(encoding_name) - def _tiktoken_encoder(text: str) -> int: - return len(enc.encode(text)) + def _tiktoken_encoder(text: str, **kwargs: Any) -> int: + return len( + enc.encode( + text, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + **kwargs, + ) + ) return cls(length_function=_tiktoken_encoder, **kwargs) @@ -155,7 +177,13 @@ class CharacterTextSplitter(TextSplitter): class TokenTextSplitter(TextSplitter): """Implementation of splitting text that looks at tokens.""" - def __init__(self, encoding_name: str = "gpt2", **kwargs: Any): + def __init__( + self, + encoding_name: str = "gpt2", + allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), + disallowed_special: Union[Literal["all"], Collection[str]] = "all", + **kwargs: Any, + ): """Create a new TextSplitter.""" super().__init__(**kwargs) try: @@ -168,11 +196,17 @@ class TokenTextSplitter(TextSplitter): ) # create a GPT-3 encoder instance self._tokenizer = tiktoken.get_encoding(encoding_name) + self._allowed_special = allowed_special + self._disallowed_special = disallowed_special def split_text(self, text: str) -> List[str]: """Split incoming text and return chunks.""" splits = [] - input_ids = self._tokenizer.encode(text) + input_ids = self._tokenizer.encode( + text, + allowed_special=self._allowed_special, + disallowed_special=self._disallowed_special, + ) start_idx = 0 cur_idx = min(start_idx + self._chunk_size, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx]