From 30595bc9c65c1df2b03737a0c837052405c72d7b Mon Sep 17 00:00:00 2001 From: ChungHwan Han <51526347+hanchchch@users.noreply.github.com> Date: Fri, 7 Apr 2023 20:52:55 +0900 Subject: [PATCH] feat: model_name as an env, and check accessability to the model (#22) --- README.md | 4 ++++ core/agents/builder.py | 1 + core/agents/llm.py | 32 +++++++++++++++++++++++++++++++- env.py | 1 + 4 files changed, 37 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 6374bb9..98f2576 100644 --- a/README.md +++ b/README.md @@ -13,10 +13,13 @@ https://user-images.githubusercontent.com/51526347/230061897-b3479405-8ebd-45ab- EVAL Making a UI for itself #### [EVAL-BOT](https://github.com/eval-bot) + EVAL's self-managed github account. EVAL does everything except for signup and bio setting. ### Examples + [Here](examples/) is an example. + ### EVAL's FEATURE 1. **Multimodal Conversation** @@ -81,6 +84,7 @@ Each optional env has default value, so you don't need to set unless you want to - `SERVER` - server address (default: http://localhost:8000) - `LOG_LEVEL` - INFO | DEBUG (default: INFO) - `BOT_NAME` - give it a name! (default: Orca) +- `MODEL_NAME` - model name for GPT (default: gpt-4) **For More Tools** diff --git a/core/agents/builder.py b/core/agents/builder.py index cf7c5fc..1f45aee 100644 --- a/core/agents/builder.py +++ b/core/agents/builder.py @@ -19,6 +19,7 @@ class AgentBuilder: def build_llm(self): self.llm = ChatOpenAI(temperature=0) + self.llm.check_access() def build_parser(self): self.parser = EvalOutputParser() diff --git a/core/agents/llm.py b/core/agents/llm.py index 040c1a0..1fd6ab1 100644 --- a/core/agents/llm.py +++ b/core/agents/llm.py @@ -5,6 +5,8 @@ import logging import sys from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple +import openai + from langchain.chat_models.base import BaseChatModel from langchain.schema import ( AIMessage, @@ -26,6 +28,9 @@ from tenacity import ( wait_exponential, ) +from env import settings +from ansi import ANSI, Color, Style + def _create_retry_decorator(llm: ChatOpenAI) -> Callable[[Any], Any]: import openai @@ -98,6 +103,23 @@ def _create_chat_result(response: Mapping[str, Any]) -> ChatResult: return ChatResult(generations=generations) +class ModelNotFoundException(Exception): + """Exception raised when the model is not found.""" + + def __init__(self, model_name: str): + self.model_name = model_name + super().__init__( + f"\n\nModel {ANSI(self.model_name).to(Color.red())} does not exist.\nMake sure if you have access to the model.\n" + + f"You can the model name with the environment variable {ANSI('MODEL_NAME').to(Style.bold())} on {ANSI('.env').to(Style.bold())}.\n" + + "\nex) MODEL_NAME=gpt-4\n" + + ANSI( + "\nLooks like you don't have access to gpt-4 yet. Try using `gpt-3.5-turbo`." + if self.model_name == "gpt-4" + else "" + ).to(Style.italic()) + ) + + class ChatOpenAI(BaseChatModel, BaseModel): """Wrapper around OpenAI Chat large language models. @@ -115,7 +137,7 @@ class ChatOpenAI(BaseChatModel, BaseModel): """ client: Any #: :meta private: - model_name: str = "gpt-4" + model_name: str = settings["MODEL_NAME"] """Model name to use.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" @@ -134,6 +156,14 @@ class ChatOpenAI(BaseChatModel, BaseModel): extra = Extra.ignore + def check_access(self) -> None: + """Check that the user has access to the model.""" + + try: + openai.Engine.retrieve(self.model_name) + except openai.error.InvalidRequestError: + raise ModelNotFoundException(self.model_name) + @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" diff --git a/env.py b/env.py index 071ecef..061c286 100644 --- a/env.py +++ b/env.py @@ -30,6 +30,7 @@ class DotEnv(TypedDict): EVAL_PORT = int(os.getenv("EVAL_PORT", 8000)) settings: DotEnv = { "EVAL_PORT": EVAL_PORT, + "MODEL_NAME": os.getenv("MODEL_NAME", "gpt-4"), "SERVER": os.getenv("SERVER", f"http://localhost:{EVAL_PORT}"), "USE_GPU": os.getenv("USE_GPU", "False").lower() == "true", "PLAYGROUND_DIR": os.getenv("PLAYGROUND_DIR", "playground"),