From 4f75b230ed6098bb4652abf501ec7418544ef2b4 Mon Sep 17 00:00:00 2001 From: aditya thomas Date: Sat, 13 Apr 2024 01:49:31 +0530 Subject: [PATCH] partner[ai21]: masking of the api key for ai21 models (#20257) **Description:** Masking of the API key for AI21 models **Issue:** Fixes #12165 for AI21 **Dependencies:** None Note: This fix came in originally through #12418 but was possibly missed in the refactor to the AI21 partner package --------- Co-authored-by: Erick Friis --- .../ai21/tests/unit_tests/test_chat_models.py | 38 ++++++++++++++++++- .../ai21/tests/unit_tests/test_llms.py | 38 +++++++++++++++++++ 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/libs/partners/ai21/tests/unit_tests/test_chat_models.py b/libs/partners/ai21/tests/unit_tests/test_chat_models.py index f95c73db90..499d21bd6d 100644 --- a/libs/partners/ai21/tests/unit_tests/test_chat_models.py +++ b/libs/partners/ai21/tests/unit_tests/test_chat_models.py @@ -1,5 +1,5 @@ """Test chat model integration.""" -from typing import List, Optional +from typing import List, Optional, cast from unittest.mock import Mock, call import pytest @@ -14,6 +14,8 @@ from langchain_core.messages import ( from langchain_core.messages import ( ChatMessage as LangChainChatMessage, ) +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture, MonkeyPatch from langchain_ai21.chat_models import ( ChatAI21, @@ -236,3 +238,37 @@ def test_generate(mock_client_with_chat: Mock) -> None: ), ] ) + + +def test_api_key_is_secret_string() -> None: + llm = ChatAI21(model="j2-ultra", api_key="secret-api-key") + assert isinstance(llm.api_key, SecretStr) + + +def test_api_key_masked_when_passed_from_env( + monkeypatch: MonkeyPatch, capsys: CaptureFixture +) -> None: + """Test initialization with an API key provided via an env variable""" + monkeypatch.setenv("AI21_API_KEY", "secret-api-key") + llm = ChatAI21(model="j2-ultra") + print(llm.api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + + +def test_api_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + """Test initialization with an API key provided via the initializer""" + llm = ChatAI21(model="j2-ultra", api_key="secret-api-key") + print(llm.api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + + +def test_uses_actual_secret_value_from_secretstr() -> None: + """Test that actual secret is retrieved using `.get_secret_value()`.""" + llm = ChatAI21(model="j2-ultra", api_key="secret-api-key") + assert cast(SecretStr, llm.api_key).get_secret_value() == "secret-api-key" diff --git a/libs/partners/ai21/tests/unit_tests/test_llms.py b/libs/partners/ai21/tests/unit_tests/test_llms.py index 2c47ec234a..1854df0e77 100644 --- a/libs/partners/ai21/tests/unit_tests/test_llms.py +++ b/libs/partners/ai21/tests/unit_tests/test_llms.py @@ -1,4 +1,6 @@ """Test AI21 Chat API wrapper.""" + +from typing import cast from unittest.mock import Mock, call import pytest @@ -6,6 +8,8 @@ from ai21 import MissingApiKeyError from ai21.models import ( Penalty, ) +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture, MonkeyPatch from langchain_ai21 import AI21LLM from tests.unit_tests.conftest import ( @@ -106,3 +110,37 @@ def test_generate(mock_client_with_completion: Mock) -> None: ), ] ) + + +def test_api_key_is_secret_string() -> None: + llm = AI21LLM(model="j2-ultra", api_key="secret-api-key") + assert isinstance(llm.api_key, SecretStr) + + +def test_api_key_masked_when_passed_from_env( + monkeypatch: MonkeyPatch, capsys: CaptureFixture +) -> None: + """Test initialization with an API key provided via an env variable""" + monkeypatch.setenv("AI21_API_KEY", "secret-api-key") + llm = AI21LLM(model="j2-ultra") + print(llm.api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + + +def test_api_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + """Test initialization with an API key provided via the initializer""" + llm = AI21LLM(model="j2-ultra", api_key="secret-api-key") + print(llm.api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + + +def test_uses_actual_secret_value_from_secretstr() -> None: + """Test that actual secret is retrieved using `.get_secret_value()`.""" + llm = AI21LLM(model="j2-ultra", api_key="secret-api-key") + assert cast(SecretStr, llm.api_key).get_secret_value() == "secret-api-key"