swap to int64 (rationale: pytorch does not support uint64)

inference_chain
justheuristic 2 years ago
parent 62d7fde8af
commit 8092bd31ff

@ -6,8 +6,7 @@ from hivemind import DHT
from torch import nn
from src import DistributedBloomConfig
MAX_LENGTH = 128 #TODO un-hardcode
from src.server.backend import MAX_LENGTH
class RemoteInferenceChain(nn.Module):

@ -2,6 +2,7 @@ from concurrent.futures import Future
from functools import partial
from typing import List, Optional, Union, Sequence
import torch
from hivemind.moe.client import RemoteExpert
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.moe.expert_uid import ExpertUID
@ -10,18 +11,30 @@ from hivemind.p2p import StubBase, P2P
from hivemind.proto.runtime_pb2 import ExpertInfo
from hivemind.dht import DHT
from hivemind.utils import MPFuture, DHTExpiration
from src import DistributedBloomConfig
from src.server.backend import MAX_LENGTH
from src.server.handler import TransformerConnectionHandler
class RemoteTransformerBlock(RemoteExpert):
class RemoteTransformerBlockSession(RemoteExpert):
"""A class that interacts with a specific remote server for forward/backward or inference"""
def __init__(self, config: DistributedBloomConfig, info: ExpertInfo, p2p: P2P):
super().__init__(info, p2p)
self._config = config
self._inputs_cache = torch.empty(1, MAX_LENGTH, config.hidden_size, dtype=config.dtype)
self._active_stream: Optional[RemoteTransformerStream] = None
@property
def stub(self) -> StubBase:
return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
def get_remote_module(
dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
) -> Union[List[Optional[RemoteTransformerBlock]], MPFuture[List[Optional[RemoteTransformerBlock]]]]:
) -> Union[List[Optional[RemoteTransformerBlockSession]], MPFuture[List[Optional[RemoteTransformerBlockSession]]]]:
"""
:param uids: find experts with these ids from across the DHT
:param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
@ -35,7 +48,7 @@ def get_remote_module(
def create_remote_module(
infos: Union[Sequence[Optional[ExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
) -> Union[List[Optional[RemoteTransformerBlock]], Future]:
) -> Union[List[Optional[RemoteTransformerBlockSession]], Future]:
if return_future:
async def _unpack(infos_future: MPFuture, dht: DHT):
@ -48,10 +61,10 @@ def create_remote_module(
def _create_remote_experts(infos: Sequence[Optional[ExpertInfo]], p2p: P2P) -> List[Optional[RemoteTransformerBlock]]:
experts: List[Optional[RemoteTransformerBlock]] = []
experts: List[Optional[RemoteTransformerBlockSession]] = []
for info in infos:
if info is not None:
experts.append(RemoteTransformerBlock(info, p2p))
experts.append(RemoteTransformerBlockSession(info, p2p))
else:
experts.append(None)
return experts

@ -7,6 +7,8 @@ from hivemind.moe.server.task_pool import TaskPool
from src.server.cache import MemoryCache
MAX_LENGTH = 2048
class TransformerBackend(ModuleBackend):
"""A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
@ -22,7 +24,10 @@ class TransformerBackend(ModuleBackend):
self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
def inference_step(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
def inference_step(self, *inputs: torch.Tensor, attention_cache_handle: torch.IntTensor) -> Tuple[torch.Tensor, ...]:
attention_cache_handle = int(attention_cache_handle.item())
print('HANDLE:', attention_cache_handle)
with self.memory_cache.use_cache(attention_cache_handle) as cache:
cache[...] += 1
return inputs[0] + cache

@ -35,8 +35,8 @@ class MemoryCache:
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
self.device = device
self.lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
self._current_size = mp.Value(ctypes.c_uint64, 0, lock=False)
self._handle_counter = mp.Value(ctypes.c_uint64, 0, lock=False)
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
self._active_handles: Optional[Dict[Handle, TensorDescriptor]] = None
self._allocated_tensors: Optional[Dict[Handle, torch.Tensor]] = None
self.runtime_pid = os.getpid()

@ -26,7 +26,8 @@ class TransformerConnectionHandler(ConnectionHandler):
assert isinstance(backend, TransformerBackend)
inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
async with backend.memory_cache.allocate_cache(TensorDescriptor(size=(1,2,3), dtype=torch.float32)):
async with backend.memory_cache.allocate_cache(TensorDescriptor(size=(1,2,3), dtype=torch.float32)) as handle:
inputs.append(torch.tensor([handle], dtype=torch.int64))
outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema)
yield runtime_pb2.ExpertResponse(tensors=outputs)

Loading…
Cancel
Save