From c9a362e4821b5fcdc81366ca2230df77f0a113fc Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 18 May 2023 09:12:23 -0700 Subject: [PATCH] add alias for model (#4553) Co-authored-by: Dev 2049 --- langchain/base_language.py | 11 ++++++++++- langchain/chat_models/openai.py | 9 ++++----- langchain/llms/openai.py | 9 ++++----- tests/integration_tests/chat_models/test_openai.py | 8 ++++++++ tests/integration_tests/llms/test_openai.py | 7 +++++++ 5 files changed, 33 insertions(+), 11 deletions(-) diff --git a/langchain/base_language.py b/langchain/base_language.py index 86353670..99201f28 100644 --- a/langchain/base_language.py +++ b/langchain/base_language.py @@ -2,7 +2,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import List, Optional, Sequence +from typing import List, Optional, Sequence, Set from pydantic import BaseModel @@ -68,3 +68,12 @@ class BaseLanguageModel(BaseModel, ABC): def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: """Get the number of tokens in the message.""" return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages]) + + @classmethod + def all_required_field_names(cls) -> Set: + all_required_field_names = set() + for field in cls.__fields__.values(): + all_required_field_names.add(field.name) + if field.has_alias: + all_required_field_names.add(field.alias) + return all_required_field_names diff --git a/langchain/chat_models/openai.py b/langchain/chat_models/openai.py index 2eefb091..0b46818c 100644 --- a/langchain/chat_models/openai.py +++ b/langchain/chat_models/openai.py @@ -112,7 +112,7 @@ class ChatOpenAI(BaseChatModel): """ client: Any #: :meta private: - model_name: str = "gpt-3.5-turbo" + model_name: str = Field(default="gpt-3.5-turbo", alias="model") """Model name to use.""" temperature: float = 0.7 """What sampling temperature to use.""" @@ -138,12 +138,12 @@ class ChatOpenAI(BaseChatModel): """Configuration for this pydantic object.""" extra = Extra.ignore + allow_population_by_field_name = True @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = {field.alias for field in cls.__fields__.values()} - + all_required_field_names = cls.all_required_field_names() extra = values.get("model_kwargs", {}) for field_name in list(values): if field_name in extra: @@ -156,8 +156,7 @@ class ChatOpenAI(BaseChatModel): ) extra[field_name] = values.pop(field_name) - disallowed_model_kwargs = all_required_field_names | {"model"} - invalid_model_kwargs = disallowed_model_kwargs.intersection(extra.keys()) + invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) if invalid_model_kwargs: raise ValueError( f"Parameters {invalid_model_kwargs} should be specified explicitly. " diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index bfca7fe3..1216868c 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -124,7 +124,7 @@ class BaseOpenAI(BaseLLM): """Wrapper around OpenAI large language models.""" client: Any #: :meta private: - model_name: str = "text-davinci-003" + model_name: str = Field("text-davinci-003", alias="model") """Model name to use.""" temperature: float = 0.7 """What sampling temperature to use.""" @@ -178,12 +178,12 @@ class BaseOpenAI(BaseLLM): """Configuration for this pydantic object.""" extra = Extra.ignore + allow_population_by_field_name = True @root_validator(pre=True) def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Build extra kwargs from additional params that were passed in.""" - all_required_field_names = {field.alias for field in cls.__fields__.values()} - + all_required_field_names = cls.all_required_field_names() extra = values.get("model_kwargs", {}) for field_name in list(values): if field_name in extra: @@ -196,8 +196,7 @@ class BaseOpenAI(BaseLLM): ) extra[field_name] = values.pop(field_name) - disallowed_model_kwargs = all_required_field_names | {"model"} - invalid_model_kwargs = disallowed_model_kwargs.intersection(extra.keys()) + invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) if invalid_model_kwargs: raise ValueError( f"Parameters {invalid_model_kwargs} should be specified explicitly. " diff --git a/tests/integration_tests/chat_models/test_openai.py b/tests/integration_tests/chat_models/test_openai.py index 679cb06c..568a9222 100644 --- a/tests/integration_tests/chat_models/test_openai.py +++ b/tests/integration_tests/chat_models/test_openai.py @@ -25,6 +25,14 @@ def test_chat_openai() -> None: assert isinstance(response.content, str) +def test_chat_openai_model() -> None: + """Test ChatOpenAI wrapper handles model_name.""" + chat = ChatOpenAI(model="foo") + assert chat.model_name == "foo" + chat = ChatOpenAI(model_name="bar") + assert chat.model_name == "bar" + + def test_chat_openai_system_message() -> None: """Test ChatOpenAI wrapper with system message.""" chat = ChatOpenAI(max_tokens=10) diff --git a/tests/integration_tests/llms/test_openai.py b/tests/integration_tests/llms/test_openai.py index 5132e3bb..e8c5a3d1 100644 --- a/tests/integration_tests/llms/test_openai.py +++ b/tests/integration_tests/llms/test_openai.py @@ -19,6 +19,13 @@ def test_openai_call() -> None: assert isinstance(output, str) +def test_openai_model_param() -> None: + llm = OpenAI(model="foo") + assert llm.model_name == "foo" + llm = OpenAI(model_name="foo") + assert llm.model_name == "foo" + + def test_openai_extra_kwargs() -> None: """Test extra kwargs to openai.""" # Check that foo is saved in extra_kwargs.