From 016813d189d146762186dea785536180d9d09efa Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Mon, 30 Oct 2023 15:10:25 -0700 Subject: [PATCH] factor out to_secret (#12593) --- .../langchain/langchain/chat_models/baichuan.py | 17 +++++++---------- libs/langchain/langchain/chat_models/hunyuan.py | 17 +++++++---------- libs/langchain/langchain/llms/aleph_alpha.py | 15 ++++----------- libs/langchain/langchain/llms/anthropic.py | 12 ++---------- libs/langchain/langchain/llms/gooseai.py | 13 +++---------- 5 files changed, 23 insertions(+), 51 deletions(-) diff --git a/libs/langchain/langchain/chat_models/baichuan.py b/libs/langchain/langchain/chat_models/baichuan.py index 7f5f5e5f91..611d761251 100644 --- a/libs/langchain/langchain/chat_models/baichuan.py +++ b/libs/langchain/langchain/chat_models/baichuan.py @@ -2,7 +2,7 @@ import hashlib import json import logging import time -from typing import Any, Dict, Iterator, List, Mapping, Optional, Type, Union +from typing import Any, Dict, Iterator, List, Mapping, Optional, Type import requests @@ -24,7 +24,11 @@ from langchain.schema.messages import ( HumanMessageChunk, ) from langchain.schema.output import ChatGenerationChunk -from langchain.utils import get_from_dict_or_env, get_pydantic_field_names +from langchain.utils import ( + convert_to_secret_str, + get_from_dict_or_env, + get_pydantic_field_names, +) logger = logging.getLogger(__name__) @@ -71,13 +75,6 @@ def _convert_delta_to_message_chunk( return default_class(content=content) -def _to_secret(value: Union[SecretStr, str]) -> SecretStr: - """Convert a string to a SecretStr if needed.""" - if isinstance(value, SecretStr): - return value - return SecretStr(value) - - # signature generation def _signature(secret_key: SecretStr, payload: Dict[str, Any], timestamp: int) -> str: input_str = secret_key.get_secret_value() + json.dumps(payload) + str(timestamp) @@ -171,7 +168,7 @@ class ChatBaichuan(BaseChatModel): "baichuan_api_key", "BAICHUAN_API_KEY", ) - values["baichuan_secret_key"] = _to_secret( + values["baichuan_secret_key"] = convert_to_secret_str( get_from_dict_or_env( values, "baichuan_secret_key", diff --git a/libs/langchain/langchain/chat_models/hunyuan.py b/libs/langchain/langchain/chat_models/hunyuan.py index b87f2748ed..d478bbe9c4 100644 --- a/libs/langchain/langchain/chat_models/hunyuan.py +++ b/libs/langchain/langchain/chat_models/hunyuan.py @@ -4,7 +4,7 @@ import hmac import json import logging import time -from typing import Any, Dict, Iterator, List, Mapping, Optional, Type, Union +from typing import Any, Dict, Iterator, List, Mapping, Optional, Type from urllib.parse import urlparse import requests @@ -27,7 +27,11 @@ from langchain.schema.messages import ( HumanMessageChunk, ) from langchain.schema.output import ChatGenerationChunk -from langchain.utils import get_from_dict_or_env, get_pydantic_field_names +from langchain.utils import ( + convert_to_secret_str, + get_from_dict_or_env, + get_pydantic_field_names, +) logger = logging.getLogger(__name__) @@ -116,13 +120,6 @@ def _create_chat_result(response: Mapping[str, Any]) -> ChatResult: return ChatResult(generations=generations, llm_output=llm_output) -def _to_secret(value: Union[SecretStr, str]) -> SecretStr: - """Convert a string to a SecretStr if needed.""" - if isinstance(value, SecretStr): - return value - return SecretStr(value) - - class ChatHunyuan(BaseChatModel): """Tencent Hunyuan chat models API by Tencent. @@ -213,7 +210,7 @@ class ChatHunyuan(BaseChatModel): "hunyuan_secret_id", "HUNYUAN_SECRET_ID", ) - values["hunyuan_secret_key"] = _to_secret( + values["hunyuan_secret_key"] = convert_to_secret_str( get_from_dict_or_env( values, "hunyuan_secret_key", diff --git a/libs/langchain/langchain/llms/aleph_alpha.py b/libs/langchain/langchain/llms/aleph_alpha.py index 116f6c09bf..e73545756e 100644 --- a/libs/langchain/langchain/llms/aleph_alpha.py +++ b/libs/langchain/langchain/llms/aleph_alpha.py @@ -1,17 +1,10 @@ -from typing import Any, Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.pydantic_v1 import Extra, SecretStr, root_validator -from langchain.utils import get_from_dict_or_env - - -def _to_secret(value: Union[SecretStr, str]) -> SecretStr: - """Convert a string to a SecretStr if needed.""" - if isinstance(value, SecretStr): - return value - return SecretStr(value) +from langchain.pydantic_v1 import Extra, root_validator +from langchain.utils import convert_to_secret_str, get_from_dict_or_env class AlephAlpha(LLM): @@ -176,7 +169,7 @@ class AlephAlpha(LLM): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - values["aleph_alpha_api_key"] = _to_secret( + values["aleph_alpha_api_key"] = convert_to_secret_str( get_from_dict_or_env(values, "aleph_alpha_api_key", "ALEPH_ALPHA_API_KEY") ) try: diff --git a/libs/langchain/langchain/llms/anthropic.py b/libs/langchain/langchain/llms/anthropic.py index 78d3eb1523..e423f5fd5b 100644 --- a/libs/langchain/langchain/llms/anthropic.py +++ b/libs/langchain/langchain/llms/anthropic.py @@ -9,7 +9,6 @@ from typing import ( List, Mapping, Optional, - Union, ) from langchain.callbacks.manager import ( @@ -26,14 +25,7 @@ from langchain.utils import ( get_from_dict_or_env, get_pydantic_field_names, ) -from langchain.utils.utils import build_extra_kwargs - - -def _to_secret(value: Union[SecretStr, str]) -> SecretStr: - """Convert a string to a SecretStr if needed.""" - if isinstance(value, SecretStr): - return value - return SecretStr(value) +from langchain.utils.utils import build_extra_kwargs, convert_to_secret_str class _AnthropicCommon(BaseLanguageModel): @@ -81,7 +73,7 @@ class _AnthropicCommon(BaseLanguageModel): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - values["anthropic_api_key"] = _to_secret( + values["anthropic_api_key"] = convert_to_secret_str( get_from_dict_or_env(values, "anthropic_api_key", "ANTHROPIC_API_KEY") ) # Get custom api url from environment. diff --git a/libs/langchain/langchain/llms/gooseai.py b/libs/langchain/langchain/llms/gooseai.py index 30947ddb1b..67aeb1de18 100644 --- a/libs/langchain/langchain/llms/gooseai.py +++ b/libs/langchain/langchain/llms/gooseai.py @@ -1,21 +1,14 @@ import logging -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Any, Dict, List, Mapping, Optional from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.pydantic_v1 import Extra, Field, SecretStr, root_validator -from langchain.utils import get_from_dict_or_env +from langchain.utils import convert_to_secret_str, get_from_dict_or_env logger = logging.getLogger(__name__) -def _to_secret(value: Union[SecretStr, str]) -> SecretStr: - """Convert a string to a SecretStr if needed.""" - if isinstance(value, SecretStr): - return value - return SecretStr(value) - - class GooseAI(LLM): """GooseAI large language models. @@ -96,7 +89,7 @@ class GooseAI(LLM): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - gooseai_api_key = _to_secret( + gooseai_api_key = convert_to_secret_str( get_from_dict_or_env(values, "gooseai_api_key", "GOOSEAI_API_KEY") ) values["gooseai_api_key"] = gooseai_api_key