diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 51c6ee0..4220546 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -24,9 +24,15 @@ logger = get_logger(__name__) class TransformerBackend(ModuleBackend): """A wrapper for a transformer block that can process requests for forward, backward and inference""" + _peft_module = None + def __init__( self, *args, config: PretrainedConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs ): + import petals.utils.peft as _peft_module + + self._peft_module = _peft_module + super().__init__(*args, **kwargs) assert isinstance(self.module, TensorParallel) self.config = config @@ -82,13 +88,13 @@ class TransformerBackend(ModuleBackend): def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: *inputs, active_adapter = inputs - self.load_adapter_(active_adapter) - return super().forward(*inputs) + with self._peft_module.using_adapter(active_adapter): + return super().forward(*inputs) def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: *inputs, active_adapter = inputs - self.load_adapter_(active_adapter) - return super().backward(*inputs) + with self._peft_module.using_adapter(active_adapter): + return super().backward(*inputs) @torch.inference_mode() def inference_step( @@ -98,8 +104,9 @@ class TransformerBackend(ModuleBackend): inference_info: InferenceMetadata, ) -> Tuple[torch.Tensor, ...]: assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" - self.load_adapter_(inference_info.active_adapter) - with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors: + with self.memory_cache.use_cache( + *inference_info.cache_handles + ) as cache_tensors, self._peft_module.using_adapter(inference_info.active_adapter): self._reorder_cache_inplace(cache_tensors, hypo_ids) layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length) hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True) @@ -150,22 +157,6 @@ class TransformerBackend(ModuleBackend): for p in self.module.parameters(): p.data = dummy - def load_adapter_(self, active_adapter: Optional[str] = None) -> bool: - """Activate a given adapter set if available. Return True if available (or no adapter), False if missing""" - - # Import petals.utils.peft only when necessary to avoid importing bitsandbytes - from peft.tuners.lora import Linear, Linear4bit, Linear8bitLt - - loaded = False - for layer in self.module.modules(): # select adapter set -- leave empty string for no adapter - if isinstance(layer, (Linear, Linear4bit, Linear8bitLt)): - layer.active_adapter = active_adapter # empty string for no adapter - if active_adapter in layer.lora_A.keys(): - loaded = True - - if active_adapter and not loaded: - raise KeyError(f"Could not find adapter {active_adapter}, perhaps it is not loaded") - def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]): """Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call""" diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index d9a5025..12fd6eb 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -68,6 +68,7 @@ class TransformerConnectionHandler(ConnectionHandler): dht: DHT, module_backends: Dict[str, TransformerBackend], *, + adapters: Optional[Sequence[str]], dht_prefix: str, push_manager: multiprocessing.managers.SyncManager, session_queues: Dict[str, multiprocessing.managers.BaseProxy], # BaseProxy for queue.Queue @@ -81,6 +82,7 @@ class TransformerConnectionHandler(ConnectionHandler): for module_backend in self.module_backends.values(): assert isinstance(module_backend, TransformerBackend) self.dht_prefix = dht_prefix + self.adapters = adapters self._push_manager = push_manager self._session_queues = session_queues self._executor = ThreadPoolExecutor(max_workers=float("inf")) # For waiting on self.session_queues @@ -141,7 +143,7 @@ class TransformerConnectionHandler(ConnectionHandler): metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) max_length = metadata.get("max_length") - active_adapter = metadata.get("active_adapter", "") + active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) session_id = metadata.get("session_id") @@ -355,7 +357,7 @@ class TransformerConnectionHandler(ConnectionHandler): requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} - active_adapter = metadata.get("active_adapter", "") + active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) assert isinstance( points, (float, int) @@ -382,7 +384,7 @@ class TransformerConnectionHandler(ConnectionHandler): self._log_request("rpc_forward_stream", requested_uids, context) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) - active_adapter = metadata.get("active_adapter", "") + active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) assert isinstance( points, (float, int) @@ -433,7 +435,7 @@ class TransformerConnectionHandler(ConnectionHandler): requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} - active_adapter = metadata.get("active_adapter", "") + active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) assert isinstance( points, (float, int) @@ -458,7 +460,7 @@ class TransformerConnectionHandler(ConnectionHandler): self._log_request("rpc_backward_stream", requested_uids, context) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) - active_adapter = metadata.get("active_adapter", "") + active_adapter = self._get_active_adapter(metadata) points = metadata.get("points", 0) assert isinstance( points, (float, int) @@ -476,6 +478,12 @@ class TransformerConnectionHandler(ConnectionHandler): for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE): yield runtime_pb2.ExpertResponse(tensors=[part]) + def _get_active_adapter(self, metadata: dict) -> str: + active_adapter = metadata.get("active_adapter", "") + if active_adapter and (active_adapter not in self.adapters): + raise KeyError(f"adapter {active_adapter} not found") + return active_adapter + def _serialize_grads( self, grads: Sequence[torch.Tensor], diff --git a/src/petals/server/server.py b/src/petals/server/server.py index c90ae44..83a94e3 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -534,6 +534,7 @@ class ModuleContainer(threading.Thread): TransformerConnectionHandler( dht, self.module_backends, + adapters=adapters, dht_prefix=dht_prefix, push_manager=self.push_manager, session_queues=session_queues, diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index c537a32..b182181 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -1,3 +1,4 @@ +import contextlib import re import time from typing import Optional, Sequence @@ -118,6 +119,47 @@ def load_peft( time.sleep(delay) +class AdapterContextMixin: + """A mixin that makes LoRA-wrapped linear layers obey an adapter set from context""" + + ADAPTER_NOT_SET = "__ADAPTER_NOT_SET" + _context_active_adapter = ADAPTER_NOT_SET + + @staticmethod + @contextlib.contextmanager + def using_adapter(active_adapter: Optional[str]): + prev, AdapterContextMixin._context_active_adapter = AdapterContextMixin._context_active_adapter, active_adapter + try: + yield + finally: + AdapterContextMixin._context_active_adapter = prev + + @property + def active_adapter(self): + if self._context_active_adapter == self.ADAPTER_NOT_SET: + logger.warning(f"Layer {self} was called without using_adapter. This should only be used for debug") + return self._context_active_adapter + + @active_adapter.setter + def active_adapter(self, value: Optional[str]): + assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed via .using_adapter" "" + + +using_adapter = AdapterContextMixin.using_adapter + + +class LoraLinear(lora.Linear, AdapterContextMixin): + """LoRA linear layer that uses adapter selected via using_adapter""" + + +class LoraLinear8bitLt(lora.Linear8bitLt, AdapterContextMixin): + """LoRA linear 8-bit with outliers that uses adapter selected via using_adapter""" + + +class LoraLinear4bit(lora.Linear4bit, AdapterContextMixin): + """LoRA linear 4-bit that uses adapter selected via using_adapter""" + + def create_lora_adapter(block, quant_type: QuantType): for _, module in block.named_modules(): for child_name, child in module.named_children(): @@ -130,8 +172,8 @@ def create_lora_adapter(block, quant_type: QuantType): "threshold": 6.0, "bias": hasattr(child, "bias") and child.bias is not None, } - lora_wrapped_child = lora.Linear8bitLt( - child_name, + lora_wrapped_child = LoraLinear8bitLt( + AdapterContextMixin.ADAPTER_NOT_SET, child.in_features, child.out_features, **kwargs, @@ -143,22 +185,21 @@ def create_lora_adapter(block, quant_type: QuantType): "blocksize": 64, "bias": hasattr(child, "bias") and child.bias is not None, } - lora_wrapped_child = lora.Linear4bit( - child_name, + lora_wrapped_child = LoraLinear4bit( + AdapterContextMixin.ADAPTER_NOT_SET, child.in_features, child.out_features, **kwargs, ) else: bias = hasattr(child, "bias") and child.bias is not None - lora_wrapped_child = lora.Linear( - child_name, + lora_wrapped_child = LoraLinear( + AdapterContextMixin.ADAPTER_NOT_SET, child.in_features, child.out_features, bias=bias, ) if lora_wrapped_child: - lora_wrapped_child.active_adapter = None lora_wrapped_child.weight = child.weight lora_wrapped_child.bias = child.bias for p in lora_wrapped_child.parameters():