diff --git a/libs/community/langchain_community/llms/bananadev.py b/libs/community/langchain_community/llms/bananadev.py index 88ab7f5e58..43ee44b83d 100644 --- a/libs/community/langchain_community/llms/bananadev.py +++ b/libs/community/langchain_community/llms/bananadev.py @@ -1,10 +1,10 @@ import logging -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, List, Mapping, Optional, cast from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import Extra, Field, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_community.llms.utils import enforce_stop_tokens @@ -38,7 +38,7 @@ class Banana(LLM): """Holds any model parameters valid for `create` call not explicitly specified.""" - banana_api_key: Optional[str] = None + banana_api_key: Optional[SecretStr] = None class Config: """Configuration for this pydantic config.""" @@ -66,8 +66,8 @@ class Banana(LLM): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - banana_api_key = get_from_dict_or_env( - values, "banana_api_key", "BANANA_API_KEY" + banana_api_key = convert_to_secret_str( + get_from_dict_or_env(values, "banana_api_key", "BANANA_API_KEY") ) values["banana_api_key"] = banana_api_key return values @@ -103,7 +103,7 @@ class Banana(LLM): ) params = self.model_kwargs or {} params = {**params, **kwargs} - api_key = self.banana_api_key + api_key = cast(SecretStr, self.banana_api_key) model_key = self.model_key model_url_slug = self.model_url_slug model_inputs = { @@ -113,7 +113,7 @@ class Banana(LLM): } model = Client( # Found in main dashboard - api_key=api_key, + api_key=api_key.get_secret_value(), # Both found in model details page model_key=model_key, url=f"https://{model_url_slug}.run.banana.dev", diff --git a/libs/community/tests/unit_tests/llms/test_bananadev.py b/libs/community/tests/unit_tests/llms/test_bananadev.py new file mode 100644 index 0000000000..ea219cc815 --- /dev/null +++ b/libs/community/tests/unit_tests/llms/test_bananadev.py @@ -0,0 +1,41 @@ +"""Test Banana llm""" +from typing import cast + +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture, MonkeyPatch + +from langchain_community.llms.bananadev import Banana + + +def test_api_key_is_secret_string() -> None: + llm = Banana(banana_api_key="secret-api-key") + assert isinstance(llm.banana_api_key, SecretStr) + + +def test_api_key_masked_when_passed_from_env( + monkeypatch: MonkeyPatch, capsys: CaptureFixture +) -> None: + """Test initialization with an API key provided via an env variable""" + monkeypatch.setenv("BANANA_API_KEY", "secret-api-key") + llm = Banana() + print(llm.banana_api_key, end="") # noqa: T201 + captured = capsys.readouterr() + + assert captured.out == "**********" + + +def test_api_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + """Test initialization with an API key provided via the initializer""" + llm = Banana(banana_api_key="secret-api-key") + print(llm.banana_api_key, end="") # noqa: T201 + captured = capsys.readouterr() + + assert captured.out == "**********" + + +def test_uses_actual_secret_value_from_secretstr() -> None: + """Test that actual secret is retrieved using `.get_secret_value()`.""" + llm = Banana(banana_api_key="secret-api-key") + assert cast(SecretStr, llm.banana_api_key).get_secret_value() == "secret-api-key"