diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index f5dbafec..b98a7f5d 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -5,11 +5,14 @@ import logging import sys import warnings from typing import ( + AbstractSet, Any, Callable, + Collection, Dict, Generator, List, + Literal, Mapping, Optional, Set, @@ -150,6 +153,10 @@ class BaseOpenAI(BaseLLM): """Maximum number of retries to make when generating.""" streaming: bool = False """Whether to stream the results or not.""" + allowed_special: Union[Literal["all"], AbstractSet[str]] = set() + """Set of special tokens that are allowed。""" + disallowed_special: Union[Literal["all"], Collection[str]] = "all" + """Set of special tokens that are not allowed。""" def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore """Initialize the OpenAI object.""" @@ -449,7 +456,11 @@ class BaseOpenAI(BaseLLM): enc = tiktoken.encoding_for_model(self.model_name) - tokenized_text = enc.encode(text) + tokenized_text = enc.encode( + text, + allowed_special=self.allowed_special, + disallowed_special=self.disallowed_special, + ) # calculate the number of tokens in the encoded text return len(tokenized_text) @@ -602,6 +613,10 @@ class OpenAIChat(BaseLLM): """Series of messages for Chat input.""" streaming: bool = False """Whether to stream the results or not.""" + allowed_special: Union[Literal["all"], AbstractSet[str]] = set() + """Set of special tokens that are allowed。""" + disallowed_special: Union[Literal["all"], Collection[str]] = "all" + """Set of special tokens that are not allowed。""" class Config: """Configuration for this pydantic object.""" @@ -785,7 +800,11 @@ class OpenAIChat(BaseLLM): enc = tiktoken.encoding_for_model("gpt-3.5-turbo") # encode the text using the GPT-3.5-Turbo encoder - tokenized_text = enc.encode(text) + tokenized_text = enc.encode( + text, + allowed_special=self.allowed_special, + disallowed_special=self.disallowed_special, + ) # calculate the number of tokens in the encoded text return len(tokenized_text)