diff --git a/libs/community/langchain_community/embeddings/edenai.py b/libs/community/langchain_community/embeddings/edenai.py index 9d12376fc2..3d6b1ec16d 100644 --- a/libs/community/langchain_community/embeddings/edenai.py +++ b/libs/community/langchain_community/embeddings/edenai.py @@ -1,8 +1,14 @@ from typing import Any, Dict, List, Optional from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import ( + BaseModel, + Extra, + Field, + SecretStr, + root_validator, +) +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_community.utilities.requests import Requests @@ -13,7 +19,7 @@ class EdenAiEmbeddings(BaseModel, Embeddings): it as a named parameter. """ - edenai_api_key: Optional[str] = Field(None, description="EdenAI API Token") + edenai_api_key: Optional[SecretStr] = Field(None, description="EdenAI API Token") provider: str = "openai" """embedding provider to use (eg: openai,google etc.)""" @@ -32,8 +38,8 @@ class EdenAiEmbeddings(BaseModel, Embeddings): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" - values["edenai_api_key"] = get_from_dict_or_env( - values, "edenai_api_key", "EDENAI_API_KEY" + values["edenai_api_key"] = convert_to_secret_str( + get_from_dict_or_env(values, "edenai_api_key", "EDENAI_API_KEY") ) return values @@ -50,7 +56,7 @@ class EdenAiEmbeddings(BaseModel, Embeddings): headers = { "accept": "application/json", "content-type": "application/json", - "authorization": f"Bearer {self.edenai_api_key}", + "authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", "User-Agent": self.get_user_agent(), } diff --git a/libs/community/tests/unit_tests/embeddings/test_edenai.py b/libs/community/tests/unit_tests/embeddings/test_edenai.py new file mode 100644 index 0000000000..6616f3ef13 --- /dev/null +++ b/libs/community/tests/unit_tests/embeddings/test_edenai.py @@ -0,0 +1,21 @@ +"""Test EdenAiEmbeddings embeddings""" + +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture + +from langchain_community.embeddings import EdenAiEmbeddings + + +def test_api_key_is_string() -> None: + llm = EdenAiEmbeddings(edenai_api_key="secret-api-key") + assert isinstance(llm.edenai_api_key, SecretStr) + + +def test_api_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + llm = EdenAiEmbeddings(edenai_api_key="secret-api-key") + print(llm.edenai_api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********"