voyageai[patch]: Upgrade root validators for pydantic 2 (#25455)

Update @root_validators to be consistent with pydantic 2 semantics
This commit is contained in:
Eugene Yurtsev 2024-08-15 15:30:41 -04:00 committed by GitHub
parent 4cdaca67dc
commit b297af5482
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 17 deletions

View File

@ -1,5 +1,4 @@
import logging
import os
from typing import Iterable, List, Optional
import voyageai # type: ignore
@ -10,7 +9,7 @@ from langchain_core.pydantic_v1 import (
SecretStr,
root_validator,
)
from langchain_core.utils import convert_to_secret_str
from langchain_core.utils import secret_from_env
logger = logging.getLogger(__name__)
@ -32,34 +31,32 @@ class VoyageAIEmbeddings(BaseModel, Embeddings):
batch_size: int
show_progress_bar: bool = False
truncation: Optional[bool] = None
voyage_api_key: Optional[SecretStr] = None
voyage_api_key: SecretStr = Field(
alias="api_key",
default_factory=secret_from_env(
"VOYAGE_API_KEY",
error_message="Must set `VOYAGE_API_KEY` environment variable or "
"pass `api_key` to VoyageAIEmbeddings constructor.",
),
)
class Config:
extra = "forbid"
allow_population_by_field_name = True
@root_validator(pre=True)
def default_values(cls, values: dict) -> dict:
"""Set default batch size based on model"""
model = values.get("model")
batch_size = values.get("batch_size")
if batch_size is None:
values["batch_size"] = 72 if model in ["voyage-2", "voyage-02"] else 7
return values
@root_validator()
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: dict) -> dict:
"""Validate that VoyageAI credentials exist in environment."""
voyage_api_key = values.get("voyage_api_key") or os.getenv(
"VOYAGE_API_KEY", None
)
if voyage_api_key:
api_key_secretstr = convert_to_secret_str(voyage_api_key)
values["voyage_api_key"] = api_key_secretstr
api_key_str = api_key_secretstr.get_secret_value()
else:
api_key_str = None
api_key_str = values["voyage_api_key"].get_secret_value()
values["_client"] = voyageai.Client(api_key=api_key_str)
values["_aclient"] = voyageai.client_async.AsyncClient(api_key=api_key_str)
return values

View File

@ -9,6 +9,17 @@ MODEL = "voyage-2"
def test_initialization_voyage_2() -> None:
"""Test embedding model initialization."""
emb = VoyageAIEmbeddings(api_key="NOT_A_VALID_KEY", model=MODEL)
assert isinstance(emb, Embeddings)
assert emb.batch_size == 72
assert emb.model == MODEL
assert emb._client is not None
def test_initialization_voyage_2_with_full_api_key_name() -> None:
"""Test embedding model initialization."""
# Testing that we can initialize the model using `voyage_api_key`
# instead of `api_key`
emb = VoyageAIEmbeddings(voyage_api_key="NOT_A_VALID_KEY", model=MODEL)
assert isinstance(emb, Embeddings)
assert emb.batch_size == 72
@ -18,7 +29,7 @@ def test_initialization_voyage_2() -> None:
def test_initialization_voyage_1() -> None:
"""Test embedding model initialization."""
emb = VoyageAIEmbeddings(voyage_api_key="NOT_A_VALID_KEY", model="voyage-01")
emb = VoyageAIEmbeddings(api_key="NOT_A_VALID_KEY", model="voyage-01")
assert isinstance(emb, Embeddings)
assert emb.batch_size == 7
assert emb.model == "voyage-01"
@ -28,7 +39,7 @@ def test_initialization_voyage_1() -> None:
def test_initialization_voyage_1_batch_size() -> None:
"""Test embedding model initialization."""
emb = VoyageAIEmbeddings(
voyage_api_key="NOT_A_VALID_KEY", model="voyage-01", batch_size=15
api_key="NOT_A_VALID_KEY", model="voyage-01", batch_size=15
)
assert isinstance(emb, Embeddings)
assert emb.batch_size == 15