mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Mask API key for Minimax LLM (#14309)
- **Description:** Added masking for the API key for Minimax LLM + tests inspired by https://github.com/langchain-ai/langchain/pull/12418. - **Issue:** the issue # fixes https://github.com/langchain-ai/langchain/issues/12165 - **Dependencies:** this fix is dependent on Minimax instantiation fix which is introduced in https://github.com/langchain-ai/langchain/pull/13439, so merge this one after. - **Tag maintainer:** @eyurtsev --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
29e993a5f2
commit
d22c13ec48
@ -10,14 +10,14 @@ from typing import (
|
||||
)
|
||||
|
||||
import requests
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -27,7 +27,7 @@ class _MinimaxEndpointClient(BaseModel):
|
||||
|
||||
host: str
|
||||
group_id: str
|
||||
api_key: str
|
||||
api_key: SecretStr
|
||||
api_url: str
|
||||
|
||||
@root_validator(pre=True, allow_reuse=True)
|
||||
@ -40,7 +40,7 @@ class _MinimaxEndpointClient(BaseModel):
|
||||
return values
|
||||
|
||||
def post(self, request: Any) -> Any:
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
headers = {"Authorization": f"Bearer {self.api_key.get_secret_value()}"}
|
||||
response = requests.post(self.api_url, headers=headers, json=request)
|
||||
# TODO: error handling and automatic retries
|
||||
if not response.ok:
|
||||
@ -56,7 +56,7 @@ class _MinimaxEndpointClient(BaseModel):
|
||||
class MinimaxCommon(BaseModel):
|
||||
"""Common parameters for Minimax large language models."""
|
||||
|
||||
_client: Any = None
|
||||
_client: _MinimaxEndpointClient
|
||||
model: str = "abab5.5-chat"
|
||||
"""Model name to use."""
|
||||
max_tokens: int = 256
|
||||
@ -69,13 +69,13 @@ class MinimaxCommon(BaseModel):
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
minimax_api_host: Optional[str] = None
|
||||
minimax_group_id: Optional[str] = None
|
||||
minimax_api_key: Optional[str] = None
|
||||
minimax_api_key: Optional[SecretStr] = None
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["minimax_api_key"] = get_from_dict_or_env(
|
||||
values, "minimax_api_key", "MINIMAX_API_KEY"
|
||||
values["minimax_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "minimax_api_key", "MINIMAX_API_KEY")
|
||||
)
|
||||
values["minimax_group_id"] = get_from_dict_or_env(
|
||||
values, "minimax_group_id", "MINIMAX_GROUP_ID"
|
||||
@ -87,6 +87,11 @@ class MinimaxCommon(BaseModel):
|
||||
"MINIMAX_API_HOST",
|
||||
default="https://api.minimax.chat",
|
||||
)
|
||||
values["_client"] = _MinimaxEndpointClient(
|
||||
host=values["minimax_api_host"],
|
||||
api_key=values["minimax_api_key"],
|
||||
group_id=values["minimax_group_id"],
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
@ -110,14 +115,6 @@ class MinimaxCommon(BaseModel):
|
||||
"""Return type of llm."""
|
||||
return "minimax"
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
self._client = _MinimaxEndpointClient(
|
||||
host=self.minimax_api_host,
|
||||
api_key=self.minimax_api_key,
|
||||
group_id=self.minimax_group_id,
|
||||
)
|
||||
|
||||
|
||||
class Minimax(MinimaxCommon, LLM):
|
||||
"""Wrapper around Minimax large language models.
|
||||
|
42
libs/langchain/tests/unit_tests/llms/test_minimax.py
Normal file
42
libs/langchain/tests/unit_tests/llms/test_minimax.py
Normal file
@ -0,0 +1,42 @@
|
||||
"""Test Minimax llm"""
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from pytest import CaptureFixture, MonkeyPatch
|
||||
|
||||
from langchain.llms.minimax import Minimax
|
||||
|
||||
|
||||
def test_api_key_is_secret_string() -> None:
|
||||
llm = Minimax(minimax_api_key="secret-api-key", minimax_group_id="group_id")
|
||||
assert isinstance(llm.minimax_api_key, SecretStr)
|
||||
|
||||
|
||||
def test_api_key_masked_when_passed_from_env(
|
||||
monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
||||
) -> None:
|
||||
"""Test initialization with an API key provided via an env variable"""
|
||||
monkeypatch.setenv("MINIMAX_API_KEY", "secret-api-key")
|
||||
monkeypatch.setenv("MINIMAX_GROUP_ID", "group_id")
|
||||
llm = Minimax()
|
||||
print(llm.minimax_api_key, end="")
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
def test_api_key_masked_when_passed_via_constructor(
|
||||
capsys: CaptureFixture,
|
||||
) -> None:
|
||||
"""Test initialization with an API key provided via the initializer"""
|
||||
llm = Minimax(minimax_api_key="secret-api-key", minimax_group_id="group_id")
|
||||
print(llm.minimax_api_key, end="")
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
def test_uses_actual_secret_value_from_secretstr() -> None:
|
||||
"""Test that actual secret is retrieved using `.get_secret_value()`."""
|
||||
llm = Minimax(minimax_api_key="secret-api-key", minimax_group_id="group_id")
|
||||
assert cast(SecretStr, llm.minimax_api_key).get_secret_value() == "secret-api-key"
|
Loading…
Reference in New Issue
Block a user