forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
81 lines
2.6 KiB
Python
81 lines
2.6 KiB
Python
"""Generic utility functions."""
|
|
import os
|
|
from typing import Any, Callable, Dict, Optional, Tuple
|
|
|
|
from requests import HTTPError, Response
|
|
|
|
|
|
def get_from_dict_or_env(
|
|
data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None
|
|
) -> str:
|
|
"""Get a value from a dictionary or an environment variable."""
|
|
if key in data and data[key]:
|
|
return data[key]
|
|
else:
|
|
return get_from_env(key, env_key, default=default)
|
|
|
|
|
|
def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str:
|
|
"""Get a value from a dictionary or an environment variable."""
|
|
if env_key in os.environ and os.environ[env_key]:
|
|
return os.environ[env_key]
|
|
elif default is not None:
|
|
return default
|
|
else:
|
|
raise ValueError(
|
|
f"Did not find {key}, please add an environment variable"
|
|
f" `{env_key}` which contains it, or pass"
|
|
f" `{key}` as a named parameter."
|
|
)
|
|
|
|
|
|
def xor_args(*arg_groups: Tuple[str, ...]) -> Callable:
|
|
"""Validate specified keyword args are mutually exclusive."""
|
|
|
|
def decorator(func: Callable) -> Callable:
|
|
def wrapper(*args: Any, **kwargs: Any) -> Callable:
|
|
"""Validate exactly one arg in each group is not None."""
|
|
counts = [
|
|
sum(1 for arg in arg_group if kwargs.get(arg) is not None)
|
|
for arg_group in arg_groups
|
|
]
|
|
invalid_groups = [i for i, count in enumerate(counts) if count != 1]
|
|
if invalid_groups:
|
|
invalid_group_names = [", ".join(arg_groups[i]) for i in invalid_groups]
|
|
raise ValueError(
|
|
"Exactly one argument in each of the following"
|
|
" groups must be defined:"
|
|
f" {', '.join(invalid_group_names)}"
|
|
)
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
def raise_for_status_with_text(response: Response) -> None:
|
|
"""Raise an error with the response text."""
|
|
try:
|
|
response.raise_for_status()
|
|
except HTTPError as e:
|
|
raise ValueError(response.text) from e
|
|
|
|
|
|
def stringify_value(val: Any) -> str:
|
|
if isinstance(val, str):
|
|
return val
|
|
elif isinstance(val, dict):
|
|
return "\n" + stringify_dict(val)
|
|
elif isinstance(val, list):
|
|
return "\n".join(stringify_value(v) for v in val)
|
|
else:
|
|
return str(val)
|
|
|
|
|
|
def stringify_dict(data: dict) -> str:
|
|
text = ""
|
|
for key, value in data.items():
|
|
text += key + ": " + stringify_value(value) + "\n"
|
|
return text
|