You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/utils.py

47 lines
1.6 KiB
Python

"""Generic utility functions."""
import os
from typing import Any, Callable, Dict, Optional, Tuple
def get_from_dict_or_env(
data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
) -> str:
"""Get a value from a dictionary or an environment variable."""
if key in data and data[key]:
return data[key]
elif env_key in os.environ and os.environ[env_key]:
return os.environ[env_key]
elif default is not None:
return default
else:
raise ValueError(
f"Did not find {key}, please add an environment variable"
f" `{env_key}` which contains it, or pass"
f" `{key}` as a named parameter."
)
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
"""Validate specified keyword args are mutually exclusive."""
def decorator(func: Callable) -> Callable:
def wrapper(*args: Any, **kwargs: Any) -> Callable:
"""Validate exactly one arg in each group is not None."""
counts = [
sum(1 for arg in arg_group if kwargs.get(arg) is not None)
for arg_group in arg_groups
]
invalid_groups = [i for i, count in enumerate(counts) if count != 1]
if invalid_groups:
invalid_group_names = [", ".join(arg_groups[i]) for i in invalid_groups]
raise ValueError(
"Exactly one argument in each of the following"
" groups must be defined:"
f" {', '.join(invalid_group_names)}"
)
return func(*args, **kwargs)
return wrapper
return decorator