mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
Mask API key for Anyscale LLM (#12406)
Description: Add masking of API Key for Anyscale LLM when printed. Issue: #12165 Dependencies: None Tag maintainer: @eyurtsev --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
5ae51a8a85
commit
ae63c186af
@ -13,9 +13,9 @@ from langchain.chat_models.openai import (
|
||||
ChatOpenAI,
|
||||
_import_tiktoken,
|
||||
)
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import tiktoken
|
||||
@ -53,7 +53,7 @@ class ChatAnyscale(ChatOpenAI):
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"anyscale_api_key": "ANYSCALE_API_KEY"}
|
||||
|
||||
anyscale_api_key: Optional[str] = None
|
||||
anyscale_api_key: Optional[SecretStr] = None
|
||||
"""AnyScale Endpoints API keys."""
|
||||
model_name: str = Field(default=DEFAULT_MODEL, alias="model")
|
||||
"""Model name to use."""
|
||||
@ -98,11 +98,13 @@ class ChatAnyscale(ChatOpenAI):
|
||||
@root_validator(pre=True)
|
||||
def validate_environment_override(cls, values: dict) -> dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["openai_api_key"] = get_from_dict_or_env(
|
||||
values["openai_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values,
|
||||
"anyscale_api_key",
|
||||
"ANYSCALE_API_KEY",
|
||||
)
|
||||
)
|
||||
values["openai_api_base"] = get_from_dict_or_env(
|
||||
values,
|
||||
"anyscale_api_base",
|
||||
@ -138,7 +140,7 @@ class ChatAnyscale(ChatOpenAI):
|
||||
model_name = values["model_name"]
|
||||
|
||||
available_models = cls.get_available_models(
|
||||
values["openai_api_key"],
|
||||
values["openai_api_key"].get_secret_value(),
|
||||
values["openai_api_base"],
|
||||
)
|
||||
|
||||
|
@ -9,6 +9,7 @@ from typing import (
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
@ -20,10 +21,10 @@ from langchain.llms.openai import (
|
||||
acompletion_with_retry,
|
||||
completion_with_retry,
|
||||
)
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
|
||||
|
||||
def update_token_usage(
|
||||
@ -84,7 +85,7 @@ class Anyscale(BaseOpenAI):
|
||||
|
||||
"""Key word arguments to pass to the model."""
|
||||
anyscale_api_base: Optional[str] = None
|
||||
anyscale_api_key: Optional[str] = None
|
||||
anyscale_api_key: Optional[SecretStr] = None
|
||||
|
||||
prefix_messages: List = Field(default_factory=list)
|
||||
|
||||
@ -94,9 +95,10 @@ class Anyscale(BaseOpenAI):
|
||||
values["anyscale_api_base"] = get_from_dict_or_env(
|
||||
values, "anyscale_api_base", "ANYSCALE_API_BASE"
|
||||
)
|
||||
values["anyscale_api_key"] = get_from_dict_or_env(
|
||||
values, "anyscale_api_key", "ANYSCALE_API_KEY"
|
||||
values["anyscale_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "anyscale_api_key", "ANYSCALE_API_KEY")
|
||||
)
|
||||
|
||||
try:
|
||||
import openai
|
||||
|
||||
@ -126,7 +128,7 @@ class Anyscale(BaseOpenAI):
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
openai_creds: Dict[str, Any] = {
|
||||
"api_key": self.anyscale_api_key,
|
||||
"api_key": cast(SecretStr, self.anyscale_api_key).get_secret_value(),
|
||||
"api_base": self.anyscale_api_base,
|
||||
}
|
||||
return {**openai_creds, **{"model": self.model_name}, **super()._default_params}
|
||||
|
41
libs/langchain/tests/unit_tests/llms/test_anyscale.py
Normal file
41
libs/langchain/tests/unit_tests/llms/test_anyscale.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""Test Anyscale llm"""
|
||||
import pytest
|
||||
from pytest import CaptureFixture, MonkeyPatch
|
||||
|
||||
from langchain.llms.anyscale import Anyscale
|
||||
from langchain.pydantic_v1 import SecretStr
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_api_key_is_secret_string() -> None:
|
||||
llm = Anyscale(
|
||||
anyscale_api_key="secret-api-key", anyscale_api_base="test", model_name="test"
|
||||
)
|
||||
assert isinstance(llm.anyscale_api_key, SecretStr)
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_api_key_masked_when_passed_from_env(
|
||||
monkeypatch: MonkeyPatch, capsys: CaptureFixture
|
||||
) -> None:
|
||||
"""Test initialization with an API key provided via an env variable"""
|
||||
monkeypatch.setenv("ANYSCALE_API_KEY", "secret-api-key")
|
||||
llm = Anyscale(anyscale_api_base="test", model_name="test")
|
||||
print(llm.anyscale_api_key, end="")
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_api_key_masked_when_passed_via_constructor(
|
||||
capsys: CaptureFixture,
|
||||
) -> None:
|
||||
"""Test initialization with an API key provided via the initializer"""
|
||||
llm = Anyscale(
|
||||
anyscale_api_key="secret-api-key", anyscale_api_base="test", model_name="test"
|
||||
)
|
||||
print(llm.anyscale_api_key, end="")
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert captured.out == "**********"
|
Loading…
Reference in New Issue
Block a user