basic chained inference (multiple blocks per one RPC call)

fix-auth-token
justheuristic 2 years ago
parent ce9556dcb0
commit 1cdf8a77fb

@ -51,8 +51,9 @@ import torch
import hivemind
from src import get_remote_module
dht = hivemind.DHT(
initial_peers=["/ip4/127.0.0.1/COPY_FULL_ADDRESS_FROM_ANY_OF_THE_SERVERS"],
initial_peers=[TODO_COPY_FULL_ADDRESS_FROM_ANY_OF_THE_SERVERS], # e.g. /ip4/127.0.0.1/...
client_mode=True, start=True,
)

@ -8,7 +8,7 @@ import contextlib
import ctypes
import multiprocessing as mp
import os
from typing import Dict, Optional, Union
from typing import Dict, Optional, Union, AsyncContextManager
import hivemind
import torch
@ -54,7 +54,7 @@ class MemoryCache:
self._handle_counter.value = value
@contextlib.asynccontextmanager
async def allocate_cache(self, descr: TensorDescriptor) -> Handle:
async def allocate_cache(self, descr: TensorDescriptor) -> AsyncContextManager[Handle]:
"""
Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.

@ -1,21 +1,24 @@
from typing import AsyncIterator, Dict
import contextlib
from typing import AsyncIterator, Dict, Sequence
import torch
from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten
from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, serialize_torch_tensor
from hivemind.moe.server.connection_handler import ConnectionHandler
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import anext
from src.data_structures import CHAIN_DELIMITER, ModuleUID
from src.server.backend import MAX_LENGTH, TransformerBackend
class TransformerConnectionHandler(ConnectionHandler):
"""Handles three request types: forward, backward and forward-incremental (inference)"""
module_backends: Dict[ModuleUID, TransformerBackend]
def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]):
for module_backend in module_backends.values():
assert isinstance(module_backend, TransformerBackend)
super().__init__(dht, module_backends)
for module_backend in self.module_backends.values():
assert isinstance(module_backend, TransformerBackend)
async def rpc_inference(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
@ -24,28 +27,69 @@ class TransformerConnectionHandler(ConnectionHandler):
try:
print("OPENED RPC_INFERENCE")
request = await anext(requests)
if not request.uid:
raise RuntimeError("User did not provide any uids.")
backend = self.module_backends[request.uid]
assert isinstance(backend, TransformerBackend)
# prepare attention cache
num_heads = backend.module.self_attention.num_heads
head_dim = backend.module.self_attention.head_dim
requested_uids = self._check_header(request)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
cache_metadata = torch.tensor([[-1, -1]], dtype=torch.int64) # [cache_handle, prefix_length]
cache_descriptor = TensorDescriptor(size=(2, 1, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32)
prefix_length = 0
async with backend.memory_cache.allocate_cache(cache_descriptor) as cache_handle:
while request.uid or request.tensors: # iterate while user is willing to supply tensors
inputs = [cache_metadata, *(deserialize_torch_tensor(tensor) for tensor in request.tensors)]
print("INPUTS:", inputs)
assert len(inputs) == 2 and inputs[1].ndim == 3, "send only hidden states for now"
cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, prefix_length
outputs = await self._process_inputs(inputs, backend.inference_pool, backend.outputs_schema)
yield runtime_pb2.ExpertResponse(tensors=outputs)
async with self._allocate_caches(requested_backends) as cache_handles:
assert len(cache_handles) == len(requested_backends)
while request.tensors: # iterate while user is willing to supply tensors
hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
# run request tensors through all requested modules, update caches
for backend, cache_handle in zip(requested_backends, cache_handles):
cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, prefix_length
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3, \
f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
assert isinstance(hidden_states, (list, tuple))
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
prefix_length += inputs[1].shape[1]
# serialize and send last layer outputs
yield runtime_pb2.ExpertResponse(tensors=[
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
])
# prepare for next step
prefix_length += hidden_states[0].shape[1]
request = await (anext(requests))
finally:
print("CLOSED RPC_INFERENCE")
def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]:
"""Check that the first request to rpc_inference is valid"""
uids = (request.uid or '').split(CHAIN_DELIMITER)
if not uids:
raise RuntimeError("User did not provide any uids")
for uid in uids:
if uid not in self.module_backends:
raise RuntimeError(f"Remote peer does not serve {uid}")
return tuple(uids)
@contextlib.asynccontextmanager
async def _allocate_caches(self, backends: Sequence[TransformerBackend]) -> Sequence[int]:
"""Allocate memory caches for each transformer block, return cache handles"""
async with contextlib.AsyncExitStack() as stack:
handles = []
for backend in backends:
num_heads = backend.module.self_attention.num_heads
head_dim = backend.module.self_attention.head_dim
cache_descriptor = TensorDescriptor(size=(2, 1, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32)
# [key_or_value, batch_size, max_length, num_heads, head_dim]
handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
yield handles

@ -0,0 +1,64 @@
######
# Warning:torch this test is a work in progress. It will be modified soon.
# - if you want more stable tests, see test_block_exact_match
# - if you want to figure out chained inference, ask yozh
import os
import hivemind
import torch
from hivemind.moe.expert_uid import ExpertInfo
from src.bloom.from_pretrained import load_pretrained_block
from src.client.remote_block import RemoteTransformerBlock
from src.dht_utils import get_remote_module
INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
if not INITIAL_PEERS:
raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
INITIAL_PEERS = INITIAL_PEERS.split()
BLOCK_UID = os.environ.get("BLOCK_UID")
if not BLOCK_UID:
raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested")
REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID[-1].split(".")[-1]))
def test_remote_block_exact_match(atol_inference=1e-4):
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
remote_block = get_remote_module(dht, BLOCK_UID)
assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
assert isinstance(remote_block, RemoteTransformerBlock)
_ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info
remote_block._info = ExpertInfo('bloom6b3.3 bloom6b3.4', remote_block._info.peer_id)
inputs = torch.randn(1, 8, 4096)
outputs_inference = []
with remote_block.begin_inference_session() as sess:
for i in range(inputs.shape[1]):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
outputs_inference = torch.cat(outputs_inference, dim=1)
ref_blocks = [
load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32)
]
outputs_ref = []
caches = [None, None]
for i in range(inputs.shape[1]):
new_caches = []
hidden_states = inputs[:, i : i + 1, :]
for ref_block, cache in zip(ref_blocks, caches):
with torch.no_grad():
hidden_states, new_cache = ref_block.forward(hidden_states, use_cache=True, layer_past=cache)
new_caches.append(new_cache)
outputs_ref.append(hidden_states)
caches = new_caches
outputs_ref = torch.cat(outputs_ref, dim=1)
assert torch.allclose(outputs_ref, outputs_inference, rtol=0, atol=atol_inference)
Loading…
Cancel
Save