diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index f95b89d..9987a79 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -57,15 +57,15 @@ class TransformerBackend(ModuleBackend): assert not buf.requires_grad, f"Block parameters must not accumulate gradients, but {name} does" max_batch_size = self.forward_pool.max_batch_size - device = self.module.devices[self.module.output_device_index] + self.device = self.module.devices[self.module.output_device_index] self.inference_pool = PrioritizedTaskPool( - self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference" + self.inference_step, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_inference" ) # note: inference_pools may be merged later, see merge_inference_pools_inplace self.forward_pool = PrioritizedTaskPool( - self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward" + self.forward, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_forward" ) self.backward_pool = PrioritizedTaskPool( - self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward" + self.backward, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_backward" ) self.dtype = backend_dtype diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 100d808..ef0e9b9 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -15,6 +15,7 @@ from hivemind import ( MSGPackSerializer, P2PContext, PeerID, + TensorDescriptor, deserialize_tensor_stream, deserialize_torch_tensor, nested_flatten, @@ -170,7 +171,9 @@ class TransformerConnectionHandler(ConnectionHandler): async with self._allocate_cache( requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout - ) as cache_handles, self._load_peft_module(requested_backends, active_adapter): + ) as cache_handles, self._load_peft_module( + requested_backends, active_adapter=active_adapter, timeout=alloc_timeout + ): background_tasks = set() async for output_tensors, can_push in iterate_rpc_inference( requested_uids=requested_uids, @@ -490,9 +493,9 @@ class TransformerConnectionHandler(ConnectionHandler): def _get_active_adapter(self, metadata: dict) -> str: active_adapter = metadata.get("active_adapter", "") - if active_adapter and (active_adapter not in self.adapters): - raise KeyError(f"adapter {active_adapter} not found") - return active_adapter + if active_adapter: + return active_adapter + return "" def _serialize_grads( self, @@ -548,31 +551,49 @@ class TransformerConnectionHandler(ConnectionHandler): yield nested_pack(handles, descriptors) @contextlib.asynccontextmanager - async def _load_peft_module(self, backends: Sequence[TransformerBackend], active_adapter: str): + async def _load_peft_module( + self, + backends: Sequence[TransformerBackend], + *, + active_adapter: str, + timeout: float, + ): if active_adapter == "": yield elif active_adapter in self.adapters: yield else: - try: - _peft_module = backends[0]._peft_module - token = None # TODO: Provide token from user request maybe? - - for backend in backends: - adapter_config, adapter_state_dict = _peft_module.load_peft( - active_adapter, - block_idx=backend.block_index, - token=token, - cache_dir=backend.cache_dir, - max_disk_space=backend.max_disk_space, - ) + _peft_module = backends[0]._peft_module + token = None # TODO: Provide token from user request maybe? - _peft_module.add_adapter_to_block( - backend.module, backend.block_index, active_adapter, adapter_config, adapter_state_dict - ) - finally: - for backend in backends: - _peft_module.remove_adapter_from_block(backend.module, active_adapter) + estimated_peft_size = _peft_module.get_estimated_peft_module_size( + active_adapter, + token=token, + ) + + fake_descriptor = TensorDescriptor( + size=(estimated_peft_size,), + dtype=torch.int8, + device=backends[0].device, + ) + + async with backends[0].memory_cache.allocate_cache(fake_descriptor, timeout=timeout) as _: + try: + for backend in backends: + adapter_config, adapter_state_dict = _peft_module.load_peft( + active_adapter, + block_idx=backend.block_index, + token=token, + cache_dir=backend.cache_dir, + max_disk_space=backend.max_disk_space, + ) + + _peft_module.add_adapter_to_block( + backend.module, backend.block_index, active_adapter, adapter_config, adapter_state_dict + ) + finally: + for backend in backends: + _peft_module.remove_adapter_from_block(backend.module, active_adapter) def _log_request( self, diff --git a/src/petals/server/server.py b/src/petals/server/server.py index c17de6b..5370c99 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -231,6 +231,8 @@ class Server: gib = 1024**3 self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB") + self.adapters_cache_bytes = self.attn_cache_bytes + logger.info(f"Adapter cache for all blocks will consume up to {self.adapters_cache_bytes / gib:.2f} GiB") assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"] if throughput in ["auto", "eval", "dry_run"]: @@ -335,6 +337,7 @@ class Server: converted_model_name_or_path=self.converted_model_name_or_path, block_config=self.block_config, attn_cache_bytes=self.attn_cache_bytes, + adapters_cache_bytes=self.adapters_cache_bytes, server_info=self.server_info, model_info=self.model_info, block_indices=block_indices, @@ -442,6 +445,7 @@ class ModuleContainer(threading.Thread): converted_model_name_or_path: str, block_config: PretrainedConfig, attn_cache_bytes: int, + adapters_cache_bytes: int, server_info: ServerInfo, model_info: ModelInfo, block_indices: List[int], @@ -464,7 +468,7 @@ class ModuleContainer(threading.Thread): **kwargs, ) -> ModuleContainer: module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices] - memory_cache = MemoryCache(attn_cache_bytes, max_alloc_timeout) + memory_cache = MemoryCache(attn_cache_bytes + adapters_cache_bytes, max_alloc_timeout) server_info.state = ServerState.JOINING dht_announcer = ModuleAnnouncerThread( @@ -517,6 +521,8 @@ class ModuleContainer(threading.Thread): memory_cache=memory_cache, backend_dtype=torch_dtype, max_chunk_size_bytes=max_chunk_size_bytes, + cache_dir=cache_dir, + max_disk_space=max_disk_space, args_schema=( BatchTensorDescriptor( 1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression diff --git a/src/petals/utils/dht.py b/src/petals/utils/dht.py index 0710f60..6a9952d 100644 --- a/src/petals/utils/dht.py +++ b/src/petals/utils/dht.py @@ -111,11 +111,6 @@ async def _get_remote_module_infos( try: peer_id = PeerID.from_base58(peer_id) server_info = ServerInfo.from_tuple(server_info.value) - - if active_adapter and active_adapter not in server_info.adapters: - logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}") - continue - servers[peer_id] = server_info except (TypeError, ValueError) as e: logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}") diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index 56bec87..76d00ad 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -128,6 +128,15 @@ def load_peft( time.sleep(delay) +def get_estimated_peft_module_size( + repo_id: str, + revision: Optional[str] = None, + token: Optional[Union[str, bool]] = None, +): + weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision) + return get_hf_file_metadata(weight_url, token=token).size + + class AdapterContextMixin: """A mixin that makes LoRA-wrapped linear layers obey an adapter set from context"""