mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
pinecone[patch]: Upgrade @root_validators to be consistent with pydantic 2 (#25453)
Upgrade root validators for pydantic 2 migration
This commit is contained in:
parent
b297af5482
commit
34da8be60b
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
import aiohttp
|
||||
@ -10,7 +9,7 @@ from langchain_core.pydantic_v1 import (
|
||||
SecretStr,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.utils import convert_to_secret_str
|
||||
from langchain_core.utils import secret_from_env
|
||||
from pinecone import Pinecone as PineconeClient # type: ignore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -45,10 +44,21 @@ class PineconeEmbeddings(BaseModel, Embeddings):
|
||||
dimension: Optional[int] = None
|
||||
#
|
||||
show_progress_bar: bool = False
|
||||
pinecone_api_key: Optional[SecretStr] = None
|
||||
pinecone_api_key: Optional[SecretStr] = Field(
|
||||
default_factory=secret_from_env(
|
||||
"PINECONE_API_KEY",
|
||||
error_message="Pinecone API key not found. Please set the PINECONE_API_KEY "
|
||||
"environment variable or pass it via `pinecone_api_key`.",
|
||||
),
|
||||
alias="api_key",
|
||||
)
|
||||
"""Pinecone API key.
|
||||
|
||||
If not provided, will look for the PINECONE_API_KEY environment variable."""
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_default_config(cls, values: dict) -> dict:
|
||||
@ -69,25 +79,10 @@ class PineconeEmbeddings(BaseModel, Embeddings):
|
||||
values[key] = value
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
@root_validator(pre=False, skip_on_failure=True)
|
||||
def validate_environment(cls, values: dict) -> dict:
|
||||
"""Validate that Pinecone version and credentials exist in environment."""
|
||||
|
||||
pinecone_api_key = values.get("pinecone_api_key") or os.getenv(
|
||||
"PINECONE_API_KEY", None
|
||||
)
|
||||
if pinecone_api_key:
|
||||
api_key_secretstr = convert_to_secret_str(pinecone_api_key)
|
||||
values["pinecone_api_key"] = api_key_secretstr
|
||||
|
||||
api_key_str = api_key_secretstr.get_secret_value()
|
||||
else:
|
||||
api_key_str = None
|
||||
if api_key_str is None:
|
||||
raise ValueError(
|
||||
"Pinecone API key not found. Please set the PINECONE_API_KEY "
|
||||
"environment variable or pass it via `pinecone_api_key`."
|
||||
)
|
||||
api_key_str = values["pinecone_api_key"].get_secret_value()
|
||||
client = PineconeClient(api_key=api_key_str, source_tag="langchain")
|
||||
values["_client"] = client
|
||||
|
||||
|
@ -7,10 +7,22 @@ MODEL_NAME = "multilingual-e5-large"
|
||||
|
||||
|
||||
def test_default_config() -> None:
|
||||
e = PineconeEmbeddings(pinecone_api_key=API_KEY, model=MODEL_NAME)
|
||||
e = PineconeEmbeddings(
|
||||
pinecone_api_key=API_KEY, # type: ignore[call-arg]
|
||||
model=MODEL_NAME,
|
||||
)
|
||||
assert e.batch_size == 96
|
||||
|
||||
|
||||
def test_default_config_with_api_key() -> None:
|
||||
e = PineconeEmbeddings(api_key=API_KEY, model=MODEL_NAME)
|
||||
assert e.batch_size == 96
|
||||
|
||||
|
||||
def test_custom_config() -> None:
|
||||
e = PineconeEmbeddings(pinecone_api_key=API_KEY, model=MODEL_NAME, batch_size=128)
|
||||
e = PineconeEmbeddings(
|
||||
pinecone_api_key=API_KEY, # type: ignore[call-arg]
|
||||
model=MODEL_NAME,
|
||||
batch_size=128,
|
||||
)
|
||||
assert e.batch_size == 128
|
||||
|
Loading…
Reference in New Issue
Block a user