From 01c3cf8d1567f3632cfe1cfb1e280cfaa54dfe9f Mon Sep 17 00:00:00 2001 From: Artem Chumachenko Date: Wed, 6 Sep 2023 10:46:10 +0400 Subject: [PATCH] Add first version --- src/petals/server/backend.py | 6 ++++++ src/petals/server/handler.py | 32 +++++++++++++++++++++++++++++-- src/petals/server/server.py | 1 + src/petals/utils/convert_block.py | 7 ++++--- src/petals/utils/peft.py | 16 ++++++++++++++++ 5 files changed, 57 insertions(+), 5 deletions(-) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 3a9b63e..f95b89d 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -29,10 +29,13 @@ class TransformerBackend(ModuleBackend): def __init__( self, *args, + block_index: int, config: PretrainedConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, max_chunk_size_bytes: int, + cache_dir: str, + max_disk_space: int, **kwargs, ): import petals.utils.peft as _peft_module @@ -41,9 +44,12 @@ class TransformerBackend(ModuleBackend): super().__init__(*args, **kwargs) assert isinstance(self.module, TensorParallel) + self.block_index = block_index self.config = config self.memory_cache = memory_cache self.max_chunk_size_bytes = max_chunk_size_bytes + self.cache_dir = cache_dir + self.max_disk_space = max_disk_space for name, param in self.module.named_parameters(): assert not param.requires_grad, f"Block parameters must not accumulate gradients, but {name} does" diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index d8f0ec0..100d808 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -152,6 +152,7 @@ class TransformerConnectionHandler(ConnectionHandler): session_id = metadata.get("session_id") alloc_timeout = float(metadata.get("alloc_timeout", 0.0)) args_structure = metadata.get("args_structure") + active_adapter = self._get_active_adapter(metadata) if not requested_uids: raise ValueError("User must specify at least one block for inference, but got none") assert isinstance( @@ -169,12 +170,12 @@ class TransformerConnectionHandler(ConnectionHandler): async with self._allocate_cache( requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout - ) as cache_handles: + ) as cache_handles, self._load_peft_module(requested_backends, active_adapter): background_tasks = set() async for output_tensors, can_push in iterate_rpc_inference( requested_uids=requested_uids, requested_backends=requested_backends, - active_adapter=self._get_active_adapter(metadata), + active_adapter=active_adapter, input_iterator=self._iterate_inference_steps( request, requests, session_id, requested_uids, context ), @@ -546,6 +547,33 @@ class TransformerConnectionHandler(ConnectionHandler): async with backends[0].memory_cache.allocate_cache(*chain(*descriptors), timeout=timeout) as handles: yield nested_pack(handles, descriptors) + @contextlib.asynccontextmanager + async def _load_peft_module(self, backends: Sequence[TransformerBackend], active_adapter: str): + if active_adapter == "": + yield + elif active_adapter in self.adapters: + yield + else: + try: + _peft_module = backends[0]._peft_module + token = None # TODO: Provide token from user request maybe? + + for backend in backends: + adapter_config, adapter_state_dict = _peft_module.load_peft( + active_adapter, + block_idx=backend.block_index, + token=token, + cache_dir=backend.cache_dir, + max_disk_space=backend.max_disk_space, + ) + + _peft_module.add_adapter_to_block( + backend.module, backend.block_index, active_adapter, adapter_config, adapter_state_dict + ) + finally: + for backend in backends: + _peft_module.remove_adapter_from_block(backend.module, active_adapter) + def _log_request( self, method: str, diff --git a/src/petals/server/server.py b/src/petals/server/server.py index fd9f766..c17de6b 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -512,6 +512,7 @@ class ModuleContainer(threading.Thread): blocks[module_uid] = TransformerBackend( module_uid, block, + block_index=block_index, config=block_config, memory_cache=memory_cache, backend_dtype=torch_dtype, diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index 94d3e29..30aa969 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -58,10 +58,11 @@ def convert_block( for shard, device in zip(block.module_shards, block.devices): shard.to(device) - if adapters: - from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft + from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft + + create_lora_adapter(block, quant_type=quant_type) - create_lora_adapter(block, quant_type=quant_type) + if adapters: for adapter_name in adapters: adapter_config, adapter_state_dict = load_peft( adapter_name, diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index e4d29fc..56bec87 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -267,6 +267,22 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta logger.info(f"Loaded adapter {adapter_name} for block {block_index}") +def remove_adapter_from_block(block, adapter_name): + for _, module in block.named_modules(): + for child_name, child in module.named_children(): + if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)): + continue + + if adapter_name in child.lora_A: + del child.lora_A[adapter_name] + if adapter_name in child.lora_B: + del child.lora_B[adapter_name] + + # TODO: check is this needed + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def estimate_adapter_memory_per_block( block_config: transformers.PretrainedConfig, torch_dtype: Optional[torch.dtype],