Mask API key for AI21 LLM (#12418)

- **Description:** Added masking of the API Key for AI21 LLM when
printed and improved the docstring for AI21 LLM.
- Updated the AI21 LLM to utilize SecretStr from pydantic to securely
manage API key.
- Made improvements in the docstring of AI21 LLM. It now mentions that
the API key can also be passed as a named parameter to the constructor.
    - Added unit tests.
  - **Issue:** #12165 
  - **Tag maintainer:** @eyurtsev

---------

Co-authored-by: Anirudh Gautam <anirudh@Anirudhs-Mac-mini.local>
pull/12530/head
Anirudh Gautam 8 months ago committed by GitHub
parent 35d726dc15
commit b257e6a4e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,11 +1,11 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, cast
import requests
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.pydantic_v1 import BaseModel, Extra, root_validator
from langchain.utils import get_from_dict_or_env
from langchain.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
class AI21PenaltyData(BaseModel):
@ -23,13 +23,13 @@ class AI21(LLM):
"""AI21 large language models.
To use, you should have the environment variable ``AI21_API_KEY``
set with your API key.
set with your API key or pass it as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain.llms import AI21
ai21 = AI21(model="j2-jumbo-instruct")
ai21 = AI21(ai21_api_key="my-api-key", model="j2-jumbo-instruct")
"""
model: str = "j2-jumbo-instruct"
@ -62,7 +62,7 @@ class AI21(LLM):
logitBias: Optional[Dict[str, float]] = None
"""Adjust the probability of specific tokens being generated."""
ai21_api_key: Optional[str] = None
ai21_api_key: Optional[SecretStr] = None
stop: Optional[List[str]] = None
@ -77,7 +77,9 @@ class AI21(LLM):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
ai21_api_key = get_from_dict_or_env(values, "ai21_api_key", "AI21_API_KEY")
ai21_api_key = convert_to_secret_str(
get_from_dict_or_env(values, "ai21_api_key", "AI21_API_KEY")
)
values["ai21_api_key"] = ai21_api_key
return values
@ -141,9 +143,10 @@ class AI21(LLM):
else:
base_url = "https://api.ai21.com/studio/v1"
params = {**self._default_params, **kwargs}
self.ai21_api_key = cast(SecretStr, self.ai21_api_key)
response = requests.post(
url=f"{base_url}/{self.model}/complete",
headers={"Authorization": f"Bearer {self.ai21_api_key}"},
headers={"Authorization": f"Bearer {self.ai21_api_key.get_secret_value()}"},
json={"prompt": prompt, "stopSequences": stop, **params},
)
if response.status_code != 200:

@ -16,6 +16,7 @@ from langchain.utils.math import cosine_similarity, cosine_similarity_top_k
from langchain.utils.strings import comma_list, stringify_dict, stringify_value
from langchain.utils.utils import (
check_package_version,
convert_to_secret_str,
get_pydantic_field_names,
guard_import,
mock_now,
@ -27,6 +28,7 @@ __all__ = [
"StrictFormatter",
"check_package_version",
"comma_list",
"convert_to_secret_str",
"cosine_similarity",
"cosine_similarity_top_k",
"formatter",

@ -5,11 +5,13 @@ import functools
import importlib
import warnings
from importlib.metadata import version
from typing import Any, Callable, Dict, Optional, Set, Tuple
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
from packaging.version import parse
from requests import HTTPError, Response
from langchain.pydantic_v1 import SecretStr
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
"""Validate specified keyword args are mutually exclusive."""
@ -169,3 +171,10 @@ def build_extra_kwargs(
)
return extra_kwargs
def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr:
"""Convert a string to a SecretStr if needed."""
if isinstance(value, SecretStr):
return value
return SecretStr(value)

@ -0,0 +1,41 @@
"""Test AI21 llm"""
from typing import cast
from pytest import CaptureFixture, MonkeyPatch
from langchain.llms.ai21 import AI21
from langchain.pydantic_v1 import SecretStr
def test_api_key_is_secret_string() -> None:
llm = AI21(ai21_api_key="secret-api-key")
assert isinstance(llm.ai21_api_key, SecretStr)
def test_api_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("AI21_API_KEY", "secret-api-key")
llm = AI21()
print(llm.ai21_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
def test_api_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test initialization with an API key provided via the initializer"""
llm = AI21(ai21_api_key="secret-api-key")
print(llm.ai21_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"
def test_uses_actual_secret_value_from_secretstr() -> None:
"""Test that actual secret is retrieved using `.get_secret_value()`."""
llm = AI21(ai21_api_key="secret-api-key")
assert cast(SecretStr, llm.ai21_api_key).get_secret_value() == "secret-api-key"
Loading…
Cancel
Save