mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
together[patch]: Update @root_validator for pydantic 2 compatibility (#25423)
This PR updates usage of @root_validator to be compatible with pydantic 2.
This commit is contained in:
parent
a114255b82
commit
831708beb7
@ -1,6 +1,5 @@
|
||||
"""Wrapper around Together AI's Chat Completions API."""
|
||||
|
||||
import os
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
@ -12,8 +11,8 @@ import openai
|
||||
from langchain_core.language_models.chat_models import LangSmithParams
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
from_env,
|
||||
secret_from_env,
|
||||
)
|
||||
from langchain_openai.chat_models.base import BaseChatOpenAI
|
||||
|
||||
@ -311,13 +310,27 @@ class ChatTogether(BaseChatOpenAI):
|
||||
|
||||
model_name: str = Field(default="meta-llama/Llama-3-8b-chat-hf", alias="model")
|
||||
"""Model name to use."""
|
||||
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
"""Automatically inferred from env are `TOGETHER_API_KEY` if not provided."""
|
||||
together_api_base: Optional[str] = Field(
|
||||
default="https://api.together.ai/v1/", alias="base_url"
|
||||
together_api_key: Optional[SecretStr] = Field(
|
||||
alias="api_key",
|
||||
default_factory=secret_from_env("TOGETHER_API_KEY", default=None),
|
||||
)
|
||||
"""Together AI API key.
|
||||
|
||||
Automatically read from env variable `TOGETHER_API_KEY` if not provided.
|
||||
"""
|
||||
together_api_base: str = Field(
|
||||
default_factory=from_env(
|
||||
"TOGETHER_API_BASE", default="https://api.together.ai/v1/"
|
||||
),
|
||||
alias="base_url",
|
||||
)
|
||||
|
||||
@root_validator()
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if values["n"] < 1:
|
||||
@ -325,13 +338,6 @@ class ChatTogether(BaseChatOpenAI):
|
||||
if values["n"] > 1 and values["streaming"]:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
|
||||
values["together_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "together_api_key", "TOGETHER_API_KEY")
|
||||
)
|
||||
values["together_api_base"] = values["together_api_base"] or os.getenv(
|
||||
"TOGETHER_API_BASE"
|
||||
)
|
||||
|
||||
client_params = {
|
||||
"api_key": (
|
||||
values["together_api_key"].get_secret_value()
|
||||
|
@ -1,7 +1,6 @@
|
||||
"""Wrapper around Together AI's Embeddings API."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
@ -25,9 +24,9 @@ from langchain_core.pydantic_v1 import (
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
from_env,
|
||||
get_pydantic_field_names,
|
||||
secret_from_env,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -115,10 +114,19 @@ class TogetherEmbeddings(BaseModel, Embeddings):
|
||||
|
||||
Not yet supported.
|
||||
"""
|
||||
together_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
"""API Key for Solar API."""
|
||||
together_api_key: Optional[SecretStr] = Field(
|
||||
alias="api_key",
|
||||
default_factory=secret_from_env("TOGETHER_API_KEY", default=None),
|
||||
)
|
||||
"""Together AI API key.
|
||||
|
||||
Automatically read from env variable `TOGETHER_API_KEY` if not provided.
|
||||
"""
|
||||
together_api_base: str = Field(
|
||||
default="https://api.together.ai/v1/", alias="base_url"
|
||||
default_factory=from_env(
|
||||
"TOGETHER_API_BASE", default="https://api.together.ai/v1/"
|
||||
),
|
||||
alias="base_url",
|
||||
)
|
||||
"""Endpoint URL to use."""
|
||||
embedding_ctx_length: int = 4096
|
||||
@ -198,18 +206,9 @@ class TogetherEmbeddings(BaseModel, Embeddings):
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
together_api_key = get_from_dict_or_env(
|
||||
values, "together_api_key", "TOGETHER_API_KEY"
|
||||
)
|
||||
values["together_api_key"] = (
|
||||
convert_to_secret_str(together_api_key) if together_api_key else None
|
||||
)
|
||||
values["together_api_base"] = values["together_api_base"] or os.getenv(
|
||||
"TOGETHER_API_BASE"
|
||||
)
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def post_init(cls, values: Dict) -> Dict:
|
||||
"""Logic that will post Pydantic initialization."""
|
||||
client_params = {
|
||||
"api_key": (
|
||||
values["together_api_key"].get_secret_value()
|
||||
|
@ -11,8 +11,10 @@ from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import (
|
||||
secret_from_env,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -36,8 +38,14 @@ class Together(LLM):
|
||||
|
||||
base_url: str = "https://api.together.ai/v1/completions"
|
||||
"""Base completions API URL."""
|
||||
together_api_key: SecretStr
|
||||
"""Together AI API key. Get it here: https://api.together.ai/settings/api-keys"""
|
||||
together_api_key: SecretStr = Field(
|
||||
alias="api_key",
|
||||
default_factory=secret_from_env("TOGETHER_API_KEY"),
|
||||
)
|
||||
"""Together AI API key.
|
||||
|
||||
Automatically read from env variable `TOGETHER_API_KEY` if not provided.
|
||||
"""
|
||||
model: str
|
||||
"""Model name. Available models listed here:
|
||||
Base Models: https://docs.together.ai/docs/inference-models#language-models
|
||||
@ -74,21 +82,11 @@ class Together(LLM):
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = "forbid"
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key exists in environment."""
|
||||
values["together_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "together_api_key", "TOGETHER_API_KEY")
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def validate_max_tokens(cls, values: Dict) -> Dict:
|
||||
"""The v1 completions endpoint, has max_tokens as required parameter.
|
||||
|
||||
Set a default value and warn if the parameter is missing.
|
||||
"""
|
||||
if values.get("max_tokens") is None:
|
||||
warnings.warn(
|
||||
"The completions endpoint, has 'max_tokens' as required argument. "
|
||||
|
@ -9,7 +9,7 @@ from langchain_together import Together
|
||||
def test_together_api_key_is_secret_string() -> None:
|
||||
"""Test that the API key is stored as a SecretStr."""
|
||||
llm = Together(
|
||||
together_api_key="secret-api-key", # type: ignore[arg-type]
|
||||
together_api_key="secret-api-key", # type: ignore[call-arg]
|
||||
model="togethercomputer/RedPajama-INCITE-7B-Base",
|
||||
temperature=0.2,
|
||||
max_tokens=250,
|
||||
@ -38,7 +38,7 @@ def test_together_api_key_masked_when_passed_via_constructor(
|
||||
) -> None:
|
||||
"""Test that the API key is masked when passed via the constructor."""
|
||||
llm = Together(
|
||||
together_api_key="secret-api-key", # type: ignore[arg-type]
|
||||
together_api_key="secret-api-key", # type: ignore[call-arg]
|
||||
model="togethercomputer/RedPajama-INCITE-7B-Base",
|
||||
temperature=0.2,
|
||||
max_tokens=250,
|
||||
@ -52,7 +52,18 @@ def test_together_api_key_masked_when_passed_via_constructor(
|
||||
def test_together_uses_actual_secret_value_from_secretstr() -> None:
|
||||
"""Test that the actual secret value is correctly retrieved."""
|
||||
llm = Together(
|
||||
together_api_key="secret-api-key", # type: ignore[arg-type]
|
||||
together_api_key="secret-api-key", # type: ignore[call-arg]
|
||||
model="togethercomputer/RedPajama-INCITE-7B-Base",
|
||||
temperature=0.2,
|
||||
max_tokens=250,
|
||||
)
|
||||
assert cast(SecretStr, llm.together_api_key).get_secret_value() == "secret-api-key"
|
||||
|
||||
|
||||
def test_together_uses_actual_secret_value_from_secretstr_api_key() -> None:
|
||||
"""Test that the actual secret value is correctly retrieved."""
|
||||
llm = Together(
|
||||
api_key="secret-api-key", # type: ignore[arg-type]
|
||||
model="togethercomputer/RedPajama-INCITE-7B-Base",
|
||||
temperature=0.2,
|
||||
max_tokens=250,
|
||||
|
Loading…
Reference in New Issue
Block a user