black-isort

pull/9/head
justheuristic 2 years ago
parent eea9287182
commit 14e316b52a

@ -51,4 +51,3 @@ if __name__ == "__main__":
model.transformer.h = torch.nn.ModuleList()
torch.save(model.state_dict(), os.path.join(args.output_path, f"client.pth"))

@ -1,4 +1,5 @@
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # add path to src
import configargparse

@ -1 +1 @@
from .bloom import *
from .bloom import *

@ -1 +1 @@
from src.bloom.model import BloomModel, BloomForCausalLM, DistributedBloomConfig
from src.bloom.model import BloomModel, BloomForCausalLM, DistributedBloomConfig

@ -15,7 +15,8 @@ from src.bloom.ops import (
attention_mask_func,
dropout_add,
pre_process_alibi_for_pad,
split_tensor_along_last_dim, build_alibi_tensor,
split_tensor_along_last_dim,
build_alibi_tensor,
)

@ -1 +1 @@
from src.client.remote_block import RemoteTransformerBlock
from src.client.remote_block import RemoteTransformerBlock

@ -37,9 +37,11 @@ def create_remote_module(
infos: Union[Sequence[Optional[ExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
) -> Union[List[Optional[RemoteTransformerBlock]], Future]:
if return_future:
async def _unpack(infos_future: MPFuture, dht: DHT):
p2p = await dht.replicate_p2p()
return _create_remote_experts(await infos_future, p2p)
return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
return _create_remote_experts(infos, p2p)
@ -53,6 +55,3 @@ def _create_remote_experts(infos: Sequence[Optional[ExpertInfo]], p2p: P2P) -> L
else:
experts.append(None)
return experts

@ -19,6 +19,7 @@ from src.server.cache import MemoryCache
class BloomBlockBackend(ExpertBackend):
"""A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
def __init__(self, *args, memory_cache: MemoryCache, **kwargs):
super().__init__(*args, **kwargs) # to bypass super.__init__
self.memory_cache = memory_cache
@ -31,4 +32,3 @@ class BloomBlockBackend(ExpertBackend):
def forward_incremental(self, *inputs: torch.Tensor, attention_cache_handle: int) -> Tuple[torch.Tensor, ...]:
with self.memory_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values):
raise NotImplementedError("TODO")

@ -32,7 +32,7 @@ class MemoryCache:
"""A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int]):
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2 ** 64 - 1)
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
self.device = device
self.lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
self._current_size = mp.Value(ctypes.c_uint64, 0, lock=False)
@ -77,12 +77,14 @@ class MemoryCache:
try:
async with hivemind.utils.enter_asynchronously(self.lock_metadata):
if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
raise AllocationFailed(f"Could not allocate {allocated_size_bytes} bytes in cache; cache size = "
f"{self.max_size_bytes} bytes; {self.current_size_bytes} already allocated.")
raise AllocationFailed(
f"Could not allocate {allocated_size_bytes} bytes in cache; cache size = "
f"{self.max_size_bytes} bytes; {self.current_size_bytes} already allocated."
)
allocated_handle = int(self.handle_counter)
self.current_size_bytes += allocated_size_bytes
self.handle_counter += 1 # note: this will eventually overflow and it is okay
self.handle_counter += 1 # note: this will eventually overflow and it is okay
self._pending_messages.value += 1
self._pipe_send.send((allocated_handle, descr))

@ -23,15 +23,24 @@ logger = get_logger(__file__)
class Server(threading.Thread):
"""Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
def __init__(
self, dht: DHT, module_backends: Dict[str, BloomBlockBackend], *,
device: torch.device, num_connection_handlers: int = 8,
update_period: float = 30, expiration: Optional[float] = None,
start: bool, **kwargs
self,
dht: DHT,
module_backends: Dict[str, BloomBlockBackend],
*,
device: torch.device,
num_connection_handlers: int = 8,
update_period: float = 30,
expiration: Optional[float] = None,
start: bool,
**kwargs,
):
threading.Thread.__init__(self)
self.dht, self.module_backends, self.update_period = dht, module_backends, update_period
self.conn_handlers = [TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)]
self.conn_handlers = [
TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
]
self.runtime = Runtime(self.module_backends, device=device, **kwargs)
self.dht_handler_thread = DHTHandlerThread(self.module_backends, dht, update_period, expiration, daemon=True)
self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
@ -71,23 +80,23 @@ class Server(threading.Thread):
# noinspection PyMethodOverriding
@classmethod
def create(
cls,
num_blocks: int,
block_config: str,
num_handlers: Optional[int] = None,
min_batch_size: int = 1,
max_batch_size: int = 4096,
cache_size_bytes: Optional[int] = None,
device: Union[str, torch.device] = None,
initial_peers: Sequence[str] = (),
compression=CompressionType.NONE,
stats_report_interval: Optional[int] = None,
custom_module_path=None,
update_period: float = 30,
expiration: Optional[float] = None,
*,
start: bool,
**kwargs,
cls,
num_blocks: int,
block_config: str,
num_handlers: Optional[int] = None,
min_batch_size: int = 1,
max_batch_size: int = 4096,
cache_size_bytes: Optional[int] = None,
device: Union[str, torch.device] = None,
initial_peers: Sequence[str] = (),
compression=CompressionType.NONE,
stats_report_interval: Optional[int] = None,
custom_module_path=None,
update_period: float = 30,
expiration: Optional[float] = None,
*,
start: bool,
**kwargs,
) -> Server:
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
if custom_module_path is not None:
@ -181,4 +190,3 @@ class Server(threading.Thread):
self.runtime.shutdown()
logger.info("Server shutdown succesfully")

Loading…
Cancel
Save