Mask API key for Anyscale LLM (#12406)

Description: Add masking of API Key for Anyscale LLM when printed.
Issue: #12165 
Dependencies: None
Tag maintainer: @eyurtsev

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
Aidos Kanapyanov 2023-11-01 20:22:26 +06:00 committed by GitHub
parent 5ae51a8a85
commit ae63c186af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 14 deletions

View File

@ -13,9 +13,9 @@ from langchain.chat_models.openai import (
ChatOpenAI,
_import_tiktoken,
)
from langchain.pydantic_v1 import Field, root_validator
from langchain.pydantic_v1 import Field, SecretStr, root_validator
from langchain.schema.messages import BaseMessage
from langchain.utils import get_from_dict_or_env
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
if TYPE_CHECKING:
import tiktoken
@ -53,7 +53,7 @@ class ChatAnyscale(ChatOpenAI):
def lc_secrets(self) -> Dict[str, str]:
return {"anyscale_api_key": "ANYSCALE_API_KEY"}
anyscale_api_key: Optional[str] = None
anyscale_api_key: Optional[SecretStr] = None
"""AnyScale Endpoints API keys."""
model_name: str = Field(default=DEFAULT_MODEL, alias="model")
"""Model name to use."""
@ -98,11 +98,13 @@ class ChatAnyscale(ChatOpenAI):
@root_validator(pre=True)
def validate_environment_override(cls, values: dict) -> dict:
"""Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env(
values["openai_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"anyscale_api_key",
"ANYSCALE_API_KEY",
)
)
values["openai_api_base"] = get_from_dict_or_env(
values,
"anyscale_api_base",
@ -138,7 +140,7 @@ class ChatAnyscale(ChatOpenAI):
model_name = values["model_name"]
available_models = cls.get_available_models(
values["openai_api_key"],
values["openai_api_key"].get_secret_value(),
values["openai_api_base"],
)

View File

@ -9,6 +9,7 @@ from typing import (
Optional,
Set,
Tuple,
cast,
)
from langchain.callbacks.manager import (
@ -20,10 +21,10 @@ from langchain.llms.openai import (
acompletion_with_retry,
completion_with_retry,
)
from langchain.pydantic_v1 import Field, root_validator
from langchain.pydantic_v1 import Field, SecretStr, root_validator
from langchain.schema import Generation, LLMResult
from langchain.schema.output import GenerationChunk
from langchain.utils import get_from_dict_or_env
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
def update_token_usage(
@ -84,7 +85,7 @@ class Anyscale(BaseOpenAI):
"""Key word arguments to pass to the model."""
anyscale_api_base: Optional[str] = None
anyscale_api_key: Optional[str] = None
anyscale_api_key: Optional[SecretStr] = None
prefix_messages: List = Field(default_factory=list)
@ -94,9 +95,10 @@ class Anyscale(BaseOpenAI):
values["anyscale_api_base"] = get_from_dict_or_env(
values, "anyscale_api_base", "ANYSCALE_API_BASE"
)
values["anyscale_api_key"] = get_from_dict_or_env(
values, "anyscale_api_key", "ANYSCALE_API_KEY"
values["anyscale_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "anyscale_api_key", "ANYSCALE_API_KEY")
)
try:
import openai
@ -126,7 +128,7 @@ class Anyscale(BaseOpenAI):
def _invocation_params(self) -> Dict[str, Any]:
"""Get the parameters used to invoke the model."""
openai_creds: Dict[str, Any] = {
"api_key": self.anyscale_api_key,
"api_key": cast(SecretStr, self.anyscale_api_key).get_secret_value(),
"api_base": self.anyscale_api_base,
}
return {**openai_creds, **{"model": self.model_name}, **super()._default_params}

View File

@ -0,0 +1,41 @@
"""Test Anyscale llm"""
import pytest
from pytest import CaptureFixture, MonkeyPatch
from langchain.llms.anyscale import Anyscale
from langchain.pydantic_v1 import SecretStr
@pytest.mark.requires("openai")
def test_api_key_is_secret_string() -> None:
llm = Anyscale(
anyscale_api_key="secret-api-key", anyscale_api_base="test", model_name="test"
)
assert isinstance(llm.anyscale_api_key, SecretStr)
@pytest.mark.requires("openai")
def test_api_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("ANYSCALE_API_KEY", "secret-api-key")
llm = Anyscale(anyscale_api_base="test", model_name="test")
print(llm.anyscale_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
@pytest.mark.requires("openai")
def test_api_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test initialization with an API key provided via the initializer"""
llm = Anyscale(
anyscale_api_key="secret-api-key", anyscale_api_base="test", model_name="test"
)
print(llm.anyscale_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"