From 37fdcb3fe066a45ae80c3419cc60c658cbcbb594 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Fri, 14 Jul 2023 22:04:55 +0300 Subject: [PATCH] Switch adapters slightly faster (#353) Currently, each `TransformerBackend.inference_step` looks for adapters and sets the correct adapter type for each block. This is not very expensive, but it can measurably affect inference time. This pull request uses faster adapter switching with just one variable assignment, without iterating over block.modules(). --- src/petals/server/backend.py | 35 +++++++++-------------- src/petals/server/handler.py | 18 ++++++++---- src/petals/server/server.py | 1 + src/petals/utils/peft.py | 55 +++++++++++++++++++++++++++++++----- 4 files changed, 75 insertions(+), 34 deletions(-) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 51c6ee0..4220546 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -24,9 +24,15 @@ logger = get_logger(__name__) class TransformerBackend(ModuleBackend): """A wrapper for a transformer block that can process requests for forward, backward and inference""" + _peft_module = None + def __init__( self, *args, config: PretrainedConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs ): + import petals.utils.peft as _peft_module + + self._peft_module = _peft_module + super().__init__(*args, **kwargs) assert isinstance(self.module, TensorParallel) self.config = config @@ -82,13 +88,13 @@ class TransformerBackend(ModuleBackend): def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: *inputs, active_adapter = inputs - self.load_adapter_(active_adapter) - return super().forward(*inputs) + with self._peft_module.using_adapter(active_adapter): + return super().forward(*inputs) def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: *inputs, active_adapter = inputs - self.load_adapter_(active_adapter) - return super().backward(*inputs) + with self._peft_module.using_adapter(active_adapter): + return super().backward(*inputs) @torch.inference_mode() def inference_step( @@ -98,8 +104,9 @@ class TransformerBackend(ModuleBackend): inference_info: InferenceMetadata, ) -> Tuple[torch.Tensor, ...]: assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" - self.load_adapter_(inference_info.active_adapter) - with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors: + with self.memory_cache.use_cache( + *inference_info.cache_handles + ) as cache_tensors, self._peft_module.using_adapter(inference_info.active_adapter): self._reorder_cache_inplace(cache_tensors, hypo_ids) layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length) hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True) @@ -150,22 +157,6 @@ class TransformerBackend(ModuleBackend): for p in self.module.parameters(): p.data = dummy - def load_adapter_(self, active_adapter: Optional[str] = None) -> bool: - """Activate a given adapter set if available. Return True if available (or no adapter), False if missing""" - - # Import petals.utils.peft only when necessary to avoid importing bitsandbytes - from peft.tuners.lora import Linear, Linear4bit, Linear8bitLt - - loaded = False - for layer in self.module.modules(): # select adapter set -- leave empty string for no adapter - if isinstance(layer, (Linear, Linear4bit, Linear8bitLt)): - layer.active_adapter = active_adapter # empty string for no adapter - if active_adapter in layer.lora_A.keys(): - loaded = True - - if active_adapter and not loaded: - raise KeyError(f"Could not find adapter {active_adapter}, perhaps it is not loaded") - def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]): """Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call""" diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index d9a5025..12fd6eb 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -68,6 +68,7 @@ class TransformerConnectionHandler(ConnectionHandler): dht: DHT, module_backends: Dict[str, TransformerBackend], *, + adapters: Optional[Sequence[str]], dht_prefix: str, push_manager: multiprocessing.managers.SyncManager, session_queues: Dict[str, multiprocessing.managers.BaseProxy], # BaseProxy for queue.Queue @@ -81,6 +82,7 @@ class TransformerConnectionHandler(ConnectionHandler): for module_backend in self.module_backends.values(): assert isinstance(module_backend, TransformerBackend) self.dht_prefix = dht_prefix + self.adapters = adapters self._push_manager = push_manager self._session_queues = session_queues self._executor = ThreadPoolExecutor(max_workers=float("inf")) # For waiting on self.session_queues @@ -141,7 +143,7 @@ class TransformerConnectionHandler(ConnectionHandler): metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) max_length = metadata.get("max_length") - active_adapter = metadata.get("active_adapter", "") + active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) session_id = metadata.get("session_id") @@ -355,7 +357,7 @@ class TransformerConnectionHandler(ConnectionHandler): requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} - active_adapter = metadata.get("active_adapter", "") + active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) assert isinstance( points, (float, int) @@ -382,7 +384,7 @@ class TransformerConnectionHandler(ConnectionHandler): self._log_request("rpc_forward_stream", requested_uids, context) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) - active_adapter = metadata.get("active_adapter", "") + active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) assert isinstance( points, (float, int) @@ -433,7 +435,7 @@ class TransformerConnectionHandler(ConnectionHandler): requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} - active_adapter = metadata.get("active_adapter", "") + active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) assert isinstance( points, (float, int) @@ -458,7 +460,7 @@ class TransformerConnectionHandler(ConnectionHandler): self._log_request("rpc_backward_stream", requested_uids, context) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) - active_adapter = metadata.get("active_adapter", "") + active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) assert isinstance( points, (float, int) @@ -476,6 +478,12 @@ class TransformerConnectionHandler(ConnectionHandler): for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE): yield runtime_pb2.ExpertResponse(tensors=[part]) + def _get_active_adapter(self, metadata: dict) -> str: + active_adapter = metadata.get("active_adapter", "") + if active_adapter and (active_adapter not in self.adapters): + raise KeyError(f"adapter {active_adapter} not found") + return active_adapter + def _serialize_grads( self, grads: Sequence[torch.Tensor], diff --git a/src/petals/server/server.py b/src/petals/server/server.py index c90ae44..83a94e3 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -534,6 +534,7 @@ class ModuleContainer(threading.Thread): TransformerConnectionHandler( dht, self.module_backends, + adapters=adapters, dht_prefix=dht_prefix, push_manager=self.push_manager, session_queues=session_queues, diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index c537a32..b182181 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -1,3 +1,4 @@ +import contextlib import re import time from typing import Optional, Sequence @@ -118,6 +119,47 @@ def load_peft( time.sleep(delay) +class AdapterContextMixin: + """A mixin that makes LoRA-wrapped linear layers obey an adapter set from context""" + + ADAPTER_NOT_SET = "__ADAPTER_NOT_SET" + _context_active_adapter = ADAPTER_NOT_SET + + @staticmethod + @contextlib.contextmanager + def using_adapter(active_adapter: Optional[str]): + prev, AdapterContextMixin._context_active_adapter = AdapterContextMixin._context_active_adapter, active_adapter + try: + yield + finally: + AdapterContextMixin._context_active_adapter = prev + + @property + def active_adapter(self): + if self._context_active_adapter == self.ADAPTER_NOT_SET: + logger.warning(f"Layer {self} was called without using_adapter. This should only be used for debug") + return self._context_active_adapter + + @active_adapter.setter + def active_adapter(self, value: Optional[str]): + assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed via .using_adapter" "" + + +using_adapter = AdapterContextMixin.using_adapter + + +class LoraLinear(lora.Linear, AdapterContextMixin): + """LoRA linear layer that uses adapter selected via using_adapter""" + + +class LoraLinear8bitLt(lora.Linear8bitLt, AdapterContextMixin): + """LoRA linear 8-bit with outliers that uses adapter selected via using_adapter""" + + +class LoraLinear4bit(lora.Linear4bit, AdapterContextMixin): + """LoRA linear 4-bit that uses adapter selected via using_adapter""" + + def create_lora_adapter(block, quant_type: QuantType): for _, module in block.named_modules(): for child_name, child in module.named_children(): @@ -130,8 +172,8 @@ def create_lora_adapter(block, quant_type: QuantType): "threshold": 6.0, "bias": hasattr(child, "bias") and child.bias is not None, } - lora_wrapped_child = lora.Linear8bitLt( - child_name, + lora_wrapped_child = LoraLinear8bitLt( + AdapterContextMixin.ADAPTER_NOT_SET, child.in_features, child.out_features, **kwargs, @@ -143,22 +185,21 @@ def create_lora_adapter(block, quant_type: QuantType): "blocksize": 64, "bias": hasattr(child, "bias") and child.bias is not None, } - lora_wrapped_child = lora.Linear4bit( - child_name, + lora_wrapped_child = LoraLinear4bit( + AdapterContextMixin.ADAPTER_NOT_SET, child.in_features, child.out_features, **kwargs, ) else: bias = hasattr(child, "bias") and child.bias is not None - lora_wrapped_child = lora.Linear( - child_name, + lora_wrapped_child = LoraLinear( + AdapterContextMixin.ADAPTER_NOT_SET, child.in_features, child.out_features, bias=bias, ) if lora_wrapped_child: - lora_wrapped_child.active_adapter = None lora_wrapped_child.weight = child.weight lora_wrapped_child.bias = child.bias for p in lora_wrapped_child.parameters():