diff --git a/libs/langchain/langchain/chat_models/azureml_endpoint.py b/libs/langchain/langchain/chat_models/azureml_endpoint.py index 53bdc84925..8efa957ad0 100644 --- a/libs/langchain/langchain/chat_models/azureml_endpoint.py +++ b/libs/langchain/langchain/chat_models/azureml_endpoint.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, cast from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.chat_models.base import SimpleChatModel from langchain.llms.azureml_endpoint import AzureMLEndpointClient, ContentFormatterBase -from langchain.pydantic_v1 import validator +from langchain.pydantic_v1 import SecretStr, validator from langchain.schema.messages import ( AIMessage, BaseMessage, @@ -12,7 +12,7 @@ from langchain.schema.messages import ( HumanMessage, SystemMessage, ) -from langchain.utils import get_from_dict_or_env +from langchain.utils import convert_to_secret_str, get_from_dict_or_env class LlamaContentFormatter(ContentFormatterBase): @@ -94,7 +94,7 @@ class AzureMLChatOnlineEndpoint(SimpleChatModel): """URL of pre-existing Endpoint. Should be passed to constructor or specified as env var `AZUREML_ENDPOINT_URL`.""" - endpoint_api_key: str = "" + endpoint_api_key: SecretStr = convert_to_secret_str("") """Authentication Key for Endpoint. Should be passed to constructor or specified as env var `AZUREML_ENDPOINT_API_KEY`.""" @@ -112,13 +112,15 @@ class AzureMLChatOnlineEndpoint(SimpleChatModel): @classmethod def validate_client(cls, field_value: Any, values: Dict) -> AzureMLEndpointClient: """Validate that api key and python package exist in environment.""" - endpoint_key = get_from_dict_or_env( - values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY" + values["endpoint_api_key"] = convert_to_secret_str( + get_from_dict_or_env(values, "endpoint_api_key", "AZUREML_ENDPOINT_API_KEY") ) endpoint_url = get_from_dict_or_env( values, "endpoint_url", "AZUREML_ENDPOINT_URL" ) - http_client = AzureMLEndpointClient(endpoint_url, endpoint_key) + http_client = AzureMLEndpointClient( + endpoint_url, values["endpoint_api_key"].get_secret_value() + ) return http_client @property diff --git a/libs/langchain/tests/unit_tests/chat_models/test_azureml_endpoint.py b/libs/langchain/tests/unit_tests/chat_models/test_azureml_endpoint.py new file mode 100644 index 0000000000..de1055fc3d --- /dev/null +++ b/libs/langchain/tests/unit_tests/chat_models/test_azureml_endpoint.py @@ -0,0 +1,65 @@ +"""Test AzureML chat endpoint.""" + +import os + +import pytest +from pytest import CaptureFixture, FixtureRequest + +from langchain.chat_models.azureml_endpoint import AzureMLChatOnlineEndpoint +from langchain.pydantic_v1 import SecretStr + + +@pytest.fixture(scope="class") +def api_passed_via_environment_fixture() -> AzureMLChatOnlineEndpoint: + """Fixture to create an AzureMLChatOnlineEndpoint instance + with API key passed from environment""" + os.environ["AZUREML_ENDPOINT_API_KEY"] = "my-api-key" + azure_chat = AzureMLChatOnlineEndpoint( + endpoint_url="https://..inference.ml.azure.com/score" + ) + del os.environ["AZUREML_ENDPOINT_API_KEY"] + return azure_chat + + +@pytest.fixture(scope="class") +def api_passed_via_constructor_fixture() -> AzureMLChatOnlineEndpoint: + """Fixture to create an AzureMLChatOnlineEndpoint instance + with API key passed from constructor""" + azure_chat = AzureMLChatOnlineEndpoint( + endpoint_url="https://..inference.ml.azure.com/score", + endpoint_api_key="my-api-key", + ) + return azure_chat + + +@pytest.mark.parametrize( + "fixture_name", + ["api_passed_via_constructor_fixture", "api_passed_via_environment_fixture"], +) +class TestAzureMLChatOnlineEndpoint: + def test_api_key_is_secret_string( + self, fixture_name: str, request: FixtureRequest + ) -> None: + """Test that the API key is a SecretStr instance""" + azure_chat = request.getfixturevalue(fixture_name) + assert isinstance(azure_chat.endpoint_api_key, SecretStr) + + def test_api_key_masked( + self, fixture_name: str, request: FixtureRequest, capsys: CaptureFixture + ) -> None: + """Test that the API key is masked""" + azure_chat = request.getfixturevalue(fixture_name) + print(azure_chat.endpoint_api_key, end="") + captured = capsys.readouterr() + assert ( + (str(azure_chat.endpoint_api_key) == "**********") + and (repr(azure_chat.endpoint_api_key) == "SecretStr('**********')") + and (captured.out == "**********") + ) + + def test_api_key_is_readable( + self, fixture_name: str, request: FixtureRequest + ) -> None: + """Test that the real secret value of the API key can be read""" + azure_chat = request.getfixturevalue(fixture_name) + assert azure_chat.endpoint_api_key.get_secret_value() == "my-api-key"