core[minor], anthropic[patch]: Upgrade @root_validator usage to be consistent with pydantic 2 (#25457)

anthropic: Upgrade `@root_validator` usage to be consistent with
pydantic 2
core: support looking up multiple keys from env in from_env factory
This commit is contained in:
Eugene Yurtsev 2024-08-15 16:09:34 -04:00 committed by GitHub
parent 34da8be60b
commit e18511bb22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 67 additions and 44 deletions

View File

@ -7,7 +7,7 @@ import importlib
import os
import warnings
from importlib.metadata import version
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union, overload
from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Union, overload
from packaging.version import parse
from requests import HTTPError, Response
@ -280,13 +280,17 @@ def from_env(key: str, /) -> Callable[[], str]: ...
def from_env(key: str, /, *, default: str) -> Callable[[], str]: ...
@overload
def from_env(key: Sequence[str], /, *, default: str) -> Callable[[], str]: ...
@overload
def from_env(key: str, /, *, error_message: str) -> Callable[[], str]: ...
@overload
def from_env(
key: str, /, *, default: str, error_message: Optional[str]
key: Union[str, Sequence[str]], /, *, default: str, error_message: Optional[str]
) -> Callable[[], str]: ...
@ -301,7 +305,7 @@ def from_env(key: str, /, *, default: None) -> Callable[[], Optional[str]]: ...
def from_env(
key: str,
key: Union[str, Sequence[str]],
/,
*,
default: Union[str, _NoDefaultType, None] = _NoDefault,
@ -310,7 +314,10 @@ def from_env(
"""Create a factory method that gets a value from an environment variable.
Args:
key: The environment variable to look up.
key: The environment variable to look up. If a list of keys is provided,
the first key found in the environment will be used.
If no key is found, the default value will be used if set,
otherwise an error will be raised.
default: The default value to return if the environment variable is not set.
error_message: the error message which will be raised if the key is not found
and no default value is provided.
@ -319,9 +326,15 @@ def from_env(
def get_from_env_fn() -> Optional[str]:
"""Get a value from an environment variable."""
if key in os.environ:
return os.environ[key]
elif isinstance(default, (str, type(None))):
if isinstance(key, (list, tuple)):
for k in key:
if k in os.environ:
return os.environ[k]
if isinstance(key, str):
if key in os.environ:
return os.environ[key]
if isinstance(default, (str, type(None))):
return default
else:
if error_message:

View File

@ -1,4 +1,3 @@
import os
import re
import warnings
from operator import itemgetter
@ -64,8 +63,9 @@ from langchain_core.runnables import (
from langchain_core.tools import BaseTool
from langchain_core.utils import (
build_extra_kwargs,
convert_to_secret_str,
from_env,
get_pydantic_field_names,
secret_from_env,
)
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
@ -541,14 +541,26 @@ class ChatAnthropic(BaseChatModel):
stop_sequences: Optional[List[str]] = Field(None, alias="stop")
"""Default stop sequences."""
anthropic_api_url: Optional[str] = Field(None, alias="base_url")
anthropic_api_url: Optional[str] = Field(
alias="base_url",
default_factory=from_env(
["ANTHROPIC_API_URL", "ANTHROPIC_BASE_URL"],
default="https://api.anthropic.com",
),
)
"""Base URL for API requests. Only specify if using a proxy or service emulator.
If a value isn't passed in and environment variable ANTHROPIC_BASE_URL is set, value
will be read from there.
If a value isn't passed in, will attempt to read the value first from
ANTHROPIC_API_URL and if that is not set, ANTHROPIC_BASE_URL.
If neither are set, the default value of 'https://api.anthropic.com' will
be used.
"""
anthropic_api_key: Optional[SecretStr] = Field(None, alias="api_key")
anthropic_api_key: SecretStr = Field(
alias="api_key",
default_factory=secret_from_env("ANTHROPIC_API_KEY", default=""),
)
"""Automatically read from env var `ANTHROPIC_API_KEY` if not provided."""
default_headers: Optional[Mapping[str, str]] = None
@ -623,20 +635,10 @@ class ChatAnthropic(BaseChatModel):
)
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
anthropic_api_key = convert_to_secret_str(
values.get("anthropic_api_key") or os.environ.get("ANTHROPIC_API_KEY") or ""
)
values["anthropic_api_key"] = anthropic_api_key
api_key = anthropic_api_key.get_secret_value()
api_url = (
values.get("anthropic_api_url")
or os.environ.get("ANTHROPIC_API_URL")
or os.environ.get("ANTHROPIC_BASE_URL")
or "https://api.anthropic.com"
)
values["anthropic_api_url"] = api_url
@root_validator(pre=False, skip_on_failure=True)
def post_init(cls, values: Dict) -> Dict:
api_key = values["anthropic_api_key"].get_secret_value()
api_url = values["anthropic_api_url"]
client_params = {
"api_key": api_key,
"base_url": api_url,

View File

@ -23,10 +23,13 @@ from langchain_core.outputs import GenerationChunk
from langchain_core.prompt_values import PromptValue
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain_core.utils.utils import build_extra_kwargs, convert_to_secret_str
from langchain_core.utils.utils import (
build_extra_kwargs,
from_env,
secret_from_env,
)
class _AnthropicCommon(BaseLanguageModel):
@ -56,9 +59,25 @@ class _AnthropicCommon(BaseLanguageModel):
max_retries: int = 2
"""Number of retries allowed for requests sent to the Anthropic Completion API."""
anthropic_api_url: Optional[str] = None
anthropic_api_url: Optional[str] = Field(
alias="base_url",
default_factory=from_env(
"ANTHROPIC_API_URL",
default="https://api.anthropic.com",
),
)
"""Base URL for API requests. Only specify if using a proxy or service emulator.
anthropic_api_key: Optional[SecretStr] = None
If a value isn't passed in, will attempt to read the value from
ANTHROPIC_API_URL. If not set, the default value of 'https://api.anthropic.com' will
be used.
"""
anthropic_api_key: SecretStr = Field(
alias="api_key",
default_factory=secret_from_env("ANTHROPIC_API_KEY", default=""),
)
"""Automatically read from env var `ANTHROPIC_API_KEY` if not provided."""
HUMAN_PROMPT: Optional[str] = None
AI_PROMPT: Optional[str] = None
@ -74,20 +93,9 @@ class _AnthropicCommon(BaseLanguageModel):
)
return values
@root_validator()
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["anthropic_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "anthropic_api_key", "ANTHROPIC_API_KEY")
)
# Get custom api url from environment.
values["anthropic_api_url"] = get_from_dict_or_env(
values,
"anthropic_api_url",
"ANTHROPIC_API_URL",
default="https://api.anthropic.com",
)
values["client"] = anthropic.Anthropic(
base_url=values["anthropic_api_url"],
api_key=values["anthropic_api_key"].get_secret_value(),
@ -158,7 +166,7 @@ class AnthropicLLM(LLM, _AnthropicCommon):
allow_population_by_field_name = True
arbitrary_types_allowed = True
@root_validator()
@root_validator(pre=True)
def raise_warning(cls, values: Dict) -> Dict:
"""Raise warning that this class is deprecated."""
warnings.warn(