From d54c88aa2140f27c36fa18375f942e5b239799ee Mon Sep 17 00:00:00 2001 From: Carmen Sam <53148658+samcarmen@users.noreply.github.com> Date: Wed, 19 Apr 2023 00:34:08 +0800 Subject: [PATCH] Add allowed and disallowed special arguments to BaseOpenAI (#3012) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Background This PR fixes this error when there are special tokens when querying the chain: ``` Encountered text corresponding to disallowed special token '<|endofprompt|>'. If you want this text to be encoded as a special token, pass it to `allowed_special`, e.g. `allowed_special={'<|endofprompt|>', ...}`. If you want this text to be encoded as normal text, disable the check for this token by passing `disallowed_special=(enc.special_tokens_set - {'<|endofprompt|>'})`. To disable this check for all special tokens, pass `disallowed_special=()`. ``` Refer to the code snippet below, it breaks in the chain line. ``` chain = ConversationalRetrievalChain.from_llm( ChatOpenAI(openai_api_key=OPENAI_API_KEY), retriever=vectorstore.as_retriever(), qa_prompt=prompt, condense_question_prompt=condense_prompt, ) answer = chain({"question": f"{question}"}) ``` However `ChatOpenAI` class is not accepting `allowed_special` and `disallowed_special` at the moment so they cannot be passed to the `encode()` in `get_num_tokens` method to avoid the errors. ## Change - Add `allowed_special` and `disallowed_special` attributes to `BaseOpenAI` class. - Pass in `allowed_special` and `disallowed_special` as arguments of `encode()` in tiktoken. --------- Co-authored-by: samcarmen <“carmen.samkahman@gmail.com”> --- langchain/llms/openai.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index f5dbafec..b98a7f5d 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -5,11 +5,14 @@ import logging import sys import warnings from typing import ( + AbstractSet, Any, Callable, + Collection, Dict, Generator, List, + Literal, Mapping, Optional, Set, @@ -150,6 +153,10 @@ class BaseOpenAI(BaseLLM): """Maximum number of retries to make when generating.""" streaming: bool = False """Whether to stream the results or not.""" + allowed_special: Union[Literal["all"], AbstractSet[str]] = set() + """Set of special tokens that are allowed。""" + disallowed_special: Union[Literal["all"], Collection[str]] = "all" + """Set of special tokens that are not allowed。""" def __new__(cls, **data: Any) -> Union[OpenAIChat, BaseOpenAI]: # type: ignore """Initialize the OpenAI object.""" @@ -449,7 +456,11 @@ class BaseOpenAI(BaseLLM): enc = tiktoken.encoding_for_model(self.model_name) - tokenized_text = enc.encode(text) + tokenized_text = enc.encode( + text, + allowed_special=self.allowed_special, + disallowed_special=self.disallowed_special, + ) # calculate the number of tokens in the encoded text return len(tokenized_text) @@ -602,6 +613,10 @@ class OpenAIChat(BaseLLM): """Series of messages for Chat input.""" streaming: bool = False """Whether to stream the results or not.""" + allowed_special: Union[Literal["all"], AbstractSet[str]] = set() + """Set of special tokens that are allowed。""" + disallowed_special: Union[Literal["all"], Collection[str]] = "all" + """Set of special tokens that are not allowed。""" class Config: """Configuration for this pydantic object.""" @@ -785,7 +800,11 @@ class OpenAIChat(BaseLLM): enc = tiktoken.encoding_for_model("gpt-3.5-turbo") # encode the text using the GPT-3.5-Turbo encoder - tokenized_text = enc.encode(text) + tokenized_text = enc.encode( + text, + allowed_special=self.allowed_special, + disallowed_special=self.disallowed_special, + ) # calculate the number of tokens in the encoded text return len(tokenized_text)