|
|
|
@ -5,6 +5,8 @@ import logging
|
|
|
|
|
import sys
|
|
|
|
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
|
|
|
|
|
|
|
|
|
|
import openai
|
|
|
|
|
|
|
|
|
|
from langchain.chat_models.base import BaseChatModel
|
|
|
|
|
from langchain.schema import (
|
|
|
|
|
AIMessage,
|
|
|
|
@ -26,6 +28,9 @@ from tenacity import (
|
|
|
|
|
wait_exponential,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
from env import settings
|
|
|
|
|
from ansi import ANSI, Color, Style
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_retry_decorator(llm: ChatOpenAI) -> Callable[[Any], Any]:
|
|
|
|
|
import openai
|
|
|
|
@ -98,6 +103,23 @@ def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
|
|
|
|
|
return ChatResult(generations=generations)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelNotFoundException(Exception):
|
|
|
|
|
"""Exception raised when the model is not found."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, model_name: str):
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
super().__init__(
|
|
|
|
|
f"\n\nModel {ANSI(self.model_name).to(Color.red())} does not exist.\nMake sure if you have access to the model.\n"
|
|
|
|
|
+ f"You can the model name with the environment variable {ANSI('MODEL_NAME').to(Style.bold())} on {ANSI('.env').to(Style.bold())}.\n"
|
|
|
|
|
+ "\nex) MODEL_NAME=gpt-4\n"
|
|
|
|
|
+ ANSI(
|
|
|
|
|
"\nLooks like you don't have access to gpt-4 yet. Try using `gpt-3.5-turbo`."
|
|
|
|
|
if self.model_name == "gpt-4"
|
|
|
|
|
else ""
|
|
|
|
|
).to(Style.italic())
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatOpenAI(BaseChatModel, BaseModel):
|
|
|
|
|
"""Wrapper around OpenAI Chat large language models.
|
|
|
|
|
|
|
|
|
@ -115,7 +137,7 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
client: Any #: :meta private:
|
|
|
|
|
model_name: str = "gpt-4"
|
|
|
|
|
model_name: str = settings["MODEL_NAME"]
|
|
|
|
|
"""Model name to use."""
|
|
|
|
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
|
|
|
@ -134,6 +156,14 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
|
|
|
|
|
|
|
|
|
extra = Extra.ignore
|
|
|
|
|
|
|
|
|
|
def check_access(self) -> None:
|
|
|
|
|
"""Check that the user has access to the model."""
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
openai.Engine.retrieve(self.model_name)
|
|
|
|
|
except openai.error.InvalidRequestError:
|
|
|
|
|
raise ModelNotFoundException(self.model_name)
|
|
|
|
|
|
|
|
|
|
@root_validator(pre=True)
|
|
|
|
|
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
"""Build extra kwargs from additional params that were passed in."""
|
|
|
|
|