Mask API key for Minimax LLM (#14309)

- **Description:** Added masking for the API key for Minimax LLM + tests
inspired by https://github.com/langchain-ai/langchain/pull/12418.
- **Issue:** the issue # fixes
https://github.com/langchain-ai/langchain/issues/12165
- **Dependencies:** this fix is dependent on Minimax instantiation fix
which is introduced in
https://github.com/langchain-ai/langchain/pull/13439, so merge this one
after.
  - **Tag maintainer:** @eyurtsev

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Ran 2023-12-06 01:42:00 +02:00 committed by GitHub
parent 29e993a5f2
commit d22c13ec48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 55 additions and 16 deletions

View File

@ -10,14 +10,14 @@ from typing import (
)
import requests
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
logger = logging.getLogger(__name__)
@ -27,7 +27,7 @@ class _MinimaxEndpointClient(BaseModel):
host: str
group_id: str
api_key: str
api_key: SecretStr
api_url: str
@root_validator(pre=True, allow_reuse=True)
@ -40,7 +40,7 @@ class _MinimaxEndpointClient(BaseModel):
return values
def post(self, request: Any) -> Any:
headers = {"Authorization": f"Bearer {self.api_key}"}
headers = {"Authorization": f"Bearer {self.api_key.get_secret_value()}"}
response = requests.post(self.api_url, headers=headers, json=request)
# TODO: error handling and automatic retries
if not response.ok:
@ -56,7 +56,7 @@ class _MinimaxEndpointClient(BaseModel):
class MinimaxCommon(BaseModel):
"""Common parameters for Minimax large language models."""
_client: Any = None
_client: _MinimaxEndpointClient
model: str = "abab5.5-chat"
"""Model name to use."""
max_tokens: int = 256
@ -69,13 +69,13 @@ class MinimaxCommon(BaseModel):
"""Holds any model parameters valid for `create` call not explicitly specified."""
minimax_api_host: Optional[str] = None
minimax_group_id: Optional[str] = None
minimax_api_key: Optional[str] = None
minimax_api_key: Optional[SecretStr] = None
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["minimax_api_key"] = get_from_dict_or_env(
values, "minimax_api_key", "MINIMAX_API_KEY"
values["minimax_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "minimax_api_key", "MINIMAX_API_KEY")
)
values["minimax_group_id"] = get_from_dict_or_env(
values, "minimax_group_id", "MINIMAX_GROUP_ID"
@ -87,6 +87,11 @@ class MinimaxCommon(BaseModel):
"MINIMAX_API_HOST",
default="https://api.minimax.chat",
)
values["_client"] = _MinimaxEndpointClient(
host=values["minimax_api_host"],
api_key=values["minimax_api_key"],
group_id=values["minimax_group_id"],
)
return values
@property
@ -110,14 +115,6 @@ class MinimaxCommon(BaseModel):
"""Return type of llm."""
return "minimax"
def __init__(self, **data: Any):
super().__init__(**data)
self._client = _MinimaxEndpointClient(
host=self.minimax_api_host,
api_key=self.minimax_api_key,
group_id=self.minimax_group_id,
)
class Minimax(MinimaxCommon, LLM):
"""Wrapper around Minimax large language models.

View File

@ -0,0 +1,42 @@
"""Test Minimax llm"""
from typing import cast
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch
from langchain.llms.minimax import Minimax
def test_api_key_is_secret_string() -> None:
llm = Minimax(minimax_api_key="secret-api-key", minimax_group_id="group_id")
assert isinstance(llm.minimax_api_key, SecretStr)
def test_api_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("MINIMAX_API_KEY", "secret-api-key")
monkeypatch.setenv("MINIMAX_GROUP_ID", "group_id")
llm = Minimax()
print(llm.minimax_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
def test_api_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test initialization with an API key provided via the initializer"""
llm = Minimax(minimax_api_key="secret-api-key", minimax_group_id="group_id")
print(llm.minimax_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
def test_uses_actual_secret_value_from_secretstr() -> None:
"""Test that actual secret is retrieved using `.get_secret_value()`."""
llm = Minimax(minimax_api_key="secret-api-key", minimax_group_id="group_id")
assert cast(SecretStr, llm.minimax_api_key).get_secret_value() == "secret-api-key"