pinecone[patch]: Upgrade @root_validators to be consistent with pydantic 2 (#25453)

Upgrade root validators for pydantic 2 migration
This commit is contained in:
Eugene Yurtsev 2024-08-15 15:45:14 -04:00 committed by GitHub
parent b297af5482
commit 34da8be60b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 29 additions and 22 deletions

View File

@ -1,5 +1,4 @@
import logging
import os
from typing import Dict, Iterable, List, Optional
import aiohttp
@ -10,7 +9,7 @@ from langchain_core.pydantic_v1 import (
SecretStr,
root_validator,
)
from langchain_core.utils import convert_to_secret_str
from langchain_core.utils import secret_from_env
from pinecone import Pinecone as PineconeClient # type: ignore
logger = logging.getLogger(__name__)
@ -45,10 +44,21 @@ class PineconeEmbeddings(BaseModel, Embeddings):
dimension: Optional[int] = None
#
show_progress_bar: bool = False
pinecone_api_key: Optional[SecretStr] = None
pinecone_api_key: Optional[SecretStr] = Field(
default_factory=secret_from_env(
"PINECONE_API_KEY",
error_message="Pinecone API key not found. Please set the PINECONE_API_KEY "
"environment variable or pass it via `pinecone_api_key`.",
),
alias="api_key",
)
"""Pinecone API key.
If not provided, will look for the PINECONE_API_KEY environment variable."""
class Config:
extra = "forbid"
allow_population_by_field_name = True
@root_validator(pre=True)
def set_default_config(cls, values: dict) -> dict:
@ -69,25 +79,10 @@ class PineconeEmbeddings(BaseModel, Embeddings):
values[key] = value
return values
@root_validator()
@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: dict) -> dict:
"""Validate that Pinecone version and credentials exist in environment."""
pinecone_api_key = values.get("pinecone_api_key") or os.getenv(
"PINECONE_API_KEY", None
)
if pinecone_api_key:
api_key_secretstr = convert_to_secret_str(pinecone_api_key)
values["pinecone_api_key"] = api_key_secretstr
api_key_str = api_key_secretstr.get_secret_value()
else:
api_key_str = None
if api_key_str is None:
raise ValueError(
"Pinecone API key not found. Please set the PINECONE_API_KEY "
"environment variable or pass it via `pinecone_api_key`."
)
api_key_str = values["pinecone_api_key"].get_secret_value()
client = PineconeClient(api_key=api_key_str, source_tag="langchain")
values["_client"] = client

View File

@ -7,10 +7,22 @@ MODEL_NAME = "multilingual-e5-large"
def test_default_config() -> None:
e = PineconeEmbeddings(pinecone_api_key=API_KEY, model=MODEL_NAME)
e = PineconeEmbeddings(
pinecone_api_key=API_KEY, # type: ignore[call-arg]
model=MODEL_NAME,
)
assert e.batch_size == 96
def test_default_config_with_api_key() -> None:
e = PineconeEmbeddings(api_key=API_KEY, model=MODEL_NAME)
assert e.batch_size == 96
def test_custom_config() -> None:
e = PineconeEmbeddings(pinecone_api_key=API_KEY, model=MODEL_NAME, batch_size=128)
e = PineconeEmbeddings(
pinecone_api_key=API_KEY, # type: ignore[call-arg]
model=MODEL_NAME,
batch_size=128,
)
assert e.batch_size == 128