mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
Add memory cache usage
This commit is contained in:
parent
01c3cf8d15
commit
1b21dd3217
@ -57,15 +57,15 @@ class TransformerBackend(ModuleBackend):
|
|||||||
assert not buf.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
|
assert not buf.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
|
||||||
|
|
||||||
max_batch_size = self.forward_pool.max_batch_size
|
max_batch_size = self.forward_pool.max_batch_size
|
||||||
device = self.module.devices[self.module.output_device_index]
|
self.device = self.module.devices[self.module.output_device_index]
|
||||||
self.inference_pool = PrioritizedTaskPool(
|
self.inference_pool = PrioritizedTaskPool(
|
||||||
self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
|
self.inference_step, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_inference"
|
||||||
) # note: inference_pools may be merged later, see merge_inference_pools_inplace
|
) # note: inference_pools may be merged later, see merge_inference_pools_inplace
|
||||||
self.forward_pool = PrioritizedTaskPool(
|
self.forward_pool = PrioritizedTaskPool(
|
||||||
self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
|
self.forward, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_forward"
|
||||||
)
|
)
|
||||||
self.backward_pool = PrioritizedTaskPool(
|
self.backward_pool = PrioritizedTaskPool(
|
||||||
self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward"
|
self.backward, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_backward"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dtype = backend_dtype
|
self.dtype = backend_dtype
|
||||||
|
@ -15,6 +15,7 @@ from hivemind import (
|
|||||||
MSGPackSerializer,
|
MSGPackSerializer,
|
||||||
P2PContext,
|
P2PContext,
|
||||||
PeerID,
|
PeerID,
|
||||||
|
TensorDescriptor,
|
||||||
deserialize_tensor_stream,
|
deserialize_tensor_stream,
|
||||||
deserialize_torch_tensor,
|
deserialize_torch_tensor,
|
||||||
nested_flatten,
|
nested_flatten,
|
||||||
@ -170,7 +171,9 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|||||||
|
|
||||||
async with self._allocate_cache(
|
async with self._allocate_cache(
|
||||||
requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
|
requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
|
||||||
) as cache_handles, self._load_peft_module(requested_backends, active_adapter):
|
) as cache_handles, self._load_peft_module(
|
||||||
|
requested_backends, active_adapter=active_adapter, timeout=alloc_timeout
|
||||||
|
):
|
||||||
background_tasks = set()
|
background_tasks = set()
|
||||||
async for output_tensors, can_push in iterate_rpc_inference(
|
async for output_tensors, can_push in iterate_rpc_inference(
|
||||||
requested_uids=requested_uids,
|
requested_uids=requested_uids,
|
||||||
@ -490,9 +493,9 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|||||||
|
|
||||||
def _get_active_adapter(self, metadata: dict) -> str:
|
def _get_active_adapter(self, metadata: dict) -> str:
|
||||||
active_adapter = metadata.get("active_adapter", "")
|
active_adapter = metadata.get("active_adapter", "")
|
||||||
if active_adapter and (active_adapter not in self.adapters):
|
if active_adapter:
|
||||||
raise KeyError(f"adapter {active_adapter} not found")
|
return active_adapter
|
||||||
return active_adapter
|
return ""
|
||||||
|
|
||||||
def _serialize_grads(
|
def _serialize_grads(
|
||||||
self,
|
self,
|
||||||
@ -548,31 +551,49 @@ class TransformerConnectionHandler(ConnectionHandler):
|
|||||||
yield nested_pack(handles, descriptors)
|
yield nested_pack(handles, descriptors)
|
||||||
|
|
||||||
@contextlib.asynccontextmanager
|
@contextlib.asynccontextmanager
|
||||||
async def _load_peft_module(self, backends: Sequence[TransformerBackend], active_adapter: str):
|
async def _load_peft_module(
|
||||||
|
self,
|
||||||
|
backends: Sequence[TransformerBackend],
|
||||||
|
*,
|
||||||
|
active_adapter: str,
|
||||||
|
timeout: float,
|
||||||
|
):
|
||||||
if active_adapter == "":
|
if active_adapter == "":
|
||||||
yield
|
yield
|
||||||
elif active_adapter in self.adapters:
|
elif active_adapter in self.adapters:
|
||||||
yield
|
yield
|
||||||
else:
|
else:
|
||||||
try:
|
_peft_module = backends[0]._peft_module
|
||||||
_peft_module = backends[0]._peft_module
|
token = None # TODO: Provide token from user request maybe?
|
||||||
token = None # TODO: Provide token from user request maybe?
|
|
||||||
|
|
||||||
for backend in backends:
|
estimated_peft_size = _peft_module.get_estimated_peft_module_size(
|
||||||
adapter_config, adapter_state_dict = _peft_module.load_peft(
|
active_adapter,
|
||||||
active_adapter,
|
token=token,
|
||||||
block_idx=backend.block_index,
|
)
|
||||||
token=token,
|
|
||||||
cache_dir=backend.cache_dir,
|
|
||||||
max_disk_space=backend.max_disk_space,
|
|
||||||
)
|
|
||||||
|
|
||||||
_peft_module.add_adapter_to_block(
|
fake_descriptor = TensorDescriptor(
|
||||||
backend.module, backend.block_index, active_adapter, adapter_config, adapter_state_dict
|
size=(estimated_peft_size,),
|
||||||
)
|
dtype=torch.int8,
|
||||||
finally:
|
device=backends[0].device,
|
||||||
for backend in backends:
|
)
|
||||||
_peft_module.remove_adapter_from_block(backend.module, active_adapter)
|
|
||||||
|
async with backends[0].memory_cache.allocate_cache(fake_descriptor, timeout=timeout) as _:
|
||||||
|
try:
|
||||||
|
for backend in backends:
|
||||||
|
adapter_config, adapter_state_dict = _peft_module.load_peft(
|
||||||
|
active_adapter,
|
||||||
|
block_idx=backend.block_index,
|
||||||
|
token=token,
|
||||||
|
cache_dir=backend.cache_dir,
|
||||||
|
max_disk_space=backend.max_disk_space,
|
||||||
|
)
|
||||||
|
|
||||||
|
_peft_module.add_adapter_to_block(
|
||||||
|
backend.module, backend.block_index, active_adapter, adapter_config, adapter_state_dict
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
for backend in backends:
|
||||||
|
_peft_module.remove_adapter_from_block(backend.module, active_adapter)
|
||||||
|
|
||||||
def _log_request(
|
def _log_request(
|
||||||
self,
|
self,
|
||||||
|
@ -231,6 +231,8 @@ class Server:
|
|||||||
gib = 1024**3
|
gib = 1024**3
|
||||||
self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
|
self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
|
||||||
logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
|
logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
|
||||||
|
self.adapters_cache_bytes = self.attn_cache_bytes
|
||||||
|
logger.info(f"Adapter cache for all blocks will consume up to {self.adapters_cache_bytes / gib:.2f} GiB")
|
||||||
|
|
||||||
assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"]
|
assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"]
|
||||||
if throughput in ["auto", "eval", "dry_run"]:
|
if throughput in ["auto", "eval", "dry_run"]:
|
||||||
@ -335,6 +337,7 @@ class Server:
|
|||||||
converted_model_name_or_path=self.converted_model_name_or_path,
|
converted_model_name_or_path=self.converted_model_name_or_path,
|
||||||
block_config=self.block_config,
|
block_config=self.block_config,
|
||||||
attn_cache_bytes=self.attn_cache_bytes,
|
attn_cache_bytes=self.attn_cache_bytes,
|
||||||
|
adapters_cache_bytes=self.adapters_cache_bytes,
|
||||||
server_info=self.server_info,
|
server_info=self.server_info,
|
||||||
model_info=self.model_info,
|
model_info=self.model_info,
|
||||||
block_indices=block_indices,
|
block_indices=block_indices,
|
||||||
@ -442,6 +445,7 @@ class ModuleContainer(threading.Thread):
|
|||||||
converted_model_name_or_path: str,
|
converted_model_name_or_path: str,
|
||||||
block_config: PretrainedConfig,
|
block_config: PretrainedConfig,
|
||||||
attn_cache_bytes: int,
|
attn_cache_bytes: int,
|
||||||
|
adapters_cache_bytes: int,
|
||||||
server_info: ServerInfo,
|
server_info: ServerInfo,
|
||||||
model_info: ModelInfo,
|
model_info: ModelInfo,
|
||||||
block_indices: List[int],
|
block_indices: List[int],
|
||||||
@ -464,7 +468,7 @@ class ModuleContainer(threading.Thread):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ModuleContainer:
|
) -> ModuleContainer:
|
||||||
module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
|
module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
|
||||||
memory_cache = MemoryCache(attn_cache_bytes, max_alloc_timeout)
|
memory_cache = MemoryCache(attn_cache_bytes + adapters_cache_bytes, max_alloc_timeout)
|
||||||
|
|
||||||
server_info.state = ServerState.JOINING
|
server_info.state = ServerState.JOINING
|
||||||
dht_announcer = ModuleAnnouncerThread(
|
dht_announcer = ModuleAnnouncerThread(
|
||||||
@ -517,6 +521,8 @@ class ModuleContainer(threading.Thread):
|
|||||||
memory_cache=memory_cache,
|
memory_cache=memory_cache,
|
||||||
backend_dtype=torch_dtype,
|
backend_dtype=torch_dtype,
|
||||||
max_chunk_size_bytes=max_chunk_size_bytes,
|
max_chunk_size_bytes=max_chunk_size_bytes,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
max_disk_space=max_disk_space,
|
||||||
args_schema=(
|
args_schema=(
|
||||||
BatchTensorDescriptor(
|
BatchTensorDescriptor(
|
||||||
1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression
|
1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression
|
||||||
|
@ -111,11 +111,6 @@ async def _get_remote_module_infos(
|
|||||||
try:
|
try:
|
||||||
peer_id = PeerID.from_base58(peer_id)
|
peer_id = PeerID.from_base58(peer_id)
|
||||||
server_info = ServerInfo.from_tuple(server_info.value)
|
server_info = ServerInfo.from_tuple(server_info.value)
|
||||||
|
|
||||||
if active_adapter and active_adapter not in server_info.adapters:
|
|
||||||
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
servers[peer_id] = server_info
|
servers[peer_id] = server_info
|
||||||
except (TypeError, ValueError) as e:
|
except (TypeError, ValueError) as e:
|
||||||
logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
|
logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
|
||||||
|
@ -128,6 +128,15 @@ def load_peft(
|
|||||||
time.sleep(delay)
|
time.sleep(delay)
|
||||||
|
|
||||||
|
|
||||||
|
def get_estimated_peft_module_size(
|
||||||
|
repo_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
token: Optional[Union[str, bool]] = None,
|
||||||
|
):
|
||||||
|
weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
|
||||||
|
return get_hf_file_metadata(weight_url, token=token).size
|
||||||
|
|
||||||
|
|
||||||
class AdapterContextMixin:
|
class AdapterContextMixin:
|
||||||
"""A mixin that makes LoRA-wrapped linear layers obey an adapter set from context"""
|
"""A mixin that makes LoRA-wrapped linear layers obey an adapter set from context"""
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user