feat: model_name as an env, and check accessability to the model (#22)

This commit is contained in:
ChungHwan Han 2023-04-07 20:52:55 +09:00 committed by GitHub
parent d7e537dab7
commit 30595bc9c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 37 additions and 1 deletions

View File

@ -13,10 +13,13 @@ https://user-images.githubusercontent.com/51526347/230061897-b3479405-8ebd-45ab-
EVAL Making a UI for itself EVAL Making a UI for itself
#### [EVAL-BOT](https://github.com/eval-bot) #### [EVAL-BOT](https://github.com/eval-bot)
EVAL's self-managed github account. EVAL does everything except for signup and bio setting. EVAL's self-managed github account. EVAL does everything except for signup and bio setting.
### Examples ### Examples
[Here](examples/) is an example. [Here](examples/) is an example.
### EVAL's FEATURE ### EVAL's FEATURE
1. **Multimodal Conversation** 1. **Multimodal Conversation**
@ -81,6 +84,7 @@ Each optional env has default value, so you don't need to set unless you want to
- `SERVER` - server address (default: http://localhost:8000) - `SERVER` - server address (default: http://localhost:8000)
- `LOG_LEVEL` - INFO | DEBUG (default: INFO) - `LOG_LEVEL` - INFO | DEBUG (default: INFO)
- `BOT_NAME` - give it a name! (default: Orca) - `BOT_NAME` - give it a name! (default: Orca)
- `MODEL_NAME` - model name for GPT (default: gpt-4)
**For More Tools** **For More Tools**

View File

@ -19,6 +19,7 @@ class AgentBuilder:
def build_llm(self): def build_llm(self):
self.llm = ChatOpenAI(temperature=0) self.llm = ChatOpenAI(temperature=0)
self.llm.check_access()
def build_parser(self): def build_parser(self):
self.parser = EvalOutputParser() self.parser = EvalOutputParser()

View File

@ -5,6 +5,8 @@ import logging
import sys import sys
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
import openai
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
from langchain.schema import ( from langchain.schema import (
AIMessage, AIMessage,
@ -26,6 +28,9 @@ from tenacity import (
wait_exponential, wait_exponential,
) )
from env import settings
from ansi import ANSI, Color, Style
def _create_retry_decorator(llm: ChatOpenAI) -> Callable[[Any], Any]: def _create_retry_decorator(llm: ChatOpenAI) -> Callable[[Any], Any]:
import openai import openai
@ -98,6 +103,23 @@ def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
return ChatResult(generations=generations) return ChatResult(generations=generations)
class ModelNotFoundException(Exception):
"""Exception raised when the model is not found."""
def __init__(self, model_name: str):
self.model_name = model_name
super().__init__(
f"\n\nModel {ANSI(self.model_name).to(Color.red())} does not exist.\nMake sure if you have access to the model.\n"
+ f"You can the model name with the environment variable {ANSI('MODEL_NAME').to(Style.bold())} on {ANSI('.env').to(Style.bold())}.\n"
+ "\nex) MODEL_NAME=gpt-4\n"
+ ANSI(
"\nLooks like you don't have access to gpt-4 yet. Try using `gpt-3.5-turbo`."
if self.model_name == "gpt-4"
else ""
).to(Style.italic())
)
class ChatOpenAI(BaseChatModel, BaseModel): class ChatOpenAI(BaseChatModel, BaseModel):
"""Wrapper around OpenAI Chat large language models. """Wrapper around OpenAI Chat large language models.
@ -115,7 +137,7 @@ class ChatOpenAI(BaseChatModel, BaseModel):
""" """
client: Any #: :meta private: client: Any #: :meta private:
model_name: str = "gpt-4" model_name: str = settings["MODEL_NAME"]
"""Model name to use.""" """Model name to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict) model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified.""" """Holds any model parameters valid for `create` call not explicitly specified."""
@ -134,6 +156,14 @@ class ChatOpenAI(BaseChatModel, BaseModel):
extra = Extra.ignore extra = Extra.ignore
def check_access(self) -> None:
"""Check that the user has access to the model."""
try:
openai.Engine.retrieve(self.model_name)
except openai.error.InvalidRequestError:
raise ModelNotFoundException(self.model_name)
@root_validator(pre=True) @root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""

1
env.py
View File

@ -30,6 +30,7 @@ class DotEnv(TypedDict):
EVAL_PORT = int(os.getenv("EVAL_PORT", 8000)) EVAL_PORT = int(os.getenv("EVAL_PORT", 8000))
settings: DotEnv = { settings: DotEnv = {
"EVAL_PORT": EVAL_PORT, "EVAL_PORT": EVAL_PORT,
"MODEL_NAME": os.getenv("MODEL_NAME", "gpt-4"),
"SERVER": os.getenv("SERVER", f"http://localhost:{EVAL_PORT}"), "SERVER": os.getenv("SERVER", f"http://localhost:{EVAL_PORT}"),
"USE_GPU": os.getenv("USE_GPU", "False").lower() == "true", "USE_GPU": os.getenv("USE_GPU", "False").lower() == "true",
"PLAYGROUND_DIR": os.getenv("PLAYGROUND_DIR", "playground"), "PLAYGROUND_DIR": os.getenv("PLAYGROUND_DIR", "playground"),