Anthropic: Allow the use of kwargs consistent with ChatOpenAI. (#9515)

- Description: ~~Creates a new root_validator in `_AnthropicCommon` that
allows the use of `model_name` and `max_tokens` keyword arguments.~~
Adds pydantic field aliases to support `model_name` and `max_tokens` as
keyword arguments. Ultimately, this makes `ChatAnthropic` more
consistent with `ChatOpenAI`, making the two classes more
interchangeable for the developer.
  - Issue: https://github.com/langchain-ai/langchain/issues/9510

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/9672/head
Joshua Sundance Bailey 11 months ago committed by GitHub
parent a8c916955f
commit a9c86774da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -36,6 +36,12 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
model = ChatAnthropic(model="<model_name>", anthropic_api_key="my-api-key")
"""
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
arbitrary_types_allowed = True
@property
def lc_secrets(self) -> Dict[str, str]:
return {"anthropic_api_key": "ANTHROPIC_API_KEY"}

@ -21,10 +21,10 @@ from langchain.utils.utils import build_extra_kwargs
class _AnthropicCommon(BaseLanguageModel):
client: Any = None #: :meta private:
async_client: Any = None #: :meta private:
model: str = "claude-2"
model: str = Field(default="claude-2", alias="model_name")
"""Model name to use."""
max_tokens_to_sample: int = 256
max_tokens_to_sample: int = Field(default=256, alias="max_tokens")
"""Denotes the number of tokens to predict per generation."""
temperature: Optional[float] = None
@ -144,6 +144,7 @@ class Anthropic(LLM, _AnthropicCommon):
import anthropic
from langchain.llms import Anthropic
model = Anthropic(model="<model_name>", anthropic_api_key="my-api-key")
# Simplest invocation, automatically wrapped with HUMAN_PROMPT
@ -157,6 +158,12 @@ class Anthropic(LLM, _AnthropicCommon):
response = model(prompt)
"""
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
arbitrary_types_allowed = True
@root_validator()
def raise_warning(cls, values: Dict) -> Dict:
"""Raise warning that this class is deprecated."""

@ -8,6 +8,18 @@ from langchain.chat_models import ChatAnthropic
os.environ["ANTHROPIC_API_KEY"] = "foo"
@pytest.mark.requires("anthropic")
def test_anthropic_model_name_param() -> None:
llm = ChatAnthropic(model_name="foo")
assert llm.model == "foo"
@pytest.mark.requires("anthropic")
def test_anthropic_model_param() -> None:
llm = ChatAnthropic(model="foo")
assert llm.model == "foo"
@pytest.mark.requires("anthropic")
def test_anthropic_model_kwargs() -> None:
llm = ChatAnthropic(model_kwargs={"foo": "bar"})

@ -9,6 +9,18 @@ from langchain.schema import LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@pytest.mark.requires("anthropic")
def test_anthropic_model_name_param() -> None:
llm = Anthropic(model_name="foo")
assert llm.model == "foo"
@pytest.mark.requires("anthropic")
def test_anthropic_model_param() -> None:
llm = Anthropic(model="foo")
assert llm.model == "foo"
def test_anthropic_call() -> None:
"""Test valid call to anthropic."""
llm = Anthropic(model="claude-instant-1")
@ -24,7 +36,7 @@ def test_anthropic_streaming() -> None:
assert isinstance(generator, Generator)
for token in generator:
assert isinstance(token["completion"], str)
assert isinstance(token, str)
def test_anthropic_streaming_callback() -> None:

@ -6,9 +6,7 @@ from unittest.mock import MagicMock, patch
import pytest
from langchain.adapters.openai import convert_dict_to_message
from langchain.chat_models.openai import (
ChatOpenAI,
)
from langchain.chat_models.openai import ChatOpenAI
from langchain.schema.messages import (
AIMessage,
FunctionMessage,
@ -17,6 +15,14 @@ from langchain.schema.messages import (
)
@pytest.mark.requires("openai")
def test_openai_model_param() -> None:
llm = ChatOpenAI(model="foo")
assert llm.model_name == "foo"
llm = ChatOpenAI(model_name="foo")
assert llm.model_name == "foo"
def test_function_message_dict_to_function_message() -> None:
content = json.dumps({"result": "Example #1"})
name = "test_function"

Loading…
Cancel
Save