Switch adapters slightly faster (#353)

Currently, each `TransformerBackend.inference_step` looks for adapters and sets the correct adapter type for each block. This is not very expensive, but it can measurably affect inference time.

This pull request uses faster adapter switching with just one variable assignment, without iterating over block.modules().
pull/357/head
justheuristic 10 months ago committed by GitHub
parent 9703358df0
commit 37fdcb3fe0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -24,9 +24,15 @@ logger = get_logger(__name__)
class TransformerBackend(ModuleBackend):
"""A wrapper for a transformer block that can process requests for forward, backward and inference"""
_peft_module = None
def __init__(
self, *args, config: PretrainedConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs
):
import petals.utils.peft as _peft_module
self._peft_module = _peft_module
super().__init__(*args, **kwargs)
assert isinstance(self.module, TensorParallel)
self.config = config
@ -82,13 +88,13 @@ class TransformerBackend(ModuleBackend):
def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
*inputs, active_adapter = inputs
self.load_adapter_(active_adapter)
return super().forward(*inputs)
with self._peft_module.using_adapter(active_adapter):
return super().forward(*inputs)
def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
*inputs, active_adapter = inputs
self.load_adapter_(active_adapter)
return super().backward(*inputs)
with self._peft_module.using_adapter(active_adapter):
return super().backward(*inputs)
@torch.inference_mode()
def inference_step(
@ -98,8 +104,9 @@ class TransformerBackend(ModuleBackend):
inference_info: InferenceMetadata,
) -> Tuple[torch.Tensor, ...]:
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
self.load_adapter_(inference_info.active_adapter)
with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
with self.memory_cache.use_cache(
*inference_info.cache_handles
) as cache_tensors, self._peft_module.using_adapter(inference_info.active_adapter):
self._reorder_cache_inplace(cache_tensors, hypo_ids)
layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
hidden_states, new_kvs = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
@ -150,22 +157,6 @@ class TransformerBackend(ModuleBackend):
for p in self.module.parameters():
p.data = dummy
def load_adapter_(self, active_adapter: Optional[str] = None) -> bool:
"""Activate a given adapter set if available. Return True if available (or no adapter), False if missing"""
# Import petals.utils.peft only when necessary to avoid importing bitsandbytes
from peft.tuners.lora import Linear, Linear4bit, Linear8bitLt
loaded = False
for layer in self.module.modules(): # select adapter set -- leave empty string for no adapter
if isinstance(layer, (Linear, Linear4bit, Linear8bitLt)):
layer.active_adapter = active_adapter # empty string for no adapter
if active_adapter in layer.lora_A.keys():
loaded = True
if active_adapter and not loaded:
raise KeyError(f"Could not find adapter {active_adapter}, perhaps it is not loaded")
def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):
"""Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""

@ -68,6 +68,7 @@ class TransformerConnectionHandler(ConnectionHandler):
dht: DHT,
module_backends: Dict[str, TransformerBackend],
*,
adapters: Optional[Sequence[str]],
dht_prefix: str,
push_manager: multiprocessing.managers.SyncManager,
session_queues: Dict[str, multiprocessing.managers.BaseProxy], # BaseProxy for queue.Queue
@ -81,6 +82,7 @@ class TransformerConnectionHandler(ConnectionHandler):
for module_backend in self.module_backends.values():
assert isinstance(module_backend, TransformerBackend)
self.dht_prefix = dht_prefix
self.adapters = adapters
self._push_manager = push_manager
self._session_queues = session_queues
self._executor = ThreadPoolExecutor(max_workers=float("inf")) # For waiting on self.session_queues
@ -141,7 +143,7 @@ class TransformerConnectionHandler(ConnectionHandler):
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
max_length = metadata.get("max_length")
active_adapter = metadata.get("active_adapter", "")
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
session_id = metadata.get("session_id")
@ -355,7 +357,7 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
active_adapter = metadata.get("active_adapter", "")
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
@ -382,7 +384,7 @@ class TransformerConnectionHandler(ConnectionHandler):
self._log_request("rpc_forward_stream", requested_uids, context)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
active_adapter = metadata.get("active_adapter", "")
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
@ -433,7 +435,7 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
active_adapter = metadata.get("active_adapter", "")
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
@ -458,7 +460,7 @@ class TransformerConnectionHandler(ConnectionHandler):
self._log_request("rpc_backward_stream", requested_uids, context)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
active_adapter = metadata.get("active_adapter", "")
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
@ -476,6 +478,12 @@ class TransformerConnectionHandler(ConnectionHandler):
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
yield runtime_pb2.ExpertResponse(tensors=[part])
def _get_active_adapter(self, metadata: dict) -> str:
active_adapter = metadata.get("active_adapter", "")
if active_adapter and (active_adapter not in self.adapters):
raise KeyError(f"adapter {active_adapter} not found")
return active_adapter
def _serialize_grads(
self,
grads: Sequence[torch.Tensor],

@ -534,6 +534,7 @@ class ModuleContainer(threading.Thread):
TransformerConnectionHandler(
dht,
self.module_backends,
adapters=adapters,
dht_prefix=dht_prefix,
push_manager=self.push_manager,
session_queues=session_queues,

@ -1,3 +1,4 @@
import contextlib
import re
import time
from typing import Optional, Sequence
@ -118,6 +119,47 @@ def load_peft(
time.sleep(delay)
class AdapterContextMixin:
"""A mixin that makes LoRA-wrapped linear layers obey an adapter set from context"""
ADAPTER_NOT_SET = "__ADAPTER_NOT_SET"
_context_active_adapter = ADAPTER_NOT_SET
@staticmethod
@contextlib.contextmanager
def using_adapter(active_adapter: Optional[str]):
prev, AdapterContextMixin._context_active_adapter = AdapterContextMixin._context_active_adapter, active_adapter
try:
yield
finally:
AdapterContextMixin._context_active_adapter = prev
@property
def active_adapter(self):
if self._context_active_adapter == self.ADAPTER_NOT_SET:
logger.warning(f"Layer {self} was called without using_adapter. This should only be used for debug")
return self._context_active_adapter
@active_adapter.setter
def active_adapter(self, value: Optional[str]):
assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed via .using_adapter" ""
using_adapter = AdapterContextMixin.using_adapter
class LoraLinear(lora.Linear, AdapterContextMixin):
"""LoRA linear layer that uses adapter selected via using_adapter"""
class LoraLinear8bitLt(lora.Linear8bitLt, AdapterContextMixin):
"""LoRA linear 8-bit with outliers that uses adapter selected via using_adapter"""
class LoraLinear4bit(lora.Linear4bit, AdapterContextMixin):
"""LoRA linear 4-bit that uses adapter selected via using_adapter"""
def create_lora_adapter(block, quant_type: QuantType):
for _, module in block.named_modules():
for child_name, child in module.named_children():
@ -130,8 +172,8 @@ def create_lora_adapter(block, quant_type: QuantType):
"threshold": 6.0,
"bias": hasattr(child, "bias") and child.bias is not None,
}
lora_wrapped_child = lora.Linear8bitLt(
child_name,
lora_wrapped_child = LoraLinear8bitLt(
AdapterContextMixin.ADAPTER_NOT_SET,
child.in_features,
child.out_features,
**kwargs,
@ -143,22 +185,21 @@ def create_lora_adapter(block, quant_type: QuantType):
"blocksize": 64,
"bias": hasattr(child, "bias") and child.bias is not None,
}
lora_wrapped_child = lora.Linear4bit(
child_name,
lora_wrapped_child = LoraLinear4bit(
AdapterContextMixin.ADAPTER_NOT_SET,
child.in_features,
child.out_features,
**kwargs,
)
else:
bias = hasattr(child, "bias") and child.bias is not None
lora_wrapped_child = lora.Linear(
child_name,
lora_wrapped_child = LoraLinear(
AdapterContextMixin.ADAPTER_NOT_SET,
child.in_features,
child.out_features,
bias=bias,
)
if lora_wrapped_child:
lora_wrapped_child.active_adapter = None
lora_wrapped_child.weight = child.weight
lora_wrapped_child.bias = child.bias
for p in lora_wrapped_child.parameters():

Loading…
Cancel
Save