Add memory cache usage

This commit is contained in:
Artem Chumachenko 2023-09-06 12:10:32 +04:00
parent 01c3cf8d15
commit 1b21dd3217
5 changed files with 63 additions and 32 deletions

View File

@ -57,15 +57,15 @@ class TransformerBackend(ModuleBackend):
assert not buf.requires_grad, f"Block parameters must not accumulate gradients, but {name} does" assert not buf.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
max_batch_size = self.forward_pool.max_batch_size max_batch_size = self.forward_pool.max_batch_size
device = self.module.devices[self.module.output_device_index] self.device = self.module.devices[self.module.output_device_index]
self.inference_pool = PrioritizedTaskPool( self.inference_pool = PrioritizedTaskPool(
self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference" self.inference_step, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_inference"
) # note: inference_pools may be merged later, see merge_inference_pools_inplace ) # note: inference_pools may be merged later, see merge_inference_pools_inplace
self.forward_pool = PrioritizedTaskPool( self.forward_pool = PrioritizedTaskPool(
self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward" self.forward, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_forward"
) )
self.backward_pool = PrioritizedTaskPool( self.backward_pool = PrioritizedTaskPool(
self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward" self.backward, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_backward"
) )
self.dtype = backend_dtype self.dtype = backend_dtype

View File

@ -15,6 +15,7 @@ from hivemind import (
MSGPackSerializer, MSGPackSerializer,
P2PContext, P2PContext,
PeerID, PeerID,
TensorDescriptor,
deserialize_tensor_stream, deserialize_tensor_stream,
deserialize_torch_tensor, deserialize_torch_tensor,
nested_flatten, nested_flatten,
@ -170,7 +171,9 @@ class TransformerConnectionHandler(ConnectionHandler):
async with self._allocate_cache( async with self._allocate_cache(
requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
) as cache_handles, self._load_peft_module(requested_backends, active_adapter): ) as cache_handles, self._load_peft_module(
requested_backends, active_adapter=active_adapter, timeout=alloc_timeout
):
background_tasks = set() background_tasks = set()
async for output_tensors, can_push in iterate_rpc_inference( async for output_tensors, can_push in iterate_rpc_inference(
requested_uids=requested_uids, requested_uids=requested_uids,
@ -490,9 +493,9 @@ class TransformerConnectionHandler(ConnectionHandler):
def _get_active_adapter(self, metadata: dict) -> str: def _get_active_adapter(self, metadata: dict) -> str:
active_adapter = metadata.get("active_adapter", "") active_adapter = metadata.get("active_adapter", "")
if active_adapter and (active_adapter not in self.adapters): if active_adapter:
raise KeyError(f"adapter {active_adapter} not found") return active_adapter
return active_adapter return ""
def _serialize_grads( def _serialize_grads(
self, self,
@ -548,31 +551,49 @@ class TransformerConnectionHandler(ConnectionHandler):
yield nested_pack(handles, descriptors) yield nested_pack(handles, descriptors)
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def _load_peft_module(self, backends: Sequence[TransformerBackend], active_adapter: str): async def _load_peft_module(
self,
backends: Sequence[TransformerBackend],
*,
active_adapter: str,
timeout: float,
):
if active_adapter == "": if active_adapter == "":
yield yield
elif active_adapter in self.adapters: elif active_adapter in self.adapters:
yield yield
else: else:
try: _peft_module = backends[0]._peft_module
_peft_module = backends[0]._peft_module token = None # TODO: Provide token from user request maybe?
token = None # TODO: Provide token from user request maybe?
for backend in backends: estimated_peft_size = _peft_module.get_estimated_peft_module_size(
adapter_config, adapter_state_dict = _peft_module.load_peft( active_adapter,
active_adapter, token=token,
block_idx=backend.block_index, )
token=token,
cache_dir=backend.cache_dir,
max_disk_space=backend.max_disk_space,
)
_peft_module.add_adapter_to_block( fake_descriptor = TensorDescriptor(
backend.module, backend.block_index, active_adapter, adapter_config, adapter_state_dict size=(estimated_peft_size,),
) dtype=torch.int8,
finally: device=backends[0].device,
for backend in backends: )
_peft_module.remove_adapter_from_block(backend.module, active_adapter)
async with backends[0].memory_cache.allocate_cache(fake_descriptor, timeout=timeout) as _:
try:
for backend in backends:
adapter_config, adapter_state_dict = _peft_module.load_peft(
active_adapter,
block_idx=backend.block_index,
token=token,
cache_dir=backend.cache_dir,
max_disk_space=backend.max_disk_space,
)
_peft_module.add_adapter_to_block(
backend.module, backend.block_index, active_adapter, adapter_config, adapter_state_dict
)
finally:
for backend in backends:
_peft_module.remove_adapter_from_block(backend.module, active_adapter)
def _log_request( def _log_request(
self, self,

View File

@ -231,6 +231,8 @@ class Server:
gib = 1024**3 gib = 1024**3
self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB") logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
self.adapters_cache_bytes = self.attn_cache_bytes
logger.info(f"Adapter cache for all blocks will consume up to {self.adapters_cache_bytes / gib:.2f} GiB")
assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"] assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"]
if throughput in ["auto", "eval", "dry_run"]: if throughput in ["auto", "eval", "dry_run"]:
@ -335,6 +337,7 @@ class Server:
converted_model_name_or_path=self.converted_model_name_or_path, converted_model_name_or_path=self.converted_model_name_or_path,
block_config=self.block_config, block_config=self.block_config,
attn_cache_bytes=self.attn_cache_bytes, attn_cache_bytes=self.attn_cache_bytes,
adapters_cache_bytes=self.adapters_cache_bytes,
server_info=self.server_info, server_info=self.server_info,
model_info=self.model_info, model_info=self.model_info,
block_indices=block_indices, block_indices=block_indices,
@ -442,6 +445,7 @@ class ModuleContainer(threading.Thread):
converted_model_name_or_path: str, converted_model_name_or_path: str,
block_config: PretrainedConfig, block_config: PretrainedConfig,
attn_cache_bytes: int, attn_cache_bytes: int,
adapters_cache_bytes: int,
server_info: ServerInfo, server_info: ServerInfo,
model_info: ModelInfo, model_info: ModelInfo,
block_indices: List[int], block_indices: List[int],
@ -464,7 +468,7 @@ class ModuleContainer(threading.Thread):
**kwargs, **kwargs,
) -> ModuleContainer: ) -> ModuleContainer:
module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices] module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
memory_cache = MemoryCache(attn_cache_bytes, max_alloc_timeout) memory_cache = MemoryCache(attn_cache_bytes + adapters_cache_bytes, max_alloc_timeout)
server_info.state = ServerState.JOINING server_info.state = ServerState.JOINING
dht_announcer = ModuleAnnouncerThread( dht_announcer = ModuleAnnouncerThread(
@ -517,6 +521,8 @@ class ModuleContainer(threading.Thread):
memory_cache=memory_cache, memory_cache=memory_cache,
backend_dtype=torch_dtype, backend_dtype=torch_dtype,
max_chunk_size_bytes=max_chunk_size_bytes, max_chunk_size_bytes=max_chunk_size_bytes,
cache_dir=cache_dir,
max_disk_space=max_disk_space,
args_schema=( args_schema=(
BatchTensorDescriptor( BatchTensorDescriptor(
1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression 1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression

View File

@ -111,11 +111,6 @@ async def _get_remote_module_infos(
try: try:
peer_id = PeerID.from_base58(peer_id) peer_id = PeerID.from_base58(peer_id)
server_info = ServerInfo.from_tuple(server_info.value) server_info = ServerInfo.from_tuple(server_info.value)
if active_adapter and active_adapter not in server_info.adapters:
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
continue
servers[peer_id] = server_info servers[peer_id] = server_info
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}") logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")

View File

@ -128,6 +128,15 @@ def load_peft(
time.sleep(delay) time.sleep(delay)
def get_estimated_peft_module_size(
repo_id: str,
revision: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
):
weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
return get_hf_file_metadata(weight_url, token=token).size
class AdapterContextMixin: class AdapterContextMixin:
"""A mixin that makes LoRA-wrapped linear layers obey an adapter set from context""" """A mixin that makes LoRA-wrapped linear layers obey an adapter set from context"""