From fd96878c4b41363ef391b580f85dc78c879b8a79 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 28 Sep 2023 14:21:41 -0400 Subject: [PATCH] Fix anthropic secret key when passed in via init (#11185) Fixes anthropic secret key when passed via init https://github.com/langchain-ai/langchain/issues/11182 --- libs/langchain/langchain/llms/anthropic.py | 21 +++++++++++++++++-- .../chat_models/test_anthropic.py | 7 +++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/llms/anthropic.py b/libs/langchain/langchain/llms/anthropic.py index c75074d64c..78d3eb1523 100644 --- a/libs/langchain/langchain/llms/anthropic.py +++ b/libs/langchain/langchain/llms/anthropic.py @@ -1,6 +1,16 @@ import re import warnings -from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Union, +) from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -19,6 +29,13 @@ from langchain.utils import ( from langchain.utils.utils import build_extra_kwargs +def _to_secret(value: Union[SecretStr, str]) -> SecretStr: + """Convert a string to a SecretStr if needed.""" + if isinstance(value, SecretStr): + return value + return SecretStr(value) + + class _AnthropicCommon(BaseLanguageModel): client: Any = None #: :meta private: async_client: Any = None #: :meta private: @@ -64,7 +81,7 @@ class _AnthropicCommon(BaseLanguageModel): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - values["anthropic_api_key"] = SecretStr( + values["anthropic_api_key"] = _to_secret( get_from_dict_or_env(values, "anthropic_api_key", "ANTHROPIC_API_KEY") ) # Get custom api url from environment. diff --git a/libs/langchain/tests/integration_tests/chat_models/test_anthropic.py b/libs/langchain/tests/integration_tests/chat_models/test_anthropic.py index 5e3848d382..c2a7247958 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_anthropic.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_anthropic.py @@ -22,6 +22,13 @@ def test_anthropic_call() -> None: assert isinstance(response.content, str) +def test_anthropic_initialization() -> None: + """Test anthropic initialization.""" + # Verify that chat anthropic can be initialized using a secret key provided + # as a parameter rather than an environment variable. + ChatAnthropic(model="test", anthropic_api_key="test") + + def test_anthropic_generate() -> None: """Test generate method of anthropic.""" chat = ChatAnthropic(model="test")