factor out to_secret (#12593)

This commit is contained in:
Bagatur 2023-10-30 15:10:25 -07:00 committed by GitHub
parent 630ae24b28
commit 016813d189
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 23 additions and 51 deletions

View File

@ -2,7 +2,7 @@ import hashlib
import json
import logging
import time
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type, Union
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
import requests
@ -24,7 +24,11 @@ from langchain.schema.messages import (
HumanMessageChunk,
)
from langchain.schema.output import ChatGenerationChunk
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
from langchain.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
)
logger = logging.getLogger(__name__)
@ -71,13 +75,6 @@ def _convert_delta_to_message_chunk(
return default_class(content=content)
def _to_secret(value: Union[SecretStr, str]) -> SecretStr:
"""Convert a string to a SecretStr if needed."""
if isinstance(value, SecretStr):
return value
return SecretStr(value)
# signature generation
def _signature(secret_key: SecretStr, payload: Dict[str, Any], timestamp: int) -> str:
input_str = secret_key.get_secret_value() + json.dumps(payload) + str(timestamp)
@ -171,7 +168,7 @@ class ChatBaichuan(BaseChatModel):
"baichuan_api_key",
"BAICHUAN_API_KEY",
)
values["baichuan_secret_key"] = _to_secret(
values["baichuan_secret_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"baichuan_secret_key",

View File

@ -4,7 +4,7 @@ import hmac
import json
import logging
import time
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type, Union
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
from urllib.parse import urlparse
import requests
@ -27,7 +27,11 @@ from langchain.schema.messages import (
HumanMessageChunk,
)
from langchain.schema.output import ChatGenerationChunk
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
from langchain.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
)
logger = logging.getLogger(__name__)
@ -116,13 +120,6 @@ def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
return ChatResult(generations=generations, llm_output=llm_output)
def _to_secret(value: Union[SecretStr, str]) -> SecretStr:
"""Convert a string to a SecretStr if needed."""
if isinstance(value, SecretStr):
return value
return SecretStr(value)
class ChatHunyuan(BaseChatModel):
"""Tencent Hunyuan chat models API by Tencent.
@ -213,7 +210,7 @@ class ChatHunyuan(BaseChatModel):
"hunyuan_secret_id",
"HUNYUAN_SECRET_ID",
)
values["hunyuan_secret_key"] = _to_secret(
values["hunyuan_secret_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"hunyuan_secret_key",

View File

@ -1,17 +1,10 @@
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import Extra, SecretStr, root_validator
from langchain.utils import get_from_dict_or_env
def _to_secret(value: Union[SecretStr, str]) -> SecretStr:
"""Convert a string to a SecretStr if needed."""
if isinstance(value, SecretStr):
return value
return SecretStr(value)
from langchain.pydantic_v1 import Extra, root_validator
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
class AlephAlpha(LLM):
@ -176,7 +169,7 @@ class AlephAlpha(LLM):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["aleph_alpha_api_key"] = _to_secret(
values["aleph_alpha_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "aleph_alpha_api_key", "ALEPH_ALPHA_API_KEY")
)
try:

View File

@ -9,7 +9,6 @@ from typing import (
List,
Mapping,
Optional,
Union,
)
from langchain.callbacks.manager import (
@ -26,14 +25,7 @@ from langchain.utils import (
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain.utils.utils import build_extra_kwargs
def _to_secret(value: Union[SecretStr, str]) -> SecretStr:
"""Convert a string to a SecretStr if needed."""
if isinstance(value, SecretStr):
return value
return SecretStr(value)
from langchain.utils.utils import build_extra_kwargs, convert_to_secret_str
class _AnthropicCommon(BaseLanguageModel):
@ -81,7 +73,7 @@ class _AnthropicCommon(BaseLanguageModel):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["anthropic_api_key"] = _to_secret(
values["anthropic_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "anthropic_api_key", "ANTHROPIC_API_KEY")
)
# Get custom api url from environment.

View File

@ -1,21 +1,14 @@
import logging
from typing import Any, Dict, List, Mapping, Optional, Union
from typing import Any, Dict, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.pydantic_v1 import Extra, Field, SecretStr, root_validator
from langchain.utils import get_from_dict_or_env
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
logger = logging.getLogger(__name__)
def _to_secret(value: Union[SecretStr, str]) -> SecretStr:
"""Convert a string to a SecretStr if needed."""
if isinstance(value, SecretStr):
return value
return SecretStr(value)
class GooseAI(LLM):
"""GooseAI large language models.
@ -96,7 +89,7 @@ class GooseAI(LLM):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
gooseai_api_key = _to_secret(
gooseai_api_key = convert_to_secret_str(
get_from_dict_or_env(values, "gooseai_api_key", "GOOSEAI_API_KEY")
)
values["gooseai_api_key"] = gooseai_api_key