forked from Archives/langchain
c7db9febb0
Trying to call `ChatOpenAI.get_num_tokens_from_messages` returns the following error for the newly announced models `gpt-3.5-turbo-0613` and `gpt-4-0613`: ``` NotImplementedError: get_num_tokens_from_messages() is not presently implemented for model gpt-3.5-turbo-0613.See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens. ``` This adds support for counting tokens for those models, by counting tokens the same way they're counted for the previous versions of `gpt-3.5-turbo` and `gpt-4`. #### reviewers - @hwchase17 - @agola11
495 lines
19 KiB
Python
495 lines
19 KiB
Python
"""OpenAI chat wrapper."""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import sys
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
List,
|
|
Mapping,
|
|
Optional,
|
|
Tuple,
|
|
Union,
|
|
)
|
|
|
|
from pydantic import Extra, Field, root_validator
|
|
from tenacity import (
|
|
before_sleep_log,
|
|
retry,
|
|
retry_if_exception_type,
|
|
stop_after_attempt,
|
|
wait_exponential,
|
|
)
|
|
|
|
from langchain.callbacks.manager import (
|
|
AsyncCallbackManagerForLLMRun,
|
|
CallbackManagerForLLMRun,
|
|
)
|
|
from langchain.chat_models.base import BaseChatModel
|
|
from langchain.schema import (
|
|
AIMessage,
|
|
BaseMessage,
|
|
ChatGeneration,
|
|
ChatMessage,
|
|
ChatResult,
|
|
FunctionMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
)
|
|
from langchain.utils import get_from_dict_or_env
|
|
|
|
if TYPE_CHECKING:
|
|
import tiktoken
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _import_tiktoken() -> Any:
|
|
try:
|
|
import tiktoken
|
|
except ImportError:
|
|
raise ValueError(
|
|
"Could not import tiktoken python package. "
|
|
"This is needed in order to calculate get_token_ids. "
|
|
"Please install it with `pip install tiktoken`."
|
|
)
|
|
return tiktoken
|
|
|
|
|
|
def _create_retry_decorator(llm: ChatOpenAI) -> Callable[[Any], Any]:
|
|
import openai
|
|
|
|
min_seconds = 1
|
|
max_seconds = 60
|
|
# Wait 2^x * 1 second between each retry starting with
|
|
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
|
return retry(
|
|
reraise=True,
|
|
stop=stop_after_attempt(llm.max_retries),
|
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
|
retry=(
|
|
retry_if_exception_type(openai.error.Timeout)
|
|
| retry_if_exception_type(openai.error.APIError)
|
|
| retry_if_exception_type(openai.error.APIConnectionError)
|
|
| retry_if_exception_type(openai.error.RateLimitError)
|
|
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
|
),
|
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
)
|
|
|
|
|
|
async def acompletion_with_retry(llm: ChatOpenAI, **kwargs: Any) -> Any:
|
|
"""Use tenacity to retry the async completion call."""
|
|
retry_decorator = _create_retry_decorator(llm)
|
|
|
|
@retry_decorator
|
|
async def _completion_with_retry(**kwargs: Any) -> Any:
|
|
# Use OpenAI's async api https://github.com/openai/openai-python#async-api
|
|
return await llm.client.acreate(**kwargs)
|
|
|
|
return await _completion_with_retry(**kwargs)
|
|
|
|
|
|
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|
role = _dict["role"]
|
|
if role == "user":
|
|
return HumanMessage(content=_dict["content"])
|
|
elif role == "assistant":
|
|
content = _dict["content"] or "" # OpenAI returns None for tool invocations
|
|
if _dict.get("function_call"):
|
|
additional_kwargs = {"function_call": dict(_dict["function_call"])}
|
|
else:
|
|
additional_kwargs = {}
|
|
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
|
elif role == "system":
|
|
return SystemMessage(content=_dict["content"])
|
|
else:
|
|
return ChatMessage(content=_dict["content"], role=role)
|
|
|
|
|
|
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|
if isinstance(message, ChatMessage):
|
|
message_dict = {"role": message.role, "content": message.content}
|
|
elif isinstance(message, HumanMessage):
|
|
message_dict = {"role": "user", "content": message.content}
|
|
elif isinstance(message, AIMessage):
|
|
message_dict = {"role": "assistant", "content": message.content}
|
|
if "function_call" in message.additional_kwargs:
|
|
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
|
elif isinstance(message, SystemMessage):
|
|
message_dict = {"role": "system", "content": message.content}
|
|
elif isinstance(message, FunctionMessage):
|
|
message_dict = {
|
|
"role": "function",
|
|
"content": message.content,
|
|
"name": message.name,
|
|
}
|
|
else:
|
|
raise ValueError(f"Got unknown type {message}")
|
|
if "name" in message.additional_kwargs:
|
|
message_dict["name"] = message.additional_kwargs["name"]
|
|
return message_dict
|
|
|
|
|
|
class ChatOpenAI(BaseChatModel):
|
|
"""Wrapper around OpenAI Chat large language models.
|
|
|
|
To use, you should have the ``openai`` python package installed, and the
|
|
environment variable ``OPENAI_API_KEY`` set with your API key.
|
|
|
|
Any parameters that are valid to be passed to the openai.create call can be passed
|
|
in, even if not explicitly saved on this class.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain.chat_models import ChatOpenAI
|
|
openai = ChatOpenAI(model_name="gpt-3.5-turbo")
|
|
"""
|
|
|
|
@property
|
|
def lc_serializable(self) -> bool:
|
|
return True
|
|
|
|
client: Any #: :meta private:
|
|
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
|
|
"""Model name to use."""
|
|
temperature: float = 0.7
|
|
"""What sampling temperature to use."""
|
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
|
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
|
openai_api_key: Optional[str] = None
|
|
"""Base URL path for API requests,
|
|
leave blank if not using a proxy or service emulator."""
|
|
openai_api_base: Optional[str] = None
|
|
openai_organization: Optional[str] = None
|
|
# to support explicit proxy for OpenAI
|
|
openai_proxy: Optional[str] = None
|
|
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
|
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
|
max_retries: int = 6
|
|
"""Maximum number of retries to make when generating."""
|
|
streaming: bool = False
|
|
"""Whether to stream the results or not."""
|
|
n: int = 1
|
|
"""Number of chat completions to generate for each prompt."""
|
|
max_tokens: Optional[int] = None
|
|
"""Maximum number of tokens to generate."""
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.ignore
|
|
allow_population_by_field_name = True
|
|
|
|
@root_validator(pre=True)
|
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Build extra kwargs from additional params that were passed in."""
|
|
all_required_field_names = cls.all_required_field_names()
|
|
extra = values.get("model_kwargs", {})
|
|
for field_name in list(values):
|
|
if field_name in extra:
|
|
raise ValueError(f"Found {field_name} supplied twice.")
|
|
if field_name not in all_required_field_names:
|
|
logger.warning(
|
|
f"""WARNING! {field_name} is not default parameter.
|
|
{field_name} was transferred to model_kwargs.
|
|
Please confirm that {field_name} is what you intended."""
|
|
)
|
|
extra[field_name] = values.pop(field_name)
|
|
|
|
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
|
if invalid_model_kwargs:
|
|
raise ValueError(
|
|
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
|
f"Instead they were passed in as part of `model_kwargs` parameter."
|
|
)
|
|
|
|
values["model_kwargs"] = extra
|
|
return values
|
|
|
|
@root_validator()
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
"""Validate that api key and python package exists in environment."""
|
|
values["openai_api_key"] = get_from_dict_or_env(
|
|
values, "openai_api_key", "OPENAI_API_KEY"
|
|
)
|
|
values["openai_organization"] = get_from_dict_or_env(
|
|
values,
|
|
"openai_organization",
|
|
"OPENAI_ORGANIZATION",
|
|
default="",
|
|
)
|
|
values["openai_api_base"] = get_from_dict_or_env(
|
|
values,
|
|
"openai_api_base",
|
|
"OPENAI_API_BASE",
|
|
default="",
|
|
)
|
|
values["openai_proxy"] = get_from_dict_or_env(
|
|
values,
|
|
"openai_proxy",
|
|
"OPENAI_PROXY",
|
|
default="",
|
|
)
|
|
try:
|
|
import openai
|
|
|
|
except ImportError:
|
|
raise ValueError(
|
|
"Could not import openai python package. "
|
|
"Please install it with `pip install openai`."
|
|
)
|
|
try:
|
|
values["client"] = openai.ChatCompletion
|
|
except AttributeError:
|
|
raise ValueError(
|
|
"`openai` has no `ChatCompletion` attribute, this is likely "
|
|
"due to an old version of the openai package. Try upgrading it "
|
|
"with `pip install --upgrade openai`."
|
|
)
|
|
if values["n"] < 1:
|
|
raise ValueError("n must be at least 1.")
|
|
if values["n"] > 1 and values["streaming"]:
|
|
raise ValueError("n must be 1 when streaming.")
|
|
return values
|
|
|
|
@property
|
|
def _default_params(self) -> Dict[str, Any]:
|
|
"""Get the default parameters for calling OpenAI API."""
|
|
return {
|
|
"model": self.model_name,
|
|
"request_timeout": self.request_timeout,
|
|
"max_tokens": self.max_tokens,
|
|
"stream": self.streaming,
|
|
"n": self.n,
|
|
"temperature": self.temperature,
|
|
**self.model_kwargs,
|
|
}
|
|
|
|
def _create_retry_decorator(self) -> Callable[[Any], Any]:
|
|
import openai
|
|
|
|
min_seconds = 1
|
|
max_seconds = 60
|
|
# Wait 2^x * 1 second between each retry starting with
|
|
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
|
return retry(
|
|
reraise=True,
|
|
stop=stop_after_attempt(self.max_retries),
|
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
|
retry=(
|
|
retry_if_exception_type(openai.error.Timeout)
|
|
| retry_if_exception_type(openai.error.APIError)
|
|
| retry_if_exception_type(openai.error.APIConnectionError)
|
|
| retry_if_exception_type(openai.error.RateLimitError)
|
|
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
|
),
|
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
)
|
|
|
|
def completion_with_retry(self, **kwargs: Any) -> Any:
|
|
"""Use tenacity to retry the completion call."""
|
|
retry_decorator = self._create_retry_decorator()
|
|
|
|
@retry_decorator
|
|
def _completion_with_retry(**kwargs: Any) -> Any:
|
|
return self.client.create(**kwargs)
|
|
|
|
return _completion_with_retry(**kwargs)
|
|
|
|
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
|
overall_token_usage: dict = {}
|
|
for output in llm_outputs:
|
|
if output is None:
|
|
# Happens in streaming
|
|
continue
|
|
token_usage = output["token_usage"]
|
|
for k, v in token_usage.items():
|
|
if k in overall_token_usage:
|
|
overall_token_usage[k] += v
|
|
else:
|
|
overall_token_usage[k] = v
|
|
return {"token_usage": overall_token_usage, "model_name": self.model_name}
|
|
|
|
def _generate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
|
params = {**params, **kwargs}
|
|
if self.streaming:
|
|
inner_completion = ""
|
|
role = "assistant"
|
|
params["stream"] = True
|
|
function_call: Optional[dict] = None
|
|
for stream_resp in self.completion_with_retry(
|
|
messages=message_dicts, **params
|
|
):
|
|
role = stream_resp["choices"][0]["delta"].get("role", role)
|
|
token = stream_resp["choices"][0]["delta"].get("content") or ""
|
|
inner_completion += token
|
|
_function_call = stream_resp["choices"][0]["delta"].get("function_call")
|
|
if _function_call:
|
|
if function_call is None:
|
|
function_call = _function_call
|
|
else:
|
|
function_call["arguments"] += _function_call["arguments"]
|
|
if run_manager:
|
|
run_manager.on_llm_new_token(token)
|
|
message = _convert_dict_to_message(
|
|
{
|
|
"content": inner_completion,
|
|
"role": role,
|
|
"function_call": function_call,
|
|
}
|
|
)
|
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
|
response = self.completion_with_retry(messages=message_dicts, **params)
|
|
return self._create_chat_result(response)
|
|
|
|
def _create_message_dicts(
|
|
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
|
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
|
params = dict(self._invocation_params)
|
|
if stop is not None:
|
|
if "stop" in params:
|
|
raise ValueError("`stop` found in both the input and default params.")
|
|
params["stop"] = stop
|
|
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
|
return message_dicts, params
|
|
|
|
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
|
generations = []
|
|
for res in response["choices"]:
|
|
message = _convert_dict_to_message(res["message"])
|
|
gen = ChatGeneration(message=message)
|
|
generations.append(gen)
|
|
llm_output = {"token_usage": response["usage"], "model_name": self.model_name}
|
|
return ChatResult(generations=generations, llm_output=llm_output)
|
|
|
|
async def _agenerate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
|
params = {**params, **kwargs}
|
|
if self.streaming:
|
|
inner_completion = ""
|
|
role = "assistant"
|
|
params["stream"] = True
|
|
async for stream_resp in await acompletion_with_retry(
|
|
self, messages=message_dicts, **params
|
|
):
|
|
role = stream_resp["choices"][0]["delta"].get("role", role)
|
|
token = stream_resp["choices"][0]["delta"].get("content", "")
|
|
inner_completion += token
|
|
if run_manager:
|
|
await run_manager.on_llm_new_token(token)
|
|
message = _convert_dict_to_message(
|
|
{"content": inner_completion, "role": role}
|
|
)
|
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
|
else:
|
|
response = await acompletion_with_retry(
|
|
self, messages=message_dicts, **params
|
|
)
|
|
return self._create_chat_result(response)
|
|
|
|
@property
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
"""Get the identifying parameters."""
|
|
return {**{"model_name": self.model_name}, **self._default_params}
|
|
|
|
@property
|
|
def _invocation_params(self) -> Mapping[str, Any]:
|
|
"""Get the parameters used to invoke the model."""
|
|
openai_creds: Dict[str, Any] = {
|
|
"api_key": self.openai_api_key,
|
|
"api_base": self.openai_api_base,
|
|
"organization": self.openai_organization,
|
|
"model": self.model_name,
|
|
}
|
|
if self.openai_proxy:
|
|
import openai
|
|
|
|
openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501
|
|
return {**openai_creds, **self._default_params}
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return type of chat model."""
|
|
return "openai-chat"
|
|
|
|
def _get_encoding_model(self) -> Tuple[str, tiktoken.Encoding]:
|
|
tiktoken_ = _import_tiktoken()
|
|
model = self.model_name
|
|
if model == "gpt-3.5-turbo":
|
|
# gpt-3.5-turbo may change over time.
|
|
# Returning num tokens assuming gpt-3.5-turbo-0301.
|
|
model = "gpt-3.5-turbo-0301"
|
|
elif model == "gpt-4":
|
|
# gpt-4 may change over time.
|
|
# Returning num tokens assuming gpt-4-0314.
|
|
model = "gpt-4-0314"
|
|
# Returns the number of tokens used by a list of messages.
|
|
try:
|
|
encoding = tiktoken_.encoding_for_model(model)
|
|
except KeyError:
|
|
logger.warning("Warning: model not found. Using cl100k_base encoding.")
|
|
model = "cl100k_base"
|
|
encoding = tiktoken_.get_encoding(model)
|
|
return model, encoding
|
|
|
|
def get_token_ids(self, text: str) -> List[int]:
|
|
"""Get the tokens present in the text with tiktoken package."""
|
|
# tiktoken NOT supported for Python 3.7 or below
|
|
if sys.version_info[1] <= 7:
|
|
return super().get_token_ids(text)
|
|
_, encoding_model = self._get_encoding_model()
|
|
return encoding_model.encode(text)
|
|
|
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
|
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
|
|
|
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
|
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
|
if sys.version_info[1] <= 7:
|
|
return super().get_num_tokens_from_messages(messages)
|
|
model, encoding = self._get_encoding_model()
|
|
if model.startswith("gpt-3.5-turbo"):
|
|
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
|
tokens_per_message = 4
|
|
# if there's a name, the role is omitted
|
|
tokens_per_name = -1
|
|
elif model.startswith("gpt-4"):
|
|
tokens_per_message = 3
|
|
tokens_per_name = 1
|
|
else:
|
|
raise NotImplementedError(
|
|
f"get_num_tokens_from_messages() is not presently implemented "
|
|
f"for model {model}."
|
|
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
|
"information on how messages are converted to tokens."
|
|
)
|
|
num_tokens = 0
|
|
messages_dict = [_convert_message_to_dict(m) for m in messages]
|
|
for message in messages_dict:
|
|
num_tokens += tokens_per_message
|
|
for key, value in message.items():
|
|
num_tokens += len(encoding.encode(value))
|
|
if key == "name":
|
|
num_tokens += tokens_per_name
|
|
# every reply is primed with <im_start>assistant
|
|
num_tokens += 3
|
|
return num_tokens
|