From 3cc1da2b38d3c468d82feb0920e4f8b4ddf2cfe2 Mon Sep 17 00:00:00 2001 From: chyroc Date: Wed, 27 Dec 2023 04:57:37 +0800 Subject: [PATCH] Refactor: use SecretStr for Petals llms (#15121) --- .../langchain_community/llms/petals.py | 12 ++++++------ .../integration_tests/llms/test_petals.py | 18 ++++++++++++++++++ 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/libs/community/langchain_community/llms/petals.py b/libs/community/langchain_community/llms/petals.py index 1508d9d303..9112d18a1d 100644 --- a/libs/community/langchain_community/llms/petals.py +++ b/libs/community/langchain_community/llms/petals.py @@ -3,8 +3,8 @@ from typing import Any, Dict, List, Mapping, Optional from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM -from langchain_core.pydantic_v1 import Extra, Field, root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from langchain_community.llms.utils import enforce_stop_tokens @@ -60,7 +60,7 @@ class Petals(LLM): """Holds any model parameters valid for `create` call not explicitly specified.""" - huggingface_api_key: Optional[str] = None + huggingface_api_key: Optional[SecretStr] = None class Config: """Configuration for this pydantic config.""" @@ -89,8 +89,8 @@ class Petals(LLM): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - huggingface_api_key = get_from_dict_or_env( - values, "huggingface_api_key", "HUGGINGFACE_API_KEY" + huggingface_api_key = convert_to_secret_str( + get_from_dict_or_env(values, "huggingface_api_key", "HUGGINGFACE_API_KEY") ) try: from petals import AutoDistributedModelForCausalLM @@ -101,7 +101,7 @@ class Petals(LLM): values["client"] = AutoDistributedModelForCausalLM.from_pretrained( model_name ) - values["huggingface_api_key"] = huggingface_api_key + values["huggingface_api_key"] = huggingface_api_key.get_secret_value() except ImportError: raise ImportError( diff --git a/libs/community/tests/integration_tests/llms/test_petals.py b/libs/community/tests/integration_tests/llms/test_petals.py index 97e86dec8b..774b56fdef 100644 --- a/libs/community/tests/integration_tests/llms/test_petals.py +++ b/libs/community/tests/integration_tests/llms/test_petals.py @@ -1,8 +1,26 @@ """Test Petals API wrapper.""" +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture + from langchain_community.llms.petals import Petals +def test_api_key_is_string() -> None: + llm = Petals(huggingface_api_key="secret-api-key") + assert isinstance(llm.huggingface_api_key, SecretStr) + + +def test_api_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + llm = Petals(huggingface_api_key="secret-api-key") + print(llm.huggingface_api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + + def test_gooseai_call() -> None: """Test valid call to gooseai.""" llm = Petals(max_new_tokens=10)