black-isort

pull/18/head
justheuristic 2 years ago
parent 894cd5d586
commit 6e3db6bed6

@ -68,7 +68,7 @@ def main():
compression = getattr(CompressionType, compression_type)
use_auth_token = args.pop("use_auth_token")
args['use_auth_token'] = True if use_auth_token in ('True', 'true', '') else use_auth_token
args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
server = Server.create(**args, start=True, compression=compression)

@ -34,7 +34,7 @@ def load_pretrained_block(
block_index: int,
config: Optional[DistributedBloomConfig] = None,
torch_dtype: Union[torch.dtype, str] = "auto",
use_auth_token: Optional[str]=None
use_auth_token: Optional[str] = None,
) -> BloomBlock:
"""Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
if config is None:

@ -8,7 +8,7 @@ import contextlib
import ctypes
import multiprocessing as mp
import os
from typing import Dict, Optional, Union, AsyncContextManager
from typing import AsyncContextManager, Dict, Optional, Union
import hivemind
import torch

@ -14,6 +14,7 @@ from src.server.backend import MAX_LENGTH, TransformerBackend
class TransformerConnectionHandler(ConnectionHandler):
"""Handles three request types: forward, backward and forward-incremental (inference)"""
module_backends: Dict[ModuleUID, TransformerBackend]
def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]):
@ -42,18 +43,23 @@ class TransformerConnectionHandler(ConnectionHandler):
# run request tensors through all requested modules, update caches
for backend, cache_handle in zip(requested_backends, cache_handles):
cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, prefix_length
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3, \
f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
assert (
len(hidden_states) == 1 and hidden_states[0].ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states)
assert isinstance(hidden_states, (list, tuple))
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3
# serialize and send last layer outputs
yield runtime_pb2.ExpertResponse(tensors=[
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema))
])
yield runtime_pb2.ExpertResponse(
tensors=[
serialize_torch_tensor(result, proto.compression, allow_inplace=True)
for result, proto in zip(
hidden_states, nested_flatten(requested_backends[-1].outputs_schema)
)
]
)
# prepare for next step
prefix_length += hidden_states[0].shape[1]
@ -63,7 +69,7 @@ class TransformerConnectionHandler(ConnectionHandler):
def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]:
"""Check that the first request to rpc_inference is valid"""
uids = (request.uid or '').split(CHAIN_DELIMITER)
uids = (request.uid or "").split(CHAIN_DELIMITER)
if not uids:
raise RuntimeError("User did not provide any uids")
for uid in uids:
@ -86,11 +92,3 @@ class TransformerConnectionHandler(ConnectionHandler):
handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
yield handles

@ -143,7 +143,7 @@ class Server(threading.Thread):
block_index,
block_config,
torch_dtype=torch_dtype,
use_auth_token=use_auth_token
use_auth_token=use_auth_token,
)
for param in block.parameters():
param.requires_grad = False

@ -34,7 +34,7 @@ def test_remote_block_exact_match(atol_inference=1e-4):
assert isinstance(remote_block, RemoteTransformerBlock)
_ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info
remote_block._info = ExpertInfo('bloom6b3.3 bloom6b3.4', remote_block._info.peer_id)
remote_block._info = ExpertInfo("bloom6b3.3 bloom6b3.4", remote_block._info.peer_id)
inputs = torch.randn(1, 8, 4096)
@ -46,7 +46,7 @@ def test_remote_block_exact_match(atol_inference=1e-4):
ref_blocks = [
load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32)
load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
]
outputs_ref = []
caches = [None, None]

Loading…
Cancel
Save