diff --git a/libs/core/langchain_core/utils/env.py b/libs/core/langchain_core/utils/env.py index e0841cbd98..d509fca250 100644 --- a/libs/core/langchain_core/utils/env.py +++ b/libs/core/langchain_core/utils/env.py @@ -1,7 +1,7 @@ from __future__ import annotations import os -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional, Union def env_var_is_set(env_var: str) -> bool: @@ -22,13 +22,37 @@ def env_var_is_set(env_var: str) -> bool: def get_from_dict_or_env( - data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None + data: Dict[str, Any], + key: Union[str, List[str]], + env_key: str, + default: Optional[str] = None, ) -> str: - """Get a value from a dictionary or an environment variable.""" - if key in data and data[key]: - return data[key] + """Get a value from a dictionary or an environment variable. + + Args: + data: The dictionary to look up the key in. + key: The key to look up in the dictionary. This can be a list of keys to try + in order. + env_key: The environment variable to look up if the key is not + in the dictionary. + default: The default value to return if the key is not in the dictionary + or the environment. + """ + if isinstance(key, (list, tuple)): + for k in key: + if k in data and data[k]: + return data[k] + + if isinstance(key, str): + if key in data and data[key]: + return data[key] + + if isinstance(key, (list, tuple)): + key_for_err = key[0] else: - return get_from_env(key, env_key, default=default) + key_for_err = key + + return get_from_env(key_for_err, env_key, default=default) def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str: diff --git a/libs/core/tests/unit_tests/utils/test_env.py b/libs/core/tests/unit_tests/utils/test_env.py new file mode 100644 index 0000000000..3cf6d02735 --- /dev/null +++ b/libs/core/tests/unit_tests/utils/test_env.py @@ -0,0 +1,64 @@ +import pytest + +from langchain_core.utils.env import get_from_dict_or_env + + +def test_get_from_dict_or_env() -> None: + assert ( + get_from_dict_or_env( + { + "a": "foo", + }, + ["a"], + "__SOME_KEY_IN_ENV", + ) + == "foo" + ) + + assert ( + get_from_dict_or_env( + { + "a": "foo", + }, + ["b", "a"], + "__SOME_KEY_IN_ENV", + ) + == "foo" + ) + + assert ( + get_from_dict_or_env( + { + "a": "foo", + }, + "a", + "__SOME_KEY_IN_ENV", + ) + == "foo" + ) + + assert ( + get_from_dict_or_env( + { + "a": "foo", + }, + "not exists", + "__SOME_KEY_IN_ENV", + default="default", + ) + == "default" + ) + + # Not the most obvious behavior, but + # this is how it works right now + with pytest.raises(ValueError): + assert ( + get_from_dict_or_env( + { + "a": "foo", + }, + "not exists", + "__SOME_KEY_IN_ENV", + ) + is None + )