mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
voyageai[patch]: Upgrade root validators for pydantic 2 (#25455)
Update @root_validators to be consistent with pydantic 2 semantics
This commit is contained in:
parent
4cdaca67dc
commit
b297af5482
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
import voyageai # type: ignore
|
||||
@ -10,7 +9,7 @@ from langchain_core.pydantic_v1 import (
|
||||
SecretStr,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.utils import convert_to_secret_str
|
||||
from langchain_core.utils import secret_from_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -32,34 +31,32 @@ class VoyageAIEmbeddings(BaseModel, Embeddings):
|
||||
batch_size: int
|
||||
show_progress_bar: bool = False
|
||||
truncation: Optional[bool] = None
|
||||
voyage_api_key: Optional[SecretStr] = None
|
||||
voyage_api_key: SecretStr = Field(
|
||||
alias="api_key",
|
||||
default_factory=secret_from_env(
|
||||
"VOYAGE_API_KEY",
|
||||
error_message="Must set `VOYAGE_API_KEY` environment variable or "
|
||||
"pass `api_key` to VoyageAIEmbeddings constructor.",
|
||||
),
|
||||
)
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def default_values(cls, values: dict) -> dict:
|
||||
"""Set default batch size based on model"""
|
||||
|
||||
model = values.get("model")
|
||||
batch_size = values.get("batch_size")
|
||||
if batch_size is None:
|
||||
values["batch_size"] = 72 if model in ["voyage-2", "voyage-02"] else 7
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_environment(cls, values: dict) -> dict:
|
||||
"""Validate that VoyageAI credentials exist in environment."""
|
||||
voyage_api_key = values.get("voyage_api_key") or os.getenv(
|
||||
"VOYAGE_API_KEY", None
|
||||
)
|
||||
if voyage_api_key:
|
||||
api_key_secretstr = convert_to_secret_str(voyage_api_key)
|
||||
values["voyage_api_key"] = api_key_secretstr
|
||||
|
||||
api_key_str = api_key_secretstr.get_secret_value()
|
||||
else:
|
||||
api_key_str = None
|
||||
api_key_str = values["voyage_api_key"].get_secret_value()
|
||||
values["_client"] = voyageai.Client(api_key=api_key_str)
|
||||
values["_aclient"] = voyageai.client_async.AsyncClient(api_key=api_key_str)
|
||||
return values
|
||||
|
@ -9,6 +9,17 @@ MODEL = "voyage-2"
|
||||
|
||||
def test_initialization_voyage_2() -> None:
|
||||
"""Test embedding model initialization."""
|
||||
emb = VoyageAIEmbeddings(api_key="NOT_A_VALID_KEY", model=MODEL)
|
||||
assert isinstance(emb, Embeddings)
|
||||
assert emb.batch_size == 72
|
||||
assert emb.model == MODEL
|
||||
assert emb._client is not None
|
||||
|
||||
|
||||
def test_initialization_voyage_2_with_full_api_key_name() -> None:
|
||||
"""Test embedding model initialization."""
|
||||
# Testing that we can initialize the model using `voyage_api_key`
|
||||
# instead of `api_key`
|
||||
emb = VoyageAIEmbeddings(voyage_api_key="NOT_A_VALID_KEY", model=MODEL)
|
||||
assert isinstance(emb, Embeddings)
|
||||
assert emb.batch_size == 72
|
||||
@ -18,7 +29,7 @@ def test_initialization_voyage_2() -> None:
|
||||
|
||||
def test_initialization_voyage_1() -> None:
|
||||
"""Test embedding model initialization."""
|
||||
emb = VoyageAIEmbeddings(voyage_api_key="NOT_A_VALID_KEY", model="voyage-01")
|
||||
emb = VoyageAIEmbeddings(api_key="NOT_A_VALID_KEY", model="voyage-01")
|
||||
assert isinstance(emb, Embeddings)
|
||||
assert emb.batch_size == 7
|
||||
assert emb.model == "voyage-01"
|
||||
@ -28,7 +39,7 @@ def test_initialization_voyage_1() -> None:
|
||||
def test_initialization_voyage_1_batch_size() -> None:
|
||||
"""Test embedding model initialization."""
|
||||
emb = VoyageAIEmbeddings(
|
||||
voyage_api_key="NOT_A_VALID_KEY", model="voyage-01", batch_size=15
|
||||
api_key="NOT_A_VALID_KEY", model="voyage-01", batch_size=15
|
||||
)
|
||||
assert isinstance(emb, Embeddings)
|
||||
assert emb.batch_size == 15
|
||||
|
Loading…
Reference in New Issue
Block a user