|
|
|
@ -30,11 +30,20 @@ Output = TypeVar("Output")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
|
|
|
|
|
"""Run a coroutine with a semaphore.
|
|
|
|
|
Args:
|
|
|
|
|
semaphore: The semaphore to use.
|
|
|
|
|
coro: The coroutine to run.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The result of the coroutine.
|
|
|
|
|
"""
|
|
|
|
|
async with semaphore:
|
|
|
|
|
return await coro
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list:
|
|
|
|
|
"""Gather coroutines with a limit on the number of concurrent coroutines."""
|
|
|
|
|
if n is None:
|
|
|
|
|
return await asyncio.gather(*coros)
|
|
|
|
|
|
|
|
|
@ -44,6 +53,7 @@ async def gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> lis
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def accepts_run_manager(callable: Callable[..., Any]) -> bool:
|
|
|
|
|
"""Check if a callable accepts a run_manager argument."""
|
|
|
|
|
try:
|
|
|
|
|
return signature(callable).parameters.get("run_manager") is not None
|
|
|
|
|
except ValueError:
|
|
|
|
@ -51,6 +61,7 @@ def accepts_run_manager(callable: Callable[..., Any]) -> bool:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def accepts_config(callable: Callable[..., Any]) -> bool:
|
|
|
|
|
"""Check if a callable accepts a config argument."""
|
|
|
|
|
try:
|
|
|
|
|
return signature(callable).parameters.get("config") is not None
|
|
|
|
|
except ValueError:
|
|
|
|
@ -58,6 +69,8 @@ def accepts_config(callable: Callable[..., Any]) -> bool:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IsLocalDict(ast.NodeVisitor):
|
|
|
|
|
"""Check if a name is a local dict."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, name: str, keys: Set[str]) -> None:
|
|
|
|
|
self.name = name
|
|
|
|
|
self.keys = keys
|
|
|
|
@ -88,6 +101,8 @@ class IsLocalDict(ast.NodeVisitor):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IsFunctionArgDict(ast.NodeVisitor):
|
|
|
|
|
"""Check if the first argument of a function is a dict."""
|
|
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
self.keys: Set[str] = set()
|
|
|
|
|
|
|
|
|
@ -105,17 +120,22 @@ class IsFunctionArgDict(ast.NodeVisitor):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GetLambdaSource(ast.NodeVisitor):
|
|
|
|
|
"""Get the source code of a lambda function."""
|
|
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
"""Initialize the visitor."""
|
|
|
|
|
self.source: Optional[str] = None
|
|
|
|
|
self.count = 0
|
|
|
|
|
|
|
|
|
|
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
|
|
|
|
"""Visit a lambda function."""
|
|
|
|
|
self.count += 1
|
|
|
|
|
if hasattr(ast, "unparse"):
|
|
|
|
|
self.source = ast.unparse(node)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_function_first_arg_dict_keys(func: Callable) -> Optional[List[str]]:
|
|
|
|
|
"""Get the keys of the first argument of a function if it is a dict."""
|
|
|
|
|
try:
|
|
|
|
|
code = inspect.getsource(func)
|
|
|
|
|
tree = ast.parse(textwrap.dedent(code))
|
|
|
|
@ -190,6 +210,8 @@ _T_contra = TypeVar("_T_contra", contravariant=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SupportsAdd(Protocol[_T_contra, _T_co]):
|
|
|
|
|
"""Protocol for objects that support addition."""
|
|
|
|
|
|
|
|
|
|
def __add__(self, __x: _T_contra) -> _T_co:
|
|
|
|
|
...
|
|
|
|
|
|
|
|
|
@ -198,6 +220,7 @@ Addable = TypeVar("Addable", bound=SupportsAdd[Any, Any])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add(addables: Iterable[Addable]) -> Optional[Addable]:
|
|
|
|
|
"""Add a sequence of addable objects together."""
|
|
|
|
|
final = None
|
|
|
|
|
for chunk in addables:
|
|
|
|
|
if final is None:
|
|
|
|
@ -208,6 +231,7 @@ def add(addables: Iterable[Addable]) -> Optional[Addable]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
|
|
|
|
|
"""Asynchronously add a sequence of addable objects together."""
|
|
|
|
|
final = None
|
|
|
|
|
async for chunk in addables:
|
|
|
|
|
if final is None:
|
|
|
|
@ -218,6 +242,8 @@ async def aadd(addables: AsyncIterable[Addable]) -> Optional[Addable]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConfigurableField(NamedTuple):
|
|
|
|
|
"""A field that can be configured by the user."""
|
|
|
|
|
|
|
|
|
|
id: str
|
|
|
|
|
|
|
|
|
|
name: Optional[str] = None
|
|
|
|
@ -226,6 +252,8 @@ class ConfigurableField(NamedTuple):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConfigurableFieldSingleOption(NamedTuple):
|
|
|
|
|
"""A field that can be configured by the user with a default value."""
|
|
|
|
|
|
|
|
|
|
id: str
|
|
|
|
|
options: Mapping[str, Any]
|
|
|
|
|
default: str
|
|
|
|
@ -235,6 +263,8 @@ class ConfigurableFieldSingleOption(NamedTuple):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConfigurableFieldMultiOption(NamedTuple):
|
|
|
|
|
"""A field that can be configured by the user with multiple default values."""
|
|
|
|
|
|
|
|
|
|
id: str
|
|
|
|
|
options: Mapping[str, Any]
|
|
|
|
|
default: Sequence[str]
|
|
|
|
@ -249,6 +279,8 @@ AnyConfigurableField = Union[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConfigurableFieldSpec(NamedTuple):
|
|
|
|
|
"""A field that can be configured by the user. It is a specification of a field."""
|
|
|
|
|
|
|
|
|
|
id: str
|
|
|
|
|
name: Optional[str]
|
|
|
|
|
description: Optional[str]
|
|
|
|
@ -260,6 +292,7 @@ class ConfigurableFieldSpec(NamedTuple):
|
|
|
|
|
def get_unique_config_specs(
|
|
|
|
|
specs: Iterable[ConfigurableFieldSpec],
|
|
|
|
|
) -> Sequence[ConfigurableFieldSpec]:
|
|
|
|
|
"""Get the unique config specs from a sequence of config specs."""
|
|
|
|
|
grouped = groupby(sorted(specs, key=lambda s: s.id), lambda s: s.id)
|
|
|
|
|
unique: List[ConfigurableFieldSpec] = []
|
|
|
|
|
for id, dupes in grouped:
|
|
|
|
|