From b257e6a4e8160d1326ce16877edcbd2d4f0a6864 Mon Sep 17 00:00:00 2001 From: Anirudh Gautam <55285536+gautamanirudh@users.noreply.github.com> Date: Mon, 30 Oct 2023 03:23:41 +0530 Subject: [PATCH] Mask API key for AI21 LLM (#12418) - **Description:** Added masking of the API Key for AI21 LLM when printed and improved the docstring for AI21 LLM. - Updated the AI21 LLM to utilize SecretStr from pydantic to securely manage API key. - Made improvements in the docstring of AI21 LLM. It now mentions that the API key can also be passed as a named parameter to the constructor. - Added unit tests. - **Issue:** #12165 - **Tag maintainer:** @eyurtsev --------- Co-authored-by: Anirudh Gautam --- libs/langchain/langchain/llms/ai21.py | 19 +++++---- libs/langchain/langchain/utils/__init__.py | 2 + libs/langchain/langchain/utils/utils.py | 11 ++++- .../tests/unit_tests/llms/test_ai21.py | 41 +++++++++++++++++++ 4 files changed, 64 insertions(+), 9 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/llms/test_ai21.py diff --git a/libs/langchain/langchain/llms/ai21.py b/libs/langchain/langchain/llms/ai21.py index dc079c2010..cec4983224 100644 --- a/libs/langchain/langchain/llms/ai21.py +++ b/libs/langchain/langchain/llms/ai21.py @@ -1,11 +1,11 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, cast import requests from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM -from langchain.pydantic_v1 import BaseModel, Extra, root_validator -from langchain.utils import get_from_dict_or_env +from langchain.pydantic_v1 import BaseModel, Extra, SecretStr, root_validator +from langchain.utils import convert_to_secret_str, get_from_dict_or_env class AI21PenaltyData(BaseModel): @@ -23,13 +23,13 @@ class AI21(LLM): """AI21 large language models. To use, you should have the environment variable ``AI21_API_KEY`` - set with your API key. + set with your API key or pass it as a named parameter to the constructor. Example: .. code-block:: python from langchain.llms import AI21 - ai21 = AI21(model="j2-jumbo-instruct") + ai21 = AI21(ai21_api_key="my-api-key", model="j2-jumbo-instruct") """ model: str = "j2-jumbo-instruct" @@ -62,7 +62,7 @@ class AI21(LLM): logitBias: Optional[Dict[str, float]] = None """Adjust the probability of specific tokens being generated.""" - ai21_api_key: Optional[str] = None + ai21_api_key: Optional[SecretStr] = None stop: Optional[List[str]] = None @@ -77,7 +77,9 @@ class AI21(LLM): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" - ai21_api_key = get_from_dict_or_env(values, "ai21_api_key", "AI21_API_KEY") + ai21_api_key = convert_to_secret_str( + get_from_dict_or_env(values, "ai21_api_key", "AI21_API_KEY") + ) values["ai21_api_key"] = ai21_api_key return values @@ -141,9 +143,10 @@ class AI21(LLM): else: base_url = "https://api.ai21.com/studio/v1" params = {**self._default_params, **kwargs} + self.ai21_api_key = cast(SecretStr, self.ai21_api_key) response = requests.post( url=f"{base_url}/{self.model}/complete", - headers={"Authorization": f"Bearer {self.ai21_api_key}"}, + headers={"Authorization": f"Bearer {self.ai21_api_key.get_secret_value()}"}, json={"prompt": prompt, "stopSequences": stop, **params}, ) if response.status_code != 200: diff --git a/libs/langchain/langchain/utils/__init__.py b/libs/langchain/langchain/utils/__init__.py index 61be7e8706..7a3a7b759d 100644 --- a/libs/langchain/langchain/utils/__init__.py +++ b/libs/langchain/langchain/utils/__init__.py @@ -16,6 +16,7 @@ from langchain.utils.math import cosine_similarity, cosine_similarity_top_k from langchain.utils.strings import comma_list, stringify_dict, stringify_value from langchain.utils.utils import ( check_package_version, + convert_to_secret_str, get_pydantic_field_names, guard_import, mock_now, @@ -27,6 +28,7 @@ __all__ = [ "StrictFormatter", "check_package_version", "comma_list", + "convert_to_secret_str", "cosine_similarity", "cosine_similarity_top_k", "formatter", diff --git a/libs/langchain/langchain/utils/utils.py b/libs/langchain/langchain/utils/utils.py index 26533514a6..ece5f6aa1b 100644 --- a/libs/langchain/langchain/utils/utils.py +++ b/libs/langchain/langchain/utils/utils.py @@ -5,11 +5,13 @@ import functools import importlib import warnings from importlib.metadata import version -from typing import Any, Callable, Dict, Optional, Set, Tuple +from typing import Any, Callable, Dict, Optional, Set, Tuple, Union from packaging.version import parse from requests import HTTPError, Response +from langchain.pydantic_v1 import SecretStr + def xor_args(*arg_groups: Tuple[str, ...]) -> Callable: """Validate specified keyword args are mutually exclusive.""" @@ -169,3 +171,10 @@ def build_extra_kwargs( ) return extra_kwargs + + +def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr: + """Convert a string to a SecretStr if needed.""" + if isinstance(value, SecretStr): + return value + return SecretStr(value) diff --git a/libs/langchain/tests/unit_tests/llms/test_ai21.py b/libs/langchain/tests/unit_tests/llms/test_ai21.py new file mode 100644 index 0000000000..87df10ea51 --- /dev/null +++ b/libs/langchain/tests/unit_tests/llms/test_ai21.py @@ -0,0 +1,41 @@ +"""Test AI21 llm""" +from typing import cast + +from pytest import CaptureFixture, MonkeyPatch + +from langchain.llms.ai21 import AI21 +from langchain.pydantic_v1 import SecretStr + + +def test_api_key_is_secret_string() -> None: + llm = AI21(ai21_api_key="secret-api-key") + assert isinstance(llm.ai21_api_key, SecretStr) + + +def test_api_key_masked_when_passed_from_env( + monkeypatch: MonkeyPatch, capsys: CaptureFixture +) -> None: + """Test initialization with an API key provided via an env variable""" + monkeypatch.setenv("AI21_API_KEY", "secret-api-key") + llm = AI21() + print(llm.ai21_api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + + +def test_api_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + """Test initialization with an API key provided via the initializer""" + llm = AI21(ai21_api_key="secret-api-key") + print(llm.ai21_api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + + +def test_uses_actual_secret_value_from_secretstr() -> None: + """Test that actual secret is retrieved using `.get_secret_value()`.""" + llm = AI21(ai21_api_key="secret-api-key") + assert cast(SecretStr, llm.ai21_api_key).get_secret_value() == "secret-api-key"