|
|
|
@ -1,4 +1,4 @@
|
|
|
|
|
from typing import Any, Dict, List, Tuple
|
|
|
|
|
from typing import Any, Tuple, Sequence
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from hivemind import nested_flatten, nested_pack
|
|
|
|
@ -18,7 +18,7 @@ def _get_tensor_index(item: bytes) -> int:
|
|
|
|
|
return int(item[3:])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pack_args_kwargs(*args, **kwargs) -> Tuple[List[torch.Tensor], Any]:
|
|
|
|
|
def pack_args_kwargs(*args, **kwargs) -> Tuple[Sequence[torch.Tensor], Any]:
|
|
|
|
|
"""
|
|
|
|
|
Check the function's arguments and pack all tensors into different flattened lists.
|
|
|
|
|
:returns: a flattened list of tensors and args and kwargs, where tensors were masked
|
|
|
|
@ -35,7 +35,7 @@ def pack_args_kwargs(*args, **kwargs) -> Tuple[List[torch.Tensor], Any]:
|
|
|
|
|
return flat_tensors, nested_pack(masked_flat_values, (args, kwargs))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def unpack_args_kwargs(flat_tensors: List[torch.Tensor], args_structure: Any):
|
|
|
|
|
def unpack_args_kwargs(flat_tensors: Sequence[torch.Tensor], args_structure: Any):
|
|
|
|
|
"""
|
|
|
|
|
Restore arguments after `pack_args_kwargs` function.
|
|
|
|
|
:returns: list of args and dict of kwargs
|
|
|
|
|