together[patch]: Update @root_validator for pydantic 2 compatibility (#25423)

This PR updates usage of @root_validator to be compatible with pydantic 2.
This commit is contained in:
Eugene Yurtsev 2024-08-15 11:27:42 -04:00 committed by GitHub
parent a114255b82
commit 831708beb7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 65 additions and 51 deletions

View File

@ -1,6 +1,5 @@
"""Wrapper around Together AI's Chat Completions API."""
import os
from typing import (
Any,
Dict,
@ -12,8 +11,8 @@ import openai
from langchain_core.language_models.chat_models import LangSmithParams
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
from_env,
secret_from_env,
)
from langchain_openai.chat_models.base import BaseChatOpenAI
@ -311,13 +310,27 @@ class ChatTogether(BaseChatOpenAI):
model_name: str = Field(default="meta-llama/Llama-3-8b-chat-hf", alias="model")
"""Model name to use."""
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Automatically inferred from env are `TOGETHER_API_KEY` if not provided."""
together_api_base: Optional[str] = Field(
default="https://api.together.ai/v1/", alias="base_url"
together_api_key: Optional[SecretStr] = Field(
alias="api_key",
default_factory=secret_from_env("TOGETHER_API_KEY", default=None),
)
"""Together AI API key.
Automatically read from env variable `TOGETHER_API_KEY` if not provided.
"""
together_api_base: str = Field(
default_factory=from_env(
"TOGETHER_API_BASE", default="https://api.together.ai/v1/"
),
alias="base_url",
)
@root_validator()
class Config:
"""Pydantic config."""
allow_population_by_field_name = True
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
if values["n"] < 1:
@ -325,13 +338,6 @@ class ChatTogether(BaseChatOpenAI):
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")
values["together_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "together_api_key", "TOGETHER_API_KEY")
)
values["together_api_base"] = values["together_api_base"] or os.getenv(
"TOGETHER_API_BASE"
)
client_params = {
"api_key": (
values["together_api_key"].get_secret_value()

View File

@ -1,7 +1,6 @@
"""Wrapper around Together AI's Embeddings API."""
import logging
import os
import warnings
from typing import (
Any,
@ -25,9 +24,9 @@ from langchain_core.pydantic_v1 import (
root_validator,
)
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
from_env,
get_pydantic_field_names,
secret_from_env,
)
logger = logging.getLogger(__name__)
@ -115,10 +114,19 @@ class TogetherEmbeddings(BaseModel, Embeddings):
Not yet supported.
"""
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""API Key for Solar API."""
together_api_key: Optional[SecretStr] = Field(
alias="api_key",
default_factory=secret_from_env("TOGETHER_API_KEY", default=None),
)
"""Together AI API key.
Automatically read from env variable `TOGETHER_API_KEY` if not provided.
"""
together_api_base: str = Field(
default="https://api.together.ai/v1/", alias="base_url"
default_factory=from_env(
"TOGETHER_API_BASE", default="https://api.together.ai/v1/"
),
alias="base_url",
)
"""Endpoint URL to use."""
embedding_ctx_length: int = 4096
@ -198,18 +206,9 @@ class TogetherEmbeddings(BaseModel, Embeddings):
values["model_kwargs"] = extra
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
together_api_key = get_from_dict_or_env(
values, "together_api_key", "TOGETHER_API_KEY"
)
values["together_api_key"] = (
convert_to_secret_str(together_api_key) if together_api_key else None
)
values["together_api_base"] = values["together_api_base"] or os.getenv(
"TOGETHER_API_BASE"
)
@root_validator(pre=False, skip_on_failure=True)
def post_init(cls, values: Dict) -> Dict:
"""Logic that will post Pydantic initialization."""
client_params = {
"api_key": (
values["together_api_key"].get_secret_value()

View File

@ -11,8 +11,10 @@ from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
secret_from_env,
)
logger = logging.getLogger(__name__)
@ -36,8 +38,14 @@ class Together(LLM):
base_url: str = "https://api.together.ai/v1/completions"
"""Base completions API URL."""
together_api_key: SecretStr
"""Together AI API key. Get it here: https://api.together.ai/settings/api-keys"""
together_api_key: SecretStr = Field(
alias="api_key",
default_factory=secret_from_env("TOGETHER_API_KEY"),
)
"""Together AI API key.
Automatically read from env variable `TOGETHER_API_KEY` if not provided.
"""
model: str
"""Model name. Available models listed here:
Base Models: https://docs.together.ai/docs/inference-models#language-models
@ -74,21 +82,11 @@ class Together(LLM):
"""Configuration for this pydantic object."""
extra = "forbid"
allow_population_by_field_name = True
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
values["together_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "together_api_key", "TOGETHER_API_KEY")
)
return values
@root_validator()
def validate_max_tokens(cls, values: Dict) -> Dict:
"""The v1 completions endpoint, has max_tokens as required parameter.
Set a default value and warn if the parameter is missing.
"""
if values.get("max_tokens") is None:
warnings.warn(
"The completions endpoint, has 'max_tokens' as required argument. "

View File

@ -9,7 +9,7 @@ from langchain_together import Together
def test_together_api_key_is_secret_string() -> None:
"""Test that the API key is stored as a SecretStr."""
llm = Together(
together_api_key="secret-api-key", # type: ignore[arg-type]
together_api_key="secret-api-key", # type: ignore[call-arg]
model="togethercomputer/RedPajama-INCITE-7B-Base",
temperature=0.2,
max_tokens=250,
@ -38,7 +38,7 @@ def test_together_api_key_masked_when_passed_via_constructor(
) -> None:
"""Test that the API key is masked when passed via the constructor."""
llm = Together(
together_api_key="secret-api-key", # type: ignore[arg-type]
together_api_key="secret-api-key", # type: ignore[call-arg]
model="togethercomputer/RedPajama-INCITE-7B-Base",
temperature=0.2,
max_tokens=250,
@ -52,7 +52,18 @@ def test_together_api_key_masked_when_passed_via_constructor(
def test_together_uses_actual_secret_value_from_secretstr() -> None:
"""Test that the actual secret value is correctly retrieved."""
llm = Together(
together_api_key="secret-api-key", # type: ignore[arg-type]
together_api_key="secret-api-key", # type: ignore[call-arg]
model="togethercomputer/RedPajama-INCITE-7B-Base",
temperature=0.2,
max_tokens=250,
)
assert cast(SecretStr, llm.together_api_key).get_secret_value() == "secret-api-key"
def test_together_uses_actual_secret_value_from_secretstr_api_key() -> None:
"""Test that the actual secret value is correctly retrieved."""
llm = Together(
api_key="secret-api-key", # type: ignore[arg-type]
model="togethercomputer/RedPajama-INCITE-7B-Base",
temperature=0.2,
max_tokens=250,