mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
factor out to_secret (#12593)
This commit is contained in:
parent
630ae24b28
commit
016813d189
@ -2,7 +2,7 @@ import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type, Union
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
|
||||
|
||||
import requests
|
||||
|
||||
@ -24,7 +24,11 @@ from langchain.schema.messages import (
|
||||
HumanMessageChunk,
|
||||
)
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||
from langchain.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -71,13 +75,6 @@ def _convert_delta_to_message_chunk(
|
||||
return default_class(content=content)
|
||||
|
||||
|
||||
def _to_secret(value: Union[SecretStr, str]) -> SecretStr:
|
||||
"""Convert a string to a SecretStr if needed."""
|
||||
if isinstance(value, SecretStr):
|
||||
return value
|
||||
return SecretStr(value)
|
||||
|
||||
|
||||
# signature generation
|
||||
def _signature(secret_key: SecretStr, payload: Dict[str, Any], timestamp: int) -> str:
|
||||
input_str = secret_key.get_secret_value() + json.dumps(payload) + str(timestamp)
|
||||
@ -171,7 +168,7 @@ class ChatBaichuan(BaseChatModel):
|
||||
"baichuan_api_key",
|
||||
"BAICHUAN_API_KEY",
|
||||
)
|
||||
values["baichuan_secret_key"] = _to_secret(
|
||||
values["baichuan_secret_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values,
|
||||
"baichuan_secret_key",
|
||||
|
@ -4,7 +4,7 @@ import hmac
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type, Union
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional, Type
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
@ -27,7 +27,11 @@ from langchain.schema.messages import (
|
||||
HumanMessageChunk,
|
||||
)
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
|
||||
from langchain.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -116,13 +120,6 @@ def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
|
||||
def _to_secret(value: Union[SecretStr, str]) -> SecretStr:
|
||||
"""Convert a string to a SecretStr if needed."""
|
||||
if isinstance(value, SecretStr):
|
||||
return value
|
||||
return SecretStr(value)
|
||||
|
||||
|
||||
class ChatHunyuan(BaseChatModel):
|
||||
"""Tencent Hunyuan chat models API by Tencent.
|
||||
|
||||
@ -213,7 +210,7 @@ class ChatHunyuan(BaseChatModel):
|
||||
"hunyuan_secret_id",
|
||||
"HUNYUAN_SECRET_ID",
|
||||
)
|
||||
values["hunyuan_secret_key"] = _to_secret(
|
||||
values["hunyuan_secret_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(
|
||||
values,
|
||||
"hunyuan_secret_key",
|
||||
|
@ -1,17 +1,10 @@
|
||||
from typing import Any, Dict, List, Optional, Sequence, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.pydantic_v1 import Extra, SecretStr, root_validator
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
def _to_secret(value: Union[SecretStr, str]) -> SecretStr:
|
||||
"""Convert a string to a SecretStr if needed."""
|
||||
if isinstance(value, SecretStr):
|
||||
return value
|
||||
return SecretStr(value)
|
||||
from langchain.pydantic_v1 import Extra, root_validator
|
||||
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
|
||||
|
||||
class AlephAlpha(LLM):
|
||||
@ -176,7 +169,7 @@ class AlephAlpha(LLM):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["aleph_alpha_api_key"] = _to_secret(
|
||||
values["aleph_alpha_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "aleph_alpha_api_key", "ALEPH_ALPHA_API_KEY")
|
||||
)
|
||||
try:
|
||||
|
@ -9,7 +9,6 @@ from typing import (
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
@ -26,14 +25,7 @@ from langchain.utils import (
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
from langchain.utils.utils import build_extra_kwargs
|
||||
|
||||
|
||||
def _to_secret(value: Union[SecretStr, str]) -> SecretStr:
|
||||
"""Convert a string to a SecretStr if needed."""
|
||||
if isinstance(value, SecretStr):
|
||||
return value
|
||||
return SecretStr(value)
|
||||
from langchain.utils.utils import build_extra_kwargs, convert_to_secret_str
|
||||
|
||||
|
||||
class _AnthropicCommon(BaseLanguageModel):
|
||||
@ -81,7 +73,7 @@ class _AnthropicCommon(BaseLanguageModel):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["anthropic_api_key"] = _to_secret(
|
||||
values["anthropic_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "anthropic_api_key", "ANTHROPIC_API_KEY")
|
||||
)
|
||||
# Get custom api url from environment.
|
||||
|
@ -1,21 +1,14 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.pydantic_v1 import Extra, Field, SecretStr, root_validator
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _to_secret(value: Union[SecretStr, str]) -> SecretStr:
|
||||
"""Convert a string to a SecretStr if needed."""
|
||||
if isinstance(value, SecretStr):
|
||||
return value
|
||||
return SecretStr(value)
|
||||
|
||||
|
||||
class GooseAI(LLM):
|
||||
"""GooseAI large language models.
|
||||
|
||||
@ -96,7 +89,7 @@ class GooseAI(LLM):
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
gooseai_api_key = _to_secret(
|
||||
gooseai_api_key = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "gooseai_api_key", "GOOSEAI_API_KEY")
|
||||
)
|
||||
values["gooseai_api_key"] = gooseai_api_key
|
||||
|
Loading…
Reference in New Issue
Block a user