Merge branch 'main' into examples_fix_hivemind

pull/88/head
Alexander Borzunov 1 year ago committed by GitHub
commit 622532a9a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -10,14 +10,60 @@ from hivemind.compression.serialization import deserialize_tensor_stream, deseri
from hivemind.p2p import StubBase
from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
from hivemind.utils.asyncio import aiter_with_timeout, iter_as_aiter
from hivemind.utils.streaming import split_for_streaming
from src.data_structures import ModuleUID, RPCInfo
async def _forward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
) -> List[torch.Tensor]:
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
timeout=timeout,
)
return [deserialize_torch_tensor(t) for t in outputs.tensors]
async def _backward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
) -> List[torch.Tensor]:
grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs),
timeout=timeout,
)
return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
async def _forward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
) -> List[torch.Tensor]:
parts = (
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
for tensor in serialized_tensors
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
)
outputs = await asyncio.wait_for(stub.rpc_forward_stream(iter_as_aiter(parts)), timeout)
outputs = aiter_with_timeout(outputs, timeout)
return await deserialize_tensor_stream(msg.tensors async for msg in outputs)
async def _backward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, timeout: float, **kwargs
) -> List[torch.Tensor]:
parts = (
runtime_pb2.ExpertRequest(uid=uid, tensors=[part], **kwargs)
for tensor in serialized_tensors
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
)
grad_inputs = await asyncio.wait_for(stub.rpc_backward_stream(iter_as_aiter(parts)), timeout)
grad_inputs = aiter_with_timeout(grad_inputs, timeout)
return await deserialize_tensor_stream(msg.tensors async for msg in grad_inputs)
async def run_remote_forward(
uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, metadata: bytes = b"", **kwargs
uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, timeout: float, **kwargs
) -> Tuple[torch.Tensor, ...]:
"""
Serializes input tensors and calls "rpc_forward" on a remote server.
@ -57,53 +103,13 @@ async def run_remote_forward(
# call RPC on remote server
size = sum(t.element_size() * t.nelement() for t in inputs)
if size > MAX_UNARY_PAYLOAD_SIZE:
deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, **kwargs)
deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, timeout, **kwargs)
else:
deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, **kwargs)
deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, timeout, **kwargs)
return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
async def _forward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
) -> List[torch.Tensor]:
split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
outputs = await stub.rpc_forward_stream(
amap_in_executor(
lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
iter_as_aiter(split),
),
)
tensors_stream = amap_in_executor(lambda msg: msg.tensors, outputs)
return await deserialize_tensor_stream(tensors_stream)
async def _forward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
) -> List[torch.Tensor]:
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
)
return [deserialize_torch_tensor(t) for t in outputs.tensors]
async def _backward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
) -> List[torch.Tensor]:
split = (part for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE))
grad_inputs = await stub.rpc_backward_stream(
amap_in_executor(
lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
iter_as_aiter(split),
),
)
tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs)
return await deserialize_tensor_stream(tensors_stream)
async def run_remote_backward(
uid: ModuleUID,
stub: StubBase,
@ -111,6 +117,7 @@ async def run_remote_backward(
inputs: torch.Tensor,
grad_outputs: List[torch.Tensor],
*extra_tensors: torch.Tensor,
timeout: float,
**kwargs,
) -> Sequence[torch.Tensor]:
"""
@ -140,17 +147,8 @@ async def run_remote_backward(
size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
if size > MAX_UNARY_PAYLOAD_SIZE:
deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, **kwargs)
deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, timeout, **kwargs)
else:
deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, **kwargs)
deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, timeout, **kwargs)
return deserialized_grad_inputs
async def _backward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
) -> List[torch.Tensor]:
grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
)
return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]

@ -24,7 +24,15 @@ class RemoteSequenceManager:
In future, this class is intended to maintain latency statistics, ban non-responsive peers, etc.
"""
def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID], p2p: P2P, max_retries: int = 3):
def __init__(
self,
dht: DHT,
block_uids: Sequence[ModuleUID],
p2p: P2P,
max_retries: int = 3,
timeout: float = 5,
min_backoff: float = 1,
):
assert len(block_uids) > 0, "Sequences must contain at least one block"
self.dht, self.p2p = dht, p2p
self.block_uids: List[ModuleUID] = list(block_uids)
@ -33,6 +41,7 @@ class RemoteSequenceManager:
self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids)))
self.last_update_time: DHTExpiration = -float("inf")
self.max_retries = max_retries
self.timeout, self.min_backoff = timeout, min_backoff
self._rpc_info = None
self.lock_changes = threading.Lock()
self.update_()

@ -24,7 +24,6 @@ async def sequential_forward(
sequence_manager: RemoteSequenceManager,
start_index: int = 0,
end_index: Optional[int] = None,
min_backoff: float = 1.0,
) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
"""
Constructs a routing path from <start_index> to <end_index>.
@ -53,7 +52,9 @@ async def sequential_forward(
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
inputs_and_prompts = [inputs, prompts[span.start : span.end]]
(outputs,) = await run_remote_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts)
(outputs,) = await run_remote_forward(
span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts, timeout=sequence_manager.timeout
)
assert isinstance(outputs, torch.Tensor)
assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
@ -66,7 +67,7 @@ async def sequential_forward(
break
except Exception as e:
logging.warning(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
await asyncio.sleep(min_backoff * 2**attempt_no)
await asyncio.sleep(sequence_manager.min_backoff * 2**attempt_no)
backup_sequences = sequence_manager.make_sequence(span.start)
assert backup_sequences[0].start == span.start
@ -81,7 +82,6 @@ async def sequential_backward(
prompts: torch.Tensor,
forward_sequences: List[RemoteSpanInfo],
sequence_manager: RemoteSequenceManager,
min_backoff: float = 1.0,
) -> Sequence[torch.Tensor]:
"""
Performs chained backward for each forward subsequence.
@ -98,14 +98,20 @@ async def sequential_backward(
try:
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
grad_outputs, *span_grad_prompts = await run_remote_backward(
span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts[span.start : span.end]
span_uids,
stub,
sequence_manager.rpc_info,
inputs,
grad_outputs,
prompts[span.start : span.end],
timeout=sequence_manager.timeout,
)
grad_outputs = [grad_outputs]
grad_prompts_reversed.extend(span_grad_prompts)
break
except Exception as e:
logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
await asyncio.sleep(min_backoff * 2**attempt_no)
await asyncio.sleep(sequence_manager.min_backoff * 2**attempt_no)
_, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end

Loading…
Cancel
Save