From a9c86774daa784f71b8daf849ae950b8ff5a5a27 Mon Sep 17 00:00:00 2001 From: Joshua Sundance Bailey <84336755+joshuasundance-swca@users.noreply.github.com> Date: Wed, 23 Aug 2023 21:23:21 -0400 Subject: [PATCH] Anthropic: Allow the use of kwargs consistent with ChatOpenAI. (#9515) - Description: ~~Creates a new root_validator in `_AnthropicCommon` that allows the use of `model_name` and `max_tokens` keyword arguments.~~ Adds pydantic field aliases to support `model_name` and `max_tokens` as keyword arguments. Ultimately, this makes `ChatAnthropic` more consistent with `ChatOpenAI`, making the two classes more interchangeable for the developer. - Issue: https://github.com/langchain-ai/langchain/issues/9510 --------- Co-authored-by: Bagatur --- libs/langchain/langchain/chat_models/anthropic.py | 6 ++++++ libs/langchain/langchain/llms/anthropic.py | 11 +++++++++-- .../chat_models/test_anthropic_2.py | 12 ++++++++++++ .../tests/integration_tests/llms/test_anthropic.py | 14 +++++++++++++- .../tests/unit_tests/chat_models/test_openai.py | 12 +++++++++--- 5 files changed, 49 insertions(+), 6 deletions(-) diff --git a/libs/langchain/langchain/chat_models/anthropic.py b/libs/langchain/langchain/chat_models/anthropic.py index ef1da63196..4d00eae4df 100644 --- a/libs/langchain/langchain/chat_models/anthropic.py +++ b/libs/langchain/langchain/chat_models/anthropic.py @@ -36,6 +36,12 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon): model = ChatAnthropic(model="", anthropic_api_key="my-api-key") """ + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + arbitrary_types_allowed = True + @property def lc_secrets(self) -> Dict[str, str]: return {"anthropic_api_key": "ANTHROPIC_API_KEY"} diff --git a/libs/langchain/langchain/llms/anthropic.py b/libs/langchain/langchain/llms/anthropic.py index afaea04b51..63664e07af 100644 --- a/libs/langchain/langchain/llms/anthropic.py +++ b/libs/langchain/langchain/llms/anthropic.py @@ -21,10 +21,10 @@ from langchain.utils.utils import build_extra_kwargs class _AnthropicCommon(BaseLanguageModel): client: Any = None #: :meta private: async_client: Any = None #: :meta private: - model: str = "claude-2" + model: str = Field(default="claude-2", alias="model_name") """Model name to use.""" - max_tokens_to_sample: int = 256 + max_tokens_to_sample: int = Field(default=256, alias="max_tokens") """Denotes the number of tokens to predict per generation.""" temperature: Optional[float] = None @@ -144,6 +144,7 @@ class Anthropic(LLM, _AnthropicCommon): import anthropic from langchain.llms import Anthropic + model = Anthropic(model="", anthropic_api_key="my-api-key") # Simplest invocation, automatically wrapped with HUMAN_PROMPT @@ -157,6 +158,12 @@ class Anthropic(LLM, _AnthropicCommon): response = model(prompt) """ + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + arbitrary_types_allowed = True + @root_validator() def raise_warning(cls, values: Dict) -> Dict: """Raise warning that this class is deprecated.""" diff --git a/libs/langchain/tests/integration_tests/chat_models/test_anthropic_2.py b/libs/langchain/tests/integration_tests/chat_models/test_anthropic_2.py index 7447ec03e4..54b6045272 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_anthropic_2.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_anthropic_2.py @@ -8,6 +8,18 @@ from langchain.chat_models import ChatAnthropic os.environ["ANTHROPIC_API_KEY"] = "foo" +@pytest.mark.requires("anthropic") +def test_anthropic_model_name_param() -> None: + llm = ChatAnthropic(model_name="foo") + assert llm.model == "foo" + + +@pytest.mark.requires("anthropic") +def test_anthropic_model_param() -> None: + llm = ChatAnthropic(model="foo") + assert llm.model == "foo" + + @pytest.mark.requires("anthropic") def test_anthropic_model_kwargs() -> None: llm = ChatAnthropic(model_kwargs={"foo": "bar"}) diff --git a/libs/langchain/tests/integration_tests/llms/test_anthropic.py b/libs/langchain/tests/integration_tests/llms/test_anthropic.py index 3604f61969..f68053b2aa 100644 --- a/libs/langchain/tests/integration_tests/llms/test_anthropic.py +++ b/libs/langchain/tests/integration_tests/llms/test_anthropic.py @@ -9,6 +9,18 @@ from langchain.schema import LLMResult from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler +@pytest.mark.requires("anthropic") +def test_anthropic_model_name_param() -> None: + llm = Anthropic(model_name="foo") + assert llm.model == "foo" + + +@pytest.mark.requires("anthropic") +def test_anthropic_model_param() -> None: + llm = Anthropic(model="foo") + assert llm.model == "foo" + + def test_anthropic_call() -> None: """Test valid call to anthropic.""" llm = Anthropic(model="claude-instant-1") @@ -24,7 +36,7 @@ def test_anthropic_streaming() -> None: assert isinstance(generator, Generator) for token in generator: - assert isinstance(token["completion"], str) + assert isinstance(token, str) def test_anthropic_streaming_callback() -> None: diff --git a/libs/langchain/tests/unit_tests/chat_models/test_openai.py b/libs/langchain/tests/unit_tests/chat_models/test_openai.py index b417d82e94..c233724724 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_openai.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_openai.py @@ -6,9 +6,7 @@ from unittest.mock import MagicMock, patch import pytest from langchain.adapters.openai import convert_dict_to_message -from langchain.chat_models.openai import ( - ChatOpenAI, -) +from langchain.chat_models.openai import ChatOpenAI from langchain.schema.messages import ( AIMessage, FunctionMessage, @@ -17,6 +15,14 @@ from langchain.schema.messages import ( ) +@pytest.mark.requires("openai") +def test_openai_model_param() -> None: + llm = ChatOpenAI(model="foo") + assert llm.model_name == "foo" + llm = ChatOpenAI(model_name="foo") + assert llm.model_name == "foo" + + def test_function_message_dict_to_function_message() -> None: content = json.dumps({"result": "Example #1"}) name = "test_function"