From 8e0dcb37d2ffe01b83ed58f1028dab449a258379 Mon Sep 17 00:00:00 2001 From: Praveen Venkateswaran Date: Mon, 6 Nov 2023 14:13:59 -0500 Subject: [PATCH] Add SecretStr for Symbl.ai Nebula API (#12896) Description: This PR masks API key secrets for the Nebula model from Symbl.ai Issue: #12165 Maintainer: @eyurtsev --------- Co-authored-by: Praveen Venkateswaran --- .../langchain/llms/symblai_nebula.py | 16 ++++++---- .../unit_tests/llms/test_symblai_nebula.py | 29 +++++++++++++++++++ 2 files changed, 39 insertions(+), 6 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/llms/test_symblai_nebula.py diff --git a/libs/langchain/langchain/llms/symblai_nebula.py b/libs/langchain/langchain/llms/symblai_nebula.py index 8d33e1a42c..cefedd4cc6 100644 --- a/libs/langchain/langchain/llms/symblai_nebula.py +++ b/libs/langchain/langchain/llms/symblai_nebula.py @@ -15,8 +15,9 @@ from tenacity import ( from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, root_validator -from langchain.utils import get_from_dict_or_env +from langchain.pydantic_v1 import Extra, SecretStr, root_validator +from langchain.utils import convert_to_secret_str +from langchain.utils.env import get_from_dict_or_env DEFAULT_NEBULA_SERVICE_URL = "https://api-nebula.symbl.ai" DEFAULT_NEBULA_SERVICE_PATH = "/v1/model/generate" @@ -50,7 +51,7 @@ class Nebula(LLM): nebula_service_url: Optional[str] = None nebula_service_path: Optional[str] = None - nebula_api_key: Optional[str] = None + nebula_api_key: Optional[SecretStr] = None model: Optional[str] = None max_new_tokens: Optional[int] = 128 temperature: Optional[float] = 0.6 @@ -81,8 +82,8 @@ class Nebula(LLM): "NEBULA_SERVICE_PATH", DEFAULT_NEBULA_SERVICE_PATH, ) - nebula_api_key = get_from_dict_or_env( - values, "nebula_api_key", "NEBULA_API_KEY", None + nebula_api_key = convert_to_secret_str( + get_from_dict_or_env(values, "nebula_api_key", "NEBULA_API_KEY", None) ) if nebula_service_url.endswith("/"): @@ -187,9 +188,12 @@ def make_request( ) -> Any: """Generate text from the model.""" params = params or {} + api_key = None + if self.nebula_api_key is not None: + api_key = self.nebula_api_key.get_secret_value() headers = { "Content-Type": "application/json", - "ApiKey": f"{self.nebula_api_key}", + "ApiKey": f"{api_key}", } body = { diff --git a/libs/langchain/tests/unit_tests/llms/test_symblai_nebula.py b/libs/langchain/tests/unit_tests/llms/test_symblai_nebula.py new file mode 100644 index 0000000000..ab53377a77 --- /dev/null +++ b/libs/langchain/tests/unit_tests/llms/test_symblai_nebula.py @@ -0,0 +1,29 @@ +"""Test the Nebula model by Symbl.ai""" + +from pytest import CaptureFixture, MonkeyPatch + +from langchain.llms.symblai_nebula import Nebula +from langchain.pydantic_v1 import SecretStr + + +def test_api_key_is_secret_string() -> None: + llm = Nebula(nebula_api_key="secret-api-key") + assert isinstance(llm.nebula_api_key, SecretStr) + assert llm.nebula_api_key.get_secret_value() == "secret-api-key" + + +def test_api_key_masked_when_passed_from_env( + monkeypatch: MonkeyPatch, capsys: CaptureFixture +) -> None: + monkeypatch.setenv("NEBULA_API_KEY", "secret-api-key") + llm = Nebula() + print(llm.nebula_api_key, end="") + captured = capsys.readouterr() + assert captured.out == "**********" + + +def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None: + llm = Nebula(nebula_api_key="secret-api-key") + print(llm.nebula_api_key, end="") + captured = capsys.readouterr() + assert captured.out == "**********"