mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
core[minor]: Add factory for looking up secrets from the env (#25198)
Add factory method for looking secrets from the env.
This commit is contained in:
parent
da9281feb2
commit
429a0ee7fd
@ -27,6 +27,7 @@ from langchain_core.utils.utils import (
|
||||
guard_import,
|
||||
mock_now,
|
||||
raise_for_status_with_text,
|
||||
secret_from_env,
|
||||
xor_args,
|
||||
)
|
||||
|
||||
@ -56,4 +57,5 @@ __all__ = [
|
||||
"batch_iterate",
|
||||
"abatch_iterate",
|
||||
"from_env",
|
||||
"secret_from_env",
|
||||
]
|
||||
|
@ -313,11 +313,11 @@ def from_env(
|
||||
This will be raised as a ValueError.
|
||||
"""
|
||||
|
||||
def get_from_env_fn() -> str: # type: ignore
|
||||
def get_from_env_fn() -> Optional[str]:
|
||||
"""Get a value from an environment variable."""
|
||||
if key in os.environ:
|
||||
return os.environ[key]
|
||||
elif isinstance(default, str):
|
||||
elif isinstance(default, (str, type(None))):
|
||||
return default
|
||||
else:
|
||||
if error_message:
|
||||
@ -330,3 +330,62 @@ def from_env(
|
||||
)
|
||||
|
||||
return get_from_env_fn
|
||||
|
||||
|
||||
@overload
|
||||
def secret_from_env(key: str, /) -> Callable[[], SecretStr]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def secret_from_env(key: str, /, *, default: str) -> Callable[[], SecretStr]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def secret_from_env(
|
||||
key: str, /, *, default: None
|
||||
) -> Callable[[], Optional[SecretStr]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def secret_from_env(key: str, /, *, error_message: str) -> Callable[[], SecretStr]: ...
|
||||
|
||||
|
||||
def secret_from_env(
|
||||
key: str,
|
||||
/,
|
||||
*,
|
||||
default: Union[str, _NoDefaultType, None] = _NoDefault,
|
||||
error_message: Optional[str] = None,
|
||||
) -> Union[Callable[[], Optional[SecretStr]], Callable[[], SecretStr]]:
|
||||
"""Secret from env.
|
||||
|
||||
Args:
|
||||
key: The environment variable to look up.
|
||||
default: The default value to return if the environment variable is not set.
|
||||
error_message: the error message which will be raised if the key is not found
|
||||
and no default value is provided.
|
||||
This will be raised as a ValueError.
|
||||
|
||||
Returns:
|
||||
factory method that will look up the secret from the environment.
|
||||
"""
|
||||
|
||||
def get_secret_from_env() -> Optional[SecretStr]:
|
||||
"""Get a value from an environment variable."""
|
||||
if key in os.environ:
|
||||
return SecretStr(os.environ[key])
|
||||
elif isinstance(default, str):
|
||||
return SecretStr(default)
|
||||
elif isinstance(default, type(None)):
|
||||
return None
|
||||
else:
|
||||
if error_message:
|
||||
raise ValueError(error_message)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Did not find {key}, please add an environment variable"
|
||||
f" `{key}` which contains it, or pass"
|
||||
f" `{key}` as a named parameter."
|
||||
)
|
||||
|
||||
return get_secret_from_env
|
||||
|
@ -26,6 +26,7 @@ EXPECTED_ALL = [
|
||||
"stringify_value",
|
||||
"pre_init",
|
||||
"from_env",
|
||||
"secret_from_env",
|
||||
]
|
||||
|
||||
|
||||
|
@ -1,12 +1,13 @@
|
||||
import os
|
||||
import re
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from typing import Any, Dict, Optional, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core import utils
|
||||
from langchain_core.pydantic_v1 import SecretStr
|
||||
from langchain_core.utils import (
|
||||
check_package_version,
|
||||
from_env,
|
||||
@ -15,6 +16,7 @@ from langchain_core.utils import (
|
||||
)
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
|
||||
from langchain_core.utils.utils import secret_from_env
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -254,3 +256,107 @@ def test_from_env_with_default_error_message() -> None:
|
||||
get_value = from_env(key)
|
||||
with pytest.raises(ValueError, match=f"Did not find {key}"):
|
||||
get_value()
|
||||
|
||||
|
||||
def test_secret_from_env_with_env_variable(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Set the environment variable
|
||||
monkeypatch.setenv("TEST_KEY", "secret_value")
|
||||
|
||||
# Get the function
|
||||
get_secret: Callable[[], Optional[SecretStr]] = secret_from_env("TEST_KEY")
|
||||
|
||||
# Assert that it returns the correct value
|
||||
assert get_secret() == SecretStr("secret_value")
|
||||
|
||||
|
||||
def test_secret_from_env_with_default_value(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Unset the environment variable
|
||||
monkeypatch.delenv("TEST_KEY", raising=False)
|
||||
|
||||
# Get the function with a default value
|
||||
get_secret: Callable[[], SecretStr] = secret_from_env(
|
||||
"TEST_KEY", default="default_value"
|
||||
)
|
||||
|
||||
# Assert that it returns the default value
|
||||
assert get_secret() == SecretStr("default_value")
|
||||
|
||||
|
||||
def test_secret_from_env_with_none_default(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Unset the environment variable
|
||||
monkeypatch.delenv("TEST_KEY", raising=False)
|
||||
|
||||
# Get the function with a default value of None
|
||||
get_secret: Callable[[], Optional[SecretStr]] = secret_from_env(
|
||||
"TEST_KEY", default=None
|
||||
)
|
||||
|
||||
# Assert that it returns None
|
||||
assert get_secret() is None
|
||||
|
||||
|
||||
def test_secret_from_env_without_default_raises_error(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Unset the environment variable
|
||||
monkeypatch.delenv("TEST_KEY", raising=False)
|
||||
|
||||
# Get the function without a default value
|
||||
get_secret: Callable[[], SecretStr] = secret_from_env("TEST_KEY")
|
||||
|
||||
# Assert that it raises a ValueError with the correct message
|
||||
with pytest.raises(ValueError, match="Did not find TEST_KEY"):
|
||||
get_secret()
|
||||
|
||||
|
||||
def test_secret_from_env_with_custom_error_message(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Unset the environment variable
|
||||
monkeypatch.delenv("TEST_KEY", raising=False)
|
||||
|
||||
# Get the function without a default value but with a custom error message
|
||||
get_secret: Callable[[], SecretStr] = secret_from_env(
|
||||
"TEST_KEY", error_message="Custom error message"
|
||||
)
|
||||
|
||||
# Assert that it raises a ValueError with the custom message
|
||||
with pytest.raises(ValueError, match="Custom error message"):
|
||||
get_secret()
|
||||
|
||||
|
||||
def test_using_secret_from_env_as_default_factory(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# Set the environment variable
|
||||
monkeypatch.setenv("TEST_KEY", "secret_value")
|
||||
# Get the function
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
class Foo(BaseModel):
|
||||
secret: SecretStr = Field(default_factory=secret_from_env("TEST_KEY"))
|
||||
|
||||
assert Foo().secret.get_secret_value() == "secret_value"
|
||||
|
||||
class Bar(BaseModel):
|
||||
secret: Optional[SecretStr] = Field(
|
||||
default_factory=secret_from_env("TEST_KEY_2", default=None)
|
||||
)
|
||||
|
||||
assert Bar().secret is None
|
||||
|
||||
class Buzz(BaseModel):
|
||||
secret: Optional[SecretStr] = Field(
|
||||
default_factory=secret_from_env("TEST_KEY_2", default="hello")
|
||||
)
|
||||
|
||||
# We know it will be SecretStr rather than Optional[SecretStr]
|
||||
assert Buzz().secret.get_secret_value() == "hello" # type: ignore
|
||||
|
||||
class OhMy(BaseModel):
|
||||
secret: Optional[SecretStr] = Field(
|
||||
default_factory=secret_from_env("FOOFOOFOOBAR")
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Did not find FOOFOOFOOBAR"):
|
||||
OhMy()
|
||||
|
Loading…
Reference in New Issue
Block a user