mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
core[minor]: Add from_env utility (#25189)
Add a utility that can be used as a default factory The goal will be to start migrating from of the pydantic models to use `from_env` as a default factory if possible. ```python from pydantic import Field, BaseModel from langchain_core.utils import from_env class Foo(BaseModel): name: str = Field(default_factory=from_env('HELLO')) ```
This commit is contained in:
parent
98779797fe
commit
30fb345342
@ -22,6 +22,7 @@ from langchain_core.utils.utils import (
|
||||
build_extra_kwargs,
|
||||
check_package_version,
|
||||
convert_to_secret_str,
|
||||
from_env,
|
||||
get_pydantic_field_names,
|
||||
guard_import,
|
||||
mock_now,
|
||||
@ -54,4 +55,5 @@ __all__ = [
|
||||
"pre_init",
|
||||
"batch_iterate",
|
||||
"abatch_iterate",
|
||||
"from_env",
|
||||
]
|
||||
|
@ -4,9 +4,10 @@ import contextlib
|
||||
import datetime
|
||||
import functools
|
||||
import importlib
|
||||
import os
|
||||
import warnings
|
||||
from importlib.metadata import version
|
||||
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union, overload
|
||||
|
||||
from packaging.version import parse
|
||||
from requests import HTTPError, Response
|
||||
@ -260,3 +261,72 @@ def convert_to_secret_str(value: Union[SecretStr, str]) -> SecretStr:
|
||||
if isinstance(value, SecretStr):
|
||||
return value
|
||||
return SecretStr(value)
|
||||
|
||||
|
||||
class _NoDefaultType:
|
||||
"""Type to indicate no default value is provided."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
_NoDefault = _NoDefaultType()
|
||||
|
||||
|
||||
@overload
|
||||
def from_env(key: str, /) -> Callable[[], str]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def from_env(key: str, /, *, default: str) -> Callable[[], str]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def from_env(key: str, /, *, error_message: str) -> Callable[[], str]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def from_env(
|
||||
key: str, /, *, default: str, error_message: Optional[str]
|
||||
) -> Callable[[], str]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def from_env(
|
||||
key: str, /, *, default: None, error_message: Optional[str]
|
||||
) -> Callable[[], Optional[str]]: ...
|
||||
|
||||
|
||||
def from_env(
|
||||
key: str,
|
||||
/,
|
||||
*,
|
||||
default: Union[str, _NoDefaultType, None] = _NoDefault,
|
||||
error_message: Optional[str] = None,
|
||||
) -> Union[Callable[[], str], Callable[[], Optional[str]]]:
|
||||
"""Create a factory method that gets a value from an environment variable.
|
||||
|
||||
Args:
|
||||
key: The environment variable to look up.
|
||||
default: The default value to return if the environment variable is not set.
|
||||
error_message: the error message which will be raised if the key is not found
|
||||
and no default value is provided.
|
||||
This will be raised as a ValueError.
|
||||
"""
|
||||
|
||||
def get_from_env_fn() -> str: # type: ignore
|
||||
"""Get a value from an environment variable."""
|
||||
if key in os.environ:
|
||||
return os.environ[key]
|
||||
elif isinstance(default, str):
|
||||
return default
|
||||
else:
|
||||
if error_message:
|
||||
raise ValueError(error_message)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Did not find {key}, please add an environment variable"
|
||||
f" `{key}` which contains it, or pass"
|
||||
f" `{key}` as a named parameter."
|
||||
)
|
||||
|
||||
return get_from_env_fn
|
||||
|
@ -25,6 +25,7 @@ EXPECTED_ALL = [
|
||||
"comma_list",
|
||||
"stringify_value",
|
||||
"pre_init",
|
||||
"from_env",
|
||||
]
|
||||
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import re
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from typing import Any, Dict, Optional, Tuple, Type, Union
|
||||
@ -8,6 +9,7 @@ import pytest
|
||||
from langchain_core import utils
|
||||
from langchain_core.utils import (
|
||||
check_package_version,
|
||||
from_env,
|
||||
get_pydantic_field_names,
|
||||
guard_import,
|
||||
)
|
||||
@ -219,3 +221,36 @@ def test_get_pydantic_field_names_v1() -> None:
|
||||
result = get_pydantic_field_names(PydanticModel)
|
||||
expected = {"field1", "field2", "aliased_field", "alias_field"}
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_from_env_with_env_variable() -> None:
|
||||
key = "TEST_KEY"
|
||||
value = "test_value"
|
||||
with patch.dict(os.environ, {key: value}):
|
||||
get_value = from_env(key)
|
||||
assert get_value() == value
|
||||
|
||||
|
||||
def test_from_env_with_default_value() -> None:
|
||||
key = "TEST_KEY"
|
||||
default_value = "default_value"
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
get_value = from_env(key, default=default_value)
|
||||
assert get_value() == default_value
|
||||
|
||||
|
||||
def test_from_env_with_error_message() -> None:
|
||||
key = "TEST_KEY"
|
||||
error_message = "Custom error message"
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
get_value = from_env(key, error_message=error_message)
|
||||
with pytest.raises(ValueError, match=error_message):
|
||||
get_value()
|
||||
|
||||
|
||||
def test_from_env_with_default_error_message() -> None:
|
||||
key = "TEST_KEY"
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
get_value = from_env(key)
|
||||
with pytest.raises(ValueError, match=f"Did not find {key}"):
|
||||
get_value()
|
||||
|
Loading…
Reference in New Issue
Block a user