feat: mask api key for cerebriumai llm (#14272)

- **Description:** Masking API key for CerebriumAI LLM to protect user
secrets.
 - **Issue:** #12165 
 - **Dependencies:** None
 - **Tag maintainer:** @eyurtsev

---------

Signed-off-by: Yuchen Liang <yuchenl3@andrew.cmu.edu>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
pull/14361/head
Yuchen Liang 6 months ago committed by GitHub
parent d4d64daa1e
commit ad6dfb6220
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,13 +1,13 @@
import logging
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Dict, List, Mapping, Optional, cast
import requests
from langchain_core.pydantic_v1 import Extra, Field, root_validator
from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
logger = logging.getLogger(__name__)
@ -15,8 +15,9 @@ logger = logging.getLogger(__name__)
class CerebriumAI(LLM):
"""CerebriumAI large language models.
To use, you should have the ``cerebrium`` python package installed, and the
environment variable ``CEREBRIUMAI_API_KEY`` set with your API key.
To use, you should have the ``cerebrium`` python package installed.
You should also have the environment variable ``CEREBRIUMAI_API_KEY``
set with your API key or pass it as a named argument in the constructor.
Any parameters that are valid to be passed to the call can be passed
in, even if not explicitly saved on this class.
@ -25,7 +26,7 @@ class CerebriumAI(LLM):
.. code-block:: python
from langchain.llms import CerebriumAI
cerebrium = CerebriumAI(endpoint_url="")
cerebrium = CerebriumAI(endpoint_url="", cerebriumai_api_key="my-api-key")
"""
@ -36,7 +37,7 @@ class CerebriumAI(LLM):
"""Holds any model parameters valid for `create` call not
explicitly specified."""
cerebriumai_api_key: Optional[str] = None
cerebriumai_api_key: Optional[SecretStr] = None
class Config:
"""Configuration for this pydantic config."""
@ -64,8 +65,8 @@ class CerebriumAI(LLM):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
cerebriumai_api_key = get_from_dict_or_env(
values, "cerebriumai_api_key", "CEREBRIUMAI_API_KEY"
cerebriumai_api_key = convert_to_secret_str(
get_from_dict_or_env(values, "cerebriumai_api_key", "CEREBRIUMAI_API_KEY")
)
values["cerebriumai_api_key"] = cerebriumai_api_key
return values
@ -91,7 +92,9 @@ class CerebriumAI(LLM):
**kwargs: Any,
) -> str:
headers: Dict = {
"Authorization": self.cerebriumai_api_key,
"Authorization": cast(
SecretStr, self.cerebriumai_api_key
).get_secret_value(),
"Content-Type": "application/json",
}
params = self.model_kwargs or {}

@ -0,0 +1,33 @@
"""Test CerebriumAI llm"""
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch
from langchain.llms.cerebriumai import CerebriumAI
def test_api_key_is_secret_string() -> None:
llm = CerebriumAI(cerebriumai_api_key="test-cerebriumai-api-key")
assert isinstance(llm.cerebriumai_api_key, SecretStr)
def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None:
llm = CerebriumAI(cerebriumai_api_key="secret-api-key")
print(llm.cerebriumai_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
assert repr(llm.cerebriumai_api_key) == "SecretStr('**********')"
def test_api_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
monkeypatch.setenv("CEREBRIUMAI_API_KEY", "secret-api-key")
llm = CerebriumAI()
print(llm.cerebriumai_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
assert repr(llm.cerebriumai_api_key) == "SecretStr('**********')"
Loading…
Cancel
Save