Add memory cache usage

pull/506/head
Artem Chumachenko 8 months ago
parent 01c3cf8d15
commit 1b21dd3217

@ -57,15 +57,15 @@ class TransformerBackend(ModuleBackend):
assert not buf.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
max_batch_size = self.forward_pool.max_batch_size
device = self.module.devices[self.module.output_device_index]
self.device = self.module.devices[self.module.output_device_index]
self.inference_pool = PrioritizedTaskPool(
self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
self.inference_step, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_inference"
) # note: inference_pools may be merged later, see merge_inference_pools_inplace
self.forward_pool = PrioritizedTaskPool(
self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
self.forward, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_forward"
)
self.backward_pool = PrioritizedTaskPool(
self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward"
self.backward, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_backward"
)
self.dtype = backend_dtype

@ -15,6 +15,7 @@ from hivemind import (
MSGPackSerializer,
P2PContext,
PeerID,
TensorDescriptor,
deserialize_tensor_stream,
deserialize_torch_tensor,
nested_flatten,
@ -170,7 +171,9 @@ class TransformerConnectionHandler(ConnectionHandler):
async with self._allocate_cache(
requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
) as cache_handles, self._load_peft_module(requested_backends, active_adapter):
) as cache_handles, self._load_peft_module(
requested_backends, active_adapter=active_adapter, timeout=alloc_timeout
):
background_tasks = set()
async for output_tensors, can_push in iterate_rpc_inference(
requested_uids=requested_uids,
@ -490,9 +493,9 @@ class TransformerConnectionHandler(ConnectionHandler):
def _get_active_adapter(self, metadata: dict) -> str:
active_adapter = metadata.get("active_adapter", "")
if active_adapter and (active_adapter not in self.adapters):
raise KeyError(f"adapter {active_adapter} not found")
return active_adapter
if active_adapter:
return active_adapter
return ""
def _serialize_grads(
self,
@ -548,31 +551,49 @@ class TransformerConnectionHandler(ConnectionHandler):
yield nested_pack(handles, descriptors)
@contextlib.asynccontextmanager
async def _load_peft_module(self, backends: Sequence[TransformerBackend], active_adapter: str):
async def _load_peft_module(
self,
backends: Sequence[TransformerBackend],
*,
active_adapter: str,
timeout: float,
):
if active_adapter == "":
yield
elif active_adapter in self.adapters:
yield
else:
try:
_peft_module = backends[0]._peft_module
token = None # TODO: Provide token from user request maybe?
for backend in backends:
adapter_config, adapter_state_dict = _peft_module.load_peft(
active_adapter,
block_idx=backend.block_index,
token=token,
cache_dir=backend.cache_dir,
max_disk_space=backend.max_disk_space,
)
_peft_module = backends[0]._peft_module
token = None # TODO: Provide token from user request maybe?
_peft_module.add_adapter_to_block(
backend.module, backend.block_index, active_adapter, adapter_config, adapter_state_dict
)
finally:
for backend in backends:
_peft_module.remove_adapter_from_block(backend.module, active_adapter)
estimated_peft_size = _peft_module.get_estimated_peft_module_size(
active_adapter,
token=token,
)
fake_descriptor = TensorDescriptor(
size=(estimated_peft_size,),
dtype=torch.int8,
device=backends[0].device,
)
async with backends[0].memory_cache.allocate_cache(fake_descriptor, timeout=timeout) as _:
try:
for backend in backends:
adapter_config, adapter_state_dict = _peft_module.load_peft(
active_adapter,
block_idx=backend.block_index,
token=token,
cache_dir=backend.cache_dir,
max_disk_space=backend.max_disk_space,
)
_peft_module.add_adapter_to_block(
backend.module, backend.block_index, active_adapter, adapter_config, adapter_state_dict
)
finally:
for backend in backends:
_peft_module.remove_adapter_from_block(backend.module, active_adapter)
def _log_request(
self,

@ -231,6 +231,8 @@ class Server:
gib = 1024**3
self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
self.adapters_cache_bytes = self.attn_cache_bytes
logger.info(f"Adapter cache for all blocks will consume up to {self.adapters_cache_bytes / gib:.2f} GiB")
assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"]
if throughput in ["auto", "eval", "dry_run"]:
@ -335,6 +337,7 @@ class Server:
converted_model_name_or_path=self.converted_model_name_or_path,
block_config=self.block_config,
attn_cache_bytes=self.attn_cache_bytes,
adapters_cache_bytes=self.adapters_cache_bytes,
server_info=self.server_info,
model_info=self.model_info,
block_indices=block_indices,
@ -442,6 +445,7 @@ class ModuleContainer(threading.Thread):
converted_model_name_or_path: str,
block_config: PretrainedConfig,
attn_cache_bytes: int,
adapters_cache_bytes: int,
server_info: ServerInfo,
model_info: ModelInfo,
block_indices: List[int],
@ -464,7 +468,7 @@ class ModuleContainer(threading.Thread):
**kwargs,
) -> ModuleContainer:
module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
memory_cache = MemoryCache(attn_cache_bytes, max_alloc_timeout)
memory_cache = MemoryCache(attn_cache_bytes + adapters_cache_bytes, max_alloc_timeout)
server_info.state = ServerState.JOINING
dht_announcer = ModuleAnnouncerThread(
@ -517,6 +521,8 @@ class ModuleContainer(threading.Thread):
memory_cache=memory_cache,
backend_dtype=torch_dtype,
max_chunk_size_bytes=max_chunk_size_bytes,
cache_dir=cache_dir,
max_disk_space=max_disk_space,
args_schema=(
BatchTensorDescriptor(
1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression

@ -111,11 +111,6 @@ async def _get_remote_module_infos(
try:
peer_id = PeerID.from_base58(peer_id)
server_info = ServerInfo.from_tuple(server_info.value)
if active_adapter and active_adapter not in server_info.adapters:
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
continue
servers[peer_id] = server_info
except (TypeError, ValueError) as e:
logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")

@ -128,6 +128,15 @@ def load_peft(
time.sleep(delay)
def get_estimated_peft_module_size(
repo_id: str,
revision: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
):
weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
return get_hf_file_metadata(weight_url, token=token).size
class AdapterContextMixin:
"""A mixin that makes LoRA-wrapped linear layers obey an adapter set from context"""

Loading…
Cancel
Save