partial_rollback
Your Name 10 months ago
parent 063e94b4c8
commit 1879788705

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Tuple
from typing import Any, Tuple, Sequence
import torch
from hivemind import nested_flatten, nested_pack
@ -18,7 +18,7 @@ def _get_tensor_index(item: bytes) -> int:
return int(item[3:])
def pack_args_kwargs(*args, **kwargs) -> Tuple[List[torch.Tensor], Any]:
def pack_args_kwargs(*args, **kwargs) -> Tuple[Sequence[torch.Tensor], Any]:
"""
Check the function's arguments and pack all tensors into different flattened lists.
:returns: a flattened list of tensors and args and kwargs, where tensors were masked
@ -35,7 +35,7 @@ def pack_args_kwargs(*args, **kwargs) -> Tuple[List[torch.Tensor], Any]:
return flat_tensors, nested_pack(masked_flat_values, (args, kwargs))
def unpack_args_kwargs(flat_tensors: List[torch.Tensor], args_structure: Any):
def unpack_args_kwargs(flat_tensors: Sequence[torch.Tensor], args_structure: Any):
"""
Restore arguments after `pack_args_kwargs` function.
:returns: list of args and dict of kwargs

Loading…
Cancel
Save