add alias for model (#4553)

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
docker
Harrison Chase 1 year ago committed by GitHub
parent 7642f2159c
commit c9a362e482
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,7 +2,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Optional, Sequence
from typing import List, Optional, Sequence, Set
from pydantic import BaseModel
@ -68,3 +68,12 @@ class BaseLanguageModel(BaseModel, ABC):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the message."""
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
@classmethod
def all_required_field_names(cls) -> Set:
all_required_field_names = set()
for field in cls.__fields__.values():
all_required_field_names.add(field.name)
if field.has_alias:
all_required_field_names.add(field.alias)
return all_required_field_names

@ -112,7 +112,7 @@ class ChatOpenAI(BaseChatModel):
"""
client: Any #: :meta private:
model_name: str = "gpt-3.5-turbo"
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
"""Model name to use."""
temperature: float = 0.7
"""What sampling temperature to use."""
@ -138,12 +138,12 @@ class ChatOpenAI(BaseChatModel):
"""Configuration for this pydantic object."""
extra = Extra.ignore
allow_population_by_field_name = True
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()}
all_required_field_names = cls.all_required_field_names()
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
@ -156,8 +156,7 @@ class ChatOpenAI(BaseChatModel):
)
extra[field_name] = values.pop(field_name)
disallowed_model_kwargs = all_required_field_names | {"model"}
invalid_model_kwargs = disallowed_model_kwargs.intersection(extra.keys())
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "

@ -124,7 +124,7 @@ class BaseOpenAI(BaseLLM):
"""Wrapper around OpenAI large language models."""
client: Any #: :meta private:
model_name: str = "text-davinci-003"
model_name: str = Field("text-davinci-003", alias="model")
"""Model name to use."""
temperature: float = 0.7
"""What sampling temperature to use."""
@ -178,12 +178,12 @@ class BaseOpenAI(BaseLLM):
"""Configuration for this pydantic object."""
extra = Extra.ignore
allow_population_by_field_name = True
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()}
all_required_field_names = cls.all_required_field_names()
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name in extra:
@ -196,8 +196,7 @@ class BaseOpenAI(BaseLLM):
)
extra[field_name] = values.pop(field_name)
disallowed_model_kwargs = all_required_field_names | {"model"}
invalid_model_kwargs = disallowed_model_kwargs.intersection(extra.keys())
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
if invalid_model_kwargs:
raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. "

@ -25,6 +25,14 @@ def test_chat_openai() -> None:
assert isinstance(response.content, str)
def test_chat_openai_model() -> None:
"""Test ChatOpenAI wrapper handles model_name."""
chat = ChatOpenAI(model="foo")
assert chat.model_name == "foo"
chat = ChatOpenAI(model_name="bar")
assert chat.model_name == "bar"
def test_chat_openai_system_message() -> None:
"""Test ChatOpenAI wrapper with system message."""
chat = ChatOpenAI(max_tokens=10)

@ -19,6 +19,13 @@ def test_openai_call() -> None:
assert isinstance(output, str)
def test_openai_model_param() -> None:
llm = OpenAI(model="foo")
assert llm.model_name == "foo"
llm = OpenAI(model_name="foo")
assert llm.model_name == "foo"
def test_openai_extra_kwargs() -> None:
"""Test extra kwargs to openai."""
# Check that foo is saved in extra_kwargs.

Loading…
Cancel
Save