mirror of
https://github.com/bigscience-workshop/petals
synced 2024-11-18 03:25:33 +00:00
inference session
This commit is contained in:
parent
a7be94e6e7
commit
e7f716502c
@ -1,14 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import random
|
||||
|
||||
import torch
|
||||
from hivemind import DHT, get_logger, use_hivemind_log_handler
|
||||
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
||||
from hivemind.moe.expert_uid import ExpertInfo
|
||||
from torch import nn
|
||||
|
||||
from src import DistributedBloomConfig
|
||||
from src import DistributedBloomConfig, RemoteTransformerBlock
|
||||
from src.client.remote_sequence_info import RemoteSequenceInfo
|
||||
from src.data_structures import UID_DELIMITER
|
||||
from src.dht_utils import _create_remote_modules_from_infos
|
||||
@ -78,74 +80,52 @@ class RemoteSequential(nn.Sequential):
|
||||
class RemoteSequentialInferenceSession:
|
||||
"""An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
|
||||
|
||||
def __init__(self, remote_sequence_info: RemoteSequenceInfo):
|
||||
def __init__(self, remote_sequence_info: RemoteSequenceInfo, p2p: P2P):
|
||||
self.remote_sequence_info = remote_sequence_info
|
||||
self.p2p = p2p
|
||||
self.closed = False
|
||||
self.stack = contextlib.ExitStack()
|
||||
self.active_sessions = []
|
||||
|
||||
def __enter__(self):
|
||||
assert not self.closed
|
||||
self.stack.__enter__()
|
||||
# TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
|
||||
current_final_block = 0
|
||||
self.active_chain = []
|
||||
|
||||
while current_final_block != len(remote_sequence_info):
|
||||
candidate_spans = remote_sequence_info.spans_containing_block[current_final_block]
|
||||
current_block = 0
|
||||
while current_block != len(self.remote_sequence_info):
|
||||
candidate_spans = self.remote_sequence_info.spans_containing_block[current_block]
|
||||
chosen_span = random.choice(candidate_spans) # TODO this is a temporary code
|
||||
assert chosen_span.start <= current_final_block < chosen_span.end
|
||||
assert chosen_span.start <= current_block < chosen_span.end
|
||||
|
||||
self.active_chain.append((current_final_block, chosen_span.end, chosen_span))
|
||||
current_final_block = chosen_span.end
|
||||
# TODO begin throwaway prototype code
|
||||
remote = RemoteTransformerBlock(self.remote_sequence_info.block_infos[current_block], self.p2p)
|
||||
remote.info
|
||||
span_uids = self.remote_sequence_info.block_uids[current_block: chosen_span.end]
|
||||
remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
|
||||
|
||||
self.active_sessions.append(remote.inference_session())
|
||||
print('BEGIN', current_block, remote, self.active_sessions[-1])
|
||||
self.stack.enter_context(self.active_sessions[-1])
|
||||
current_block = chosen_span.end
|
||||
# TODO end throwaway prototype code
|
||||
|
||||
return self
|
||||
|
||||
# def step(self, new_hidden_states: torch.Tensor):
|
||||
# """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
|
||||
# if self.closed:
|
||||
# raise Exception("Session is closed, cannot perform step")
|
||||
# # serialize inputs and put them into the queue
|
||||
# inputs = (new_hidden_states,)
|
||||
# outputs_serialized = RemoteExpertWorker.run_coroutine(
|
||||
# self._step(
|
||||
# runtime_pb2.ExpertRequest(
|
||||
# uid=self.uid,
|
||||
# tensors=[
|
||||
# serialize_torch_tensor(tensor, proto.compression)
|
||||
# for tensor, proto in zip(inputs, nested_flatten(self.info["forward_schema"]))
|
||||
# ],
|
||||
# )
|
||||
# )
|
||||
# )
|
||||
# outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
|
||||
# assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
|
||||
# return outputs[0]
|
||||
#
|
||||
# async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
|
||||
# """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
|
||||
# await self._inputs_queue.put(inputs_serialized)
|
||||
# return await anext(self._outputs_stream)
|
||||
#
|
||||
# def close(self):
|
||||
# """Finish a given inference session, close the underlying connection"""
|
||||
# if self._outputs_stream is None:
|
||||
# return # already closed
|
||||
# RemoteExpertWorker.run_coroutine(self._aclose_stream())
|
||||
# self._outputs_stream = self._inputs_queue = None
|
||||
# self.closed = True
|
||||
#
|
||||
# async def _aclose_stream(self):
|
||||
# """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
|
||||
# if self._outputs_stream is None:
|
||||
# return # already closed
|
||||
# await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
|
||||
# try:
|
||||
# await anext(self._outputs_stream)
|
||||
# except StopAsyncIteration:
|
||||
# pass
|
||||
#
|
||||
# def __del__(self):
|
||||
# self.close()
|
||||
#
|
||||
# def __enter__(self):
|
||||
# assert not self.closed
|
||||
# return self
|
||||
#
|
||||
# def __exit__(self, *exc_details):
|
||||
# self.close()
|
||||
def step(self, inputs: torch.Tensor):
|
||||
assert not self.closed
|
||||
for session in self.active_sessions:
|
||||
outputs = session.step(inputs)
|
||||
assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
|
||||
inputs = outputs
|
||||
|
||||
def close(self, *exc_details):
|
||||
"""Finish a given inference session, close the underlying connection"""
|
||||
assert not self.closed
|
||||
self.active_sessions.clear()
|
||||
self.closed = True
|
||||
|
||||
def __exit__(self, *exc_details):
|
||||
self.close()
|
||||
|
||||
def __del__(self):
|
||||
self.close()
|
||||
|
Loading…
Reference in New Issue
Block a user