add alias for model (#4553)

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
docker
Harrison Chase 1 year ago committed by GitHub
parent 7642f2159c
commit c9a362e482
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional, Sequence from typing import List, Optional, Sequence, Set
from pydantic import BaseModel from pydantic import BaseModel
@ -68,3 +68,12 @@ class BaseLanguageModel(BaseModel, ABC):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the message.""" """Get the number of tokens in the message."""
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages]) return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
@classmethod
def all_required_field_names(cls) -> Set:
all_required_field_names = set()
for field in cls.__fields__.values():
all_required_field_names.add(field.name)
if field.has_alias:
all_required_field_names.add(field.alias)
return all_required_field_names

@ -112,7 +112,7 @@ class ChatOpenAI(BaseChatModel):
""" """
client: Any #: :meta private: client: Any #: :meta private:
model_name: str = "gpt-3.5-turbo" model_name: str = Field(default="gpt-3.5-turbo", alias="model")
"""Model name to use.""" """Model name to use."""
temperature: float = 0.7 temperature: float = 0.7
"""What sampling temperature to use.""" """What sampling temperature to use."""
@ -138,12 +138,12 @@ class ChatOpenAI(BaseChatModel):
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
extra = Extra.ignore extra = Extra.ignore
allow_population_by_field_name = True
@root_validator(pre=True) @root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()} all_required_field_names = cls.all_required_field_names()
extra = values.get("model_kwargs", {}) extra = values.get("model_kwargs", {})
for field_name in list(values): for field_name in list(values):
if field_name in extra: if field_name in extra:
@ -156,8 +156,7 @@ class ChatOpenAI(BaseChatModel):
) )
extra[field_name] = values.pop(field_name) extra[field_name] = values.pop(field_name)
disallowed_model_kwargs = all_required_field_names | {"model"} invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
invalid_model_kwargs = disallowed_model_kwargs.intersection(extra.keys())
if invalid_model_kwargs: if invalid_model_kwargs:
raise ValueError( raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. " f"Parameters {invalid_model_kwargs} should be specified explicitly. "

@ -124,7 +124,7 @@ class BaseOpenAI(BaseLLM):
"""Wrapper around OpenAI large language models.""" """Wrapper around OpenAI large language models."""
client: Any #: :meta private: client: Any #: :meta private:
model_name: str = "text-davinci-003" model_name: str = Field("text-davinci-003", alias="model")
"""Model name to use.""" """Model name to use."""
temperature: float = 0.7 temperature: float = 0.7
"""What sampling temperature to use.""" """What sampling temperature to use."""
@ -178,12 +178,12 @@ class BaseOpenAI(BaseLLM):
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
extra = Extra.ignore extra = Extra.ignore
allow_population_by_field_name = True
@root_validator(pre=True) @root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in.""" """Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()} all_required_field_names = cls.all_required_field_names()
extra = values.get("model_kwargs", {}) extra = values.get("model_kwargs", {})
for field_name in list(values): for field_name in list(values):
if field_name in extra: if field_name in extra:
@ -196,8 +196,7 @@ class BaseOpenAI(BaseLLM):
) )
extra[field_name] = values.pop(field_name) extra[field_name] = values.pop(field_name)
disallowed_model_kwargs = all_required_field_names | {"model"} invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
invalid_model_kwargs = disallowed_model_kwargs.intersection(extra.keys())
if invalid_model_kwargs: if invalid_model_kwargs:
raise ValueError( raise ValueError(
f"Parameters {invalid_model_kwargs} should be specified explicitly. " f"Parameters {invalid_model_kwargs} should be specified explicitly. "

@ -25,6 +25,14 @@ def test_chat_openai() -> None:
assert isinstance(response.content, str) assert isinstance(response.content, str)
def test_chat_openai_model() -> None:
"""Test ChatOpenAI wrapper handles model_name."""
chat = ChatOpenAI(model="foo")
assert chat.model_name == "foo"
chat = ChatOpenAI(model_name="bar")
assert chat.model_name == "bar"
def test_chat_openai_system_message() -> None: def test_chat_openai_system_message() -> None:
"""Test ChatOpenAI wrapper with system message.""" """Test ChatOpenAI wrapper with system message."""
chat = ChatOpenAI(max_tokens=10) chat = ChatOpenAI(max_tokens=10)

@ -19,6 +19,13 @@ def test_openai_call() -> None:
assert isinstance(output, str) assert isinstance(output, str)
def test_openai_model_param() -> None:
llm = OpenAI(model="foo")
assert llm.model_name == "foo"
llm = OpenAI(model_name="foo")
assert llm.model_name == "foo"
def test_openai_extra_kwargs() -> None: def test_openai_extra_kwargs() -> None:
"""Test extra kwargs to openai.""" """Test extra kwargs to openai."""
# Check that foo is saved in extra_kwargs. # Check that foo is saved in extra_kwargs.

Loading…
Cancel
Save