forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
47 lines
1.6 KiB
Python
47 lines
1.6 KiB
Python
"""Generic utility functions."""
|
|
import os
|
|
from typing import Any, Callable, Dict, Optional, Tuple
|
|
|
|
|
|
def get_from_dict_or_env(
|
|
data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
|
|
) -> str:
|
|
"""Get a value from a dictionary or an environment variable."""
|
|
if key in data and data[key]:
|
|
return data[key]
|
|
elif env_key in os.environ and os.environ[env_key]:
|
|
return os.environ[env_key]
|
|
elif default is not None:
|
|
return default
|
|
else:
|
|
raise ValueError(
|
|
f"Did not find {key}, please add an environment variable"
|
|
f" `{env_key}` which contains it, or pass"
|
|
f" `{key}` as a named parameter."
|
|
)
|
|
|
|
|
|
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
|
|
"""Validate specified keyword args are mutually exclusive."""
|
|
|
|
def decorator(func: Callable) -> Callable:
|
|
def wrapper(*args: Any, **kwargs: Any) -> Callable:
|
|
"""Validate exactly one arg in each group is not None."""
|
|
counts = [
|
|
sum(1 for arg in arg_group if kwargs.get(arg) is not None)
|
|
for arg_group in arg_groups
|
|
]
|
|
invalid_groups = [i for i, count in enumerate(counts) if count != 1]
|
|
if invalid_groups:
|
|
invalid_group_names = [", ".join(arg_groups[i]) for i in invalid_groups]
|
|
raise ValueError(
|
|
"Exactly one argument in each of the following"
|
|
" groups must be defined:"
|
|
f" {', '.join(invalid_group_names)}"
|
|
)
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|