core[minor]: Add factory for looking up secrets from the env (#25198)

Add factory method for looking secrets from the env.
This commit is contained in:
Eugene Yurtsev 2024-08-08 16:41:58 -04:00 committed by GitHub
parent da9281feb2
commit 429a0ee7fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 171 additions and 3 deletions

View File

@ -27,6 +27,7 @@ from langchain_core.utils.utils import (
guard_import,
mock_now,
raise_for_status_with_text,
secret_from_env,
xor_args,
)
@ -56,4 +57,5 @@ __all__ = [
"batch_iterate",
"abatch_iterate",
"from_env",
"secret_from_env",
]

View File

@ -313,11 +313,11 @@ def from_env(
This will be raised as a ValueError.
"""
def get_from_env_fn() -> str: # type: ignore
def get_from_env_fn() -> Optional[str]:
"""Get a value from an environment variable."""
if key in os.environ:
return os.environ[key]
elif isinstance(default, str):
elif isinstance(default, (str, type(None))):
return default
else:
if error_message:
@ -330,3 +330,62 @@ def from_env(
)
return get_from_env_fn
@overload
def secret_from_env(key: str, /) -> Callable[[], SecretStr]: ...
@overload
def secret_from_env(key: str, /, *, default: str) -> Callable[[], SecretStr]: ...
@overload
def secret_from_env(
key: str, /, *, default: None
) -> Callable[[], Optional[SecretStr]]: ...
@overload
def secret_from_env(key: str, /, *, error_message: str) -> Callable[[], SecretStr]: ...
def secret_from_env(
key: str,
/,
*,
default: Union[str, _NoDefaultType, None] = _NoDefault,
error_message: Optional[str] = None,
) -> Union[Callable[[], Optional[SecretStr]], Callable[[], SecretStr]]:
"""Secret from env.
Args:
key: The environment variable to look up.
default: The default value to return if the environment variable is not set.
error_message: the error message which will be raised if the key is not found
and no default value is provided.
This will be raised as a ValueError.
Returns:
factory method that will look up the secret from the environment.
"""
def get_secret_from_env() -> Optional[SecretStr]:
"""Get a value from an environment variable."""
if key in os.environ:
return SecretStr(os.environ[key])
elif isinstance(default, str):
return SecretStr(default)
elif isinstance(default, type(None)):
return None
else:
if error_message:
raise ValueError(error_message)
else:
raise ValueError(
f"Did not find {key}, please add an environment variable"
f" `{key}` which contains it, or pass"
f" `{key}` as a named parameter."
)
return get_secret_from_env

View File

@ -26,6 +26,7 @@ EXPECTED_ALL = [
"stringify_value",
"pre_init",
"from_env",
"secret_from_env",
]

View File

@ -1,12 +1,13 @@
import os
import re
from contextlib import AbstractContextManager, nullcontext
from typing import Any, Dict, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
from unittest.mock import patch
import pytest
from langchain_core import utils
from langchain_core.pydantic_v1 import SecretStr
from langchain_core.utils import (
check_package_version,
from_env,
@ -15,6 +16,7 @@ from langchain_core.utils import (
)
from langchain_core.utils._merge import merge_dicts
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
from langchain_core.utils.utils import secret_from_env
@pytest.mark.parametrize(
@ -254,3 +256,107 @@ def test_from_env_with_default_error_message() -> None:
get_value = from_env(key)
with pytest.raises(ValueError, match=f"Did not find {key}"):
get_value()
def test_secret_from_env_with_env_variable(monkeypatch: pytest.MonkeyPatch) -> None:
# Set the environment variable
monkeypatch.setenv("TEST_KEY", "secret_value")
# Get the function
get_secret: Callable[[], Optional[SecretStr]] = secret_from_env("TEST_KEY")
# Assert that it returns the correct value
assert get_secret() == SecretStr("secret_value")
def test_secret_from_env_with_default_value(monkeypatch: pytest.MonkeyPatch) -> None:
# Unset the environment variable
monkeypatch.delenv("TEST_KEY", raising=False)
# Get the function with a default value
get_secret: Callable[[], SecretStr] = secret_from_env(
"TEST_KEY", default="default_value"
)
# Assert that it returns the default value
assert get_secret() == SecretStr("default_value")
def test_secret_from_env_with_none_default(monkeypatch: pytest.MonkeyPatch) -> None:
# Unset the environment variable
monkeypatch.delenv("TEST_KEY", raising=False)
# Get the function with a default value of None
get_secret: Callable[[], Optional[SecretStr]] = secret_from_env(
"TEST_KEY", default=None
)
# Assert that it returns None
assert get_secret() is None
def test_secret_from_env_without_default_raises_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Unset the environment variable
monkeypatch.delenv("TEST_KEY", raising=False)
# Get the function without a default value
get_secret: Callable[[], SecretStr] = secret_from_env("TEST_KEY")
# Assert that it raises a ValueError with the correct message
with pytest.raises(ValueError, match="Did not find TEST_KEY"):
get_secret()
def test_secret_from_env_with_custom_error_message(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Unset the environment variable
monkeypatch.delenv("TEST_KEY", raising=False)
# Get the function without a default value but with a custom error message
get_secret: Callable[[], SecretStr] = secret_from_env(
"TEST_KEY", error_message="Custom error message"
)
# Assert that it raises a ValueError with the custom message
with pytest.raises(ValueError, match="Custom error message"):
get_secret()
def test_using_secret_from_env_as_default_factory(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Set the environment variable
monkeypatch.setenv("TEST_KEY", "secret_value")
# Get the function
from langchain_core.pydantic_v1 import BaseModel, Field
class Foo(BaseModel):
secret: SecretStr = Field(default_factory=secret_from_env("TEST_KEY"))
assert Foo().secret.get_secret_value() == "secret_value"
class Bar(BaseModel):
secret: Optional[SecretStr] = Field(
default_factory=secret_from_env("TEST_KEY_2", default=None)
)
assert Bar().secret is None
class Buzz(BaseModel):
secret: Optional[SecretStr] = Field(
default_factory=secret_from_env("TEST_KEY_2", default="hello")
)
# We know it will be SecretStr rather than Optional[SecretStr]
assert Buzz().secret.get_secret_value() == "hello" # type: ignore
class OhMy(BaseModel):
secret: Optional[SecretStr] = Field(
default_factory=secret_from_env("FOOFOOFOOBAR")
)
with pytest.raises(ValueError, match="Did not find FOOFOOFOOBAR"):
OhMy()