mirror of
https://github.com/bigscience-workshop/petals
synced 2024-11-16 06:12:50 +00:00
expel all bloom-specific files to src.bloom
This commit is contained in:
parent
324ea2dc96
commit
e5e8c9ed12
@ -0,0 +1 @@
|
||||
from .bloom import *
|
1
src/bloom/__init__.py
Normal file
1
src/bloom/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from src.bloom.model import BloomModel, BloomForCausalLM, MemoryEfficientBloomConfig
|
@ -9,7 +9,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.quantized.dynamic.modules.linear
|
||||
|
||||
from src.ops import (
|
||||
from src.bloom.ops import (
|
||||
BloomGelu,
|
||||
BloomScaledSoftmax,
|
||||
attention_mask_func,
|
@ -20,8 +20,8 @@ from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
from src.block import BloomBlock
|
||||
from src.ops import build_alibi_tensor
|
||||
from src.bloom.block import BloomBlock
|
||||
from src.bloom.ops import build_alibi_tensor
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
0
src/node/__init__.py
Normal file
0
src/node/__init__.py
Normal file
@ -1,7 +1,6 @@
|
||||
"""Code for serving bloom blocks via hivemind-server"""
|
||||
import contextlib
|
||||
import threading
|
||||
from typing import AsyncIterator, Tuple, List, Dict, Optional
|
||||
from typing import AsyncIterator, Tuple, Optional
|
||||
|
||||
import torch
|
||||
from hivemind import P2PContext, DHT
|
||||
@ -13,6 +12,8 @@ from hivemind.moe.server.server import Server
|
||||
from hivemind.proto import runtime_pb2
|
||||
from torch import nn
|
||||
|
||||
from src.node.cache import AttentionCache
|
||||
|
||||
|
||||
class BloomServer(Server):
|
||||
"""Serves one or more bloom layers for inference, forward and backward; announces oneself to the DHT"""
|
||||
@ -62,36 +63,7 @@ class _BloomBlockBackend(ExpertBackend):
|
||||
with self.attention_cache.use_cache(attention_cache_handle) as (current_length, cached_keys, cached_values):
|
||||
raise NotImplementedError("TODO")
|
||||
|
||||
|
||||
class AttentionCache:
|
||||
lock: mp.Lock
|
||||
data: Dict[int, SomeKindOfTupleWithTensors] # workaround for now, while we are on CPU
|
||||
@contextlib.asynccontextmanager
|
||||
async def allocate_cache(self, size: torch.Size, dtype: torch.dtype) -> int:
|
||||
"""
|
||||
Allocate buffers for attention cache on the compute device, return a unique handle;
|
||||
This function should be called by connection handler processes, may be called concurrently
|
||||
"""
|
||||
try:
|
||||
async with acquire_asynchronously(self.lock):
|
||||
handle: int = generate_unique_handle() # or just use counter mpvalue and increment each time
|
||||
assert handle not in data
|
||||
self.data[handle] = todo_allocate(self, size, dtype)
|
||||
yield handle
|
||||
finally:
|
||||
todo_deallocate(self, handle)
|
||||
# ^-- this should NOT move any data. But it may mark data for movement during next allocation
|
||||
self.data.pop(handle, None);
|
||||
|
||||
def use_cache(self, handle: int) -> Tuple[mp.Value, torch.Tensor, torch.Tensor]:
|
||||
"""Return a previously allocated cache, called by ExpertBackend in runtime (a single process)"""
|
||||
with self.lock:
|
||||
yield self.data[handle]
|
||||
|
||||
|
||||
|
||||
# later:
|
||||
# - if possible, do not change how DHTHandler handles for now
|
||||
# - do not worry about OOM in cache for now! - just make sure that nothing except cache could oom.
|
||||
# - contiguous attention cache with max size
|
||||
# - select a subset of experts
|
43
src/node/cache.py
Normal file
43
src/node/cache.py
Normal file
@ -0,0 +1,43 @@
|
||||
import contextlib
|
||||
import ctypes
|
||||
import multiprocessing as mp
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MemoryCache:
|
||||
lock: mp.Lock
|
||||
runtime_pid: int
|
||||
handle_counter: mp.Value[ctypes.c_uint64]
|
||||
current_size: mp.Value[ctypes.c_uint64]
|
||||
_runtime_data: Dict[int, SomeKindOfTupleWithTensors] # workaround for now, while we are on CPU
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def allocate_cache(self, size: torch.Size, dtype: torch.dtype) -> Optional[int]:
|
||||
"""
|
||||
Allocate buffers for attention cache on the compute device, return a unique handle;
|
||||
This function should be called by connection handler processes, may be called concurrently
|
||||
"""
|
||||
assert os.getpid() != self.runtime_pid
|
||||
try:
|
||||
async with acquire_asynchronously(self.lock):
|
||||
check_and_update_size(current_size, size, dtype)
|
||||
if enough_space:
|
||||
self.handle_counter.value += 1
|
||||
handle = int(self.handle_counter.value)
|
||||
# note: you cannot allocate data here because this is
|
||||
TODO_SOMEHOW_COMUNICATE_WITH_RUNTIME_TO_CREATE_THE_RIGHT_DATA
|
||||
yield handle
|
||||
finally:
|
||||
todo_deallocate(self, handle)
|
||||
# ^-- this should NOT move any data. But it may mark data for movement during next allocation
|
||||
self.data.pop(handle, None);
|
||||
|
||||
def use_cache(self, handle: int) -> Tuple[mp.Value, torch.Tensor, torch.Tensor]:
|
||||
"""Return a previously allocated cache, called by ExpertBackend in runtime (a single process)"""
|
||||
assert os.getpid() == self.runtime_pid
|
||||
with self.lock:
|
||||
if first_time:
|
||||
allocate_stuff(self._runtime_data)
|
||||
yield self.data[handle]
|
Loading…
Reference in New Issue
Block a user