Support peft LoRA adapters (#335)

Implement an option to deploy PEFT adapters to a server. Clients can set active_adapter=... to use these adapters.

---------

Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: justheuristic <justheuristic@gmail.com>
This commit is contained in:
Artem Chumachenko 2023-07-12 16:22:28 +04:00 committed by GitHub
parent dfc6578c8e
commit b9f0a5467f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 444 additions and 34 deletions

View File

@ -33,10 +33,11 @@ jobs:
run: |
export MODEL_NAME=bigscience/bloom-560m
export REF_NAME=bigscience/bloom-560m
export ADAPTER_NAME=artek0chumak/bloom-560m-safe-peft
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
--new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \
--torch_dtype float32 --compression NONE --attn_cache_tokens 2048 &> server1.log &
--torch_dtype float32 --compression NONE --attn_cache_tokens 2048 --adapters $ADAPTER_NAME &> server1.log &
SERVER1_PID=$!
sleep 5 # wait for the first server to initialize DHT
@ -45,17 +46,17 @@ jobs:
# ^-- server 1 multiaddr is determined by --identity and --host_maddrs
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:22 \
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server2.log &
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server2.log &
SERVER2_PID=$!
sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:5 \
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server3.log &
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:15 \
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server3.log &
SERVER3_PID=$!
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server4.log &
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server4.log &
SERVER4_PID=$!
tail -n 100 -f server*.log &

View File

@ -46,6 +46,8 @@ install_requires =
cpufeature>=0.2.0
packaging>=20.9
sentencepiece>=0.1.99
peft@git+https://github.com/huggingface/peft@5884bdbea49e5e71e2cd06ecfa484bb635063735
safetensors>=0.3.1
[options.extras_require]
dev =

View File

@ -146,6 +146,8 @@ def main():
help="Skip checking this server's reachability via health.petals.ml "
"when connecting to the public swarm. If you connect to a private swarm, "
"the check is skipped by default. Use this option only if you know what you are doing")
parser.add_argument("--adapters", nargs='+', default=None, help="List of pretrained LoRA adapters that can be used for inference or training.")
# fmt:on
args = vars(parser.parse_args())

View File

@ -28,6 +28,7 @@ class RemoteSequential(nn.Module):
dht: Optional[DHT] = None,
start_block: Optional[int] = None,
end_block: Optional[int] = None,
**kwargs,
):
super().__init__()
self.config = config
@ -41,7 +42,7 @@ class RemoteSequential(nn.Module):
if end_block is None:
end_block = self.config.num_hidden_layers
block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block))
sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht)
sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs)
self.sequence_manager = sequence_manager
def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):

View File

@ -43,6 +43,7 @@ class SequenceManagerConfig:
min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
max_backoff: float = 60 # limit maximal sleep time between retries to this value
ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
active_adapter: Optional[str] = None
@dataclasses.dataclass
@ -78,6 +79,7 @@ class RemoteSequenceManager:
*,
dht: Optional[DHT] = None,
state: Optional[SequenceManagerState] = None,
active_adapter: Optional[str] = None,
):
assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
@ -115,7 +117,9 @@ class RemoteSequenceManager:
if state.sequence_info.last_updated_time is None:
# Pre-fetch module infos in DHT in parallel with .from_pretrained(), then use cached records
# in the first _update() instead of the latest ones. This makes the first .update() faster.
petals.dht_utils.get_remote_module_infos(self.dht, self.block_uids, latest=True, return_future=True)
petals.dht_utils.get_remote_module_infos(
self.dht, self.block_uids, active_adapter=active_adapter, latest=True, return_future=True
)
self._need_latest_infos = False
else:
assert block_uids == state.sequence_info.block_uids
@ -179,7 +183,7 @@ class RemoteSequenceManager:
def _update(self):
"""Perform an immediate and synchronous refresh, may take time"""
new_block_infos = petals.dht_utils.get_remote_module_infos(
self.dht, self.block_uids, latest=self._need_latest_infos
self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=self._need_latest_infos
)
self._need_latest_infos = True # All future _update() should use latest infos
@ -307,7 +311,7 @@ class RemoteSequenceManager:
:param kwargs: additional request context, such as remote peer ID
:returns: msgpack-serialized metadata dict that will be passed alongside a given request
"""
return dict(points=self.policy.get_points(protocol, *args, **kwargs))
return dict(points=self.policy.get_points(protocol, *args, **kwargs), active_adapter=self.config.active_adapter)
def shutdown(self):
self._thread.shutdown()

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import dataclasses
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Tuple
from typing import Any, Dict, Optional, Tuple
from hivemind import PeerID
from hivemind.moe.expert_uid import ExpertUID
@ -57,3 +57,4 @@ class InferenceMetadata:
uid: ExpertUID
prefix_length: int
cache_handles: Tuple[Handle, ...]
active_adapter: Optional[str]

View File

@ -22,6 +22,7 @@ def declare_active_modules(
expiration_time: DHTExpiration,
state: ServerState,
throughput: float,
adapters: Optional[Sequence[str]] = None,
wait: bool = True,
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
"""
@ -39,6 +40,7 @@ def declare_active_modules(
uids = list(uids)
for uid in uids:
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
return dht.run_coroutine(
partial(
_declare_active_modules,
@ -46,6 +48,7 @@ def declare_active_modules(
expiration_time=expiration_time,
state=state,
throughput=throughput,
adapters=list(adapters or []),
),
return_future=not wait,
)
@ -58,12 +61,13 @@ async def _declare_active_modules(
expiration_time: DHTExpiration,
state: ServerState,
throughput: float,
adapters: List[str],
) -> Dict[ModuleUID, bool]:
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
return await node.store_many(
keys=uids,
subkeys=[dht.peer_id.to_base58()] * len(uids),
values=[(state.value, throughput)] * len(uids),
values=[(state.value, throughput, dict(adapters=adapters))] * len(uids),
expiration_time=expiration_time,
num_workers=num_workers,
)
@ -73,18 +77,30 @@ def get_remote_module_infos(
dht: DHT,
uids: Sequence[ModuleUID],
expiration_time: Optional[DHTExpiration] = None,
active_adapter: Optional[str] = None,
*,
latest: bool = False,
return_future: bool = False,
) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
return dht.run_coroutine(
partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time, latest=latest),
partial(
_get_remote_module_infos,
uids=uids,
active_adapter=active_adapter,
expiration_time=expiration_time,
latest=latest,
),
return_future=return_future,
)
async def _get_remote_module_infos(
dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration], latest: bool
dht: DHT,
node: DHTNode,
uids: List[ModuleUID],
active_adapter: Optional[str],
expiration_time: Optional[DHTExpiration],
latest: bool,
) -> List[Optional[RemoteModuleInfo]]:
if latest:
assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
@ -105,7 +121,13 @@ async def _get_remote_module_infos(
for peer_id, server_info in metadata.value.items():
try:
peer_id = PeerID.from_base58(peer_id)
state, throughput = server_info.value
state, throughput = server_info.value[:2]
extra_info = server_info.value[2] if len(server_info.value) > 2 else {}
adapters = extra_info.get("adapters", [])
if bool(active_adapter) and active_adapter not in adapters:
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
continue
if not (
isinstance(state, int)
and isinstance(throughput, float)

View File

@ -2,8 +2,9 @@ from __future__ import annotations
from collections import Counter
from itertools import chain
from typing import Any, Dict, Optional, Sequence, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple, Union
import peft
import torch
from hivemind import BatchTensorDescriptor, TensorDescriptor
from hivemind.moe.expert_uid import ExpertUID
@ -80,6 +81,18 @@ class TransformerBackend(ModuleBackend):
cache_tensors.extend((keys, values))
return cache_tensors
def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
*inputs, active_adapter = inputs
if not self.load_adapter_(active_adapter):
raise KeyError(f"Could not find adapter {active_adapter}; perhaps it is not loaded")
return super().forward(*inputs)
def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
*inputs, active_adapter = inputs
if not self.load_adapter_(active_adapter):
raise KeyError(f"Could not find adapter {active_adapter}; perhaps it is not loaded")
return super().backward(*inputs)
@torch.inference_mode()
def inference_step(
self,
@ -88,6 +101,8 @@ class TransformerBackend(ModuleBackend):
inference_info: InferenceMetadata,
) -> Tuple[torch.Tensor, ...]:
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
if not self.load_adapter_(inference_info.active_adapter):
raise KeyError(f"Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
self._reorder_cache_inplace(cache_tensors, hypo_ids)
layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
@ -139,6 +154,16 @@ class TransformerBackend(ModuleBackend):
for p in self.module.parameters():
p.data = dummy
def load_adapter_(self, active_adapter: Optional[str] = None) -> bool:
"""Activate a given adapter set if available. Return True if available (or no adapter), False if missing"""
adapter_was_loaded = False
for layer in self.module.modules(): # select adapter set -- leave empty string for no adapter
if isinstance(layer, (peft.tuners.lora.Linear, peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
layer.active_adapter = active_adapter # empty string for no adapter
if active_adapter in layer.lora_A.keys():
adapter_was_loaded = True
return adapter_was_loaded or not active_adapter
def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):
"""Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""

View File

@ -141,6 +141,7 @@ class TransformerConnectionHandler(ConnectionHandler):
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
max_length = metadata.get("max_length")
active_adapter = metadata.get("active_adapter", "")
points = metadata.get("points", 0)
session_id = metadata.get("session_id")
@ -201,7 +202,7 @@ class TransformerConnectionHandler(ConnectionHandler):
)
inference_infos = tuple(
InferenceMetadata(uid, prefix_length, tuple(handles))
InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
for uid, handles in zip(requested_uids, cache_handles)
)
@ -354,13 +355,18 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
active_adapter = metadata.get("active_adapter", "")
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
), f"rpc_forward should have number of points as number or None, got {points}"
hidden_states = await _rpc_forward(
*flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
*flat_inputs,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
)
return runtime_pb2.ExpertResponse(
tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
@ -376,13 +382,18 @@ class TransformerConnectionHandler(ConnectionHandler):
self._log_request("rpc_forward_stream", requested_uids, context)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
active_adapter = metadata.get("active_adapter", "")
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
), f"rpc_forward_stream should have number of points as number or None, got {points}"
hidden_states = await _rpc_forward(
*flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
*flat_inputs,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
)
# Split the serialized_output for streaming and respond to client
@ -422,13 +433,18 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
active_adapter = metadata.get("active_adapter", "")
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
), f"rpc_backward should have number of points as number or None, got {points}"
grads = await _rpc_backward(
*flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
*flat_tensors,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
)
return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
@ -442,13 +458,18 @@ class TransformerConnectionHandler(ConnectionHandler):
self._log_request("rpc_backward_stream", requested_uids, context)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
active_adapter = metadata.get("active_adapter", "")
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
), f"rpc_backward_stream should have number of points as number or None, got {points}"
grads = await _rpc_backward(
*flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
*flat_tensors,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
)
# Split the serialized_grad_inputs for streaming and respond
for tensor in self._serialize_grads(grads, requested_backends, metadata):
@ -553,6 +574,7 @@ class TransformerConnectionHandler(ConnectionHandler):
async def _rpc_forward(
*flat_tensors: torch.Tensor,
requested_backends: Sequence[TransformerBackend],
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
) -> torch.Tensor:
@ -585,6 +607,7 @@ async def _rpc_forward(
)
(hidden_states,) = await backend.forward_pool.submit_task(
hidden_states,
active_adapter,
priority=priority,
)
assert isinstance(hidden_states, torch.Tensor)
@ -598,6 +621,7 @@ async def _rpc_forward(
async def _rpc_backward(
*flat_tensors: torch.Tensor,
requested_backends: Sequence[TransformerBackend],
active_adapter: str = "",
prioritizer: TaskPrioritizerBase,
points: int = 0,
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
@ -623,7 +647,7 @@ async def _rpc_backward(
priority = prioritizer.prioritize(
inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
)
(inputs,) = await backend.forward_pool.submit_task(inputs, priority=priority)
(inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
assert isinstance(inputs, torch.Tensor)
@ -639,7 +663,7 @@ async def _rpc_backward(
priority = prioritizer.prioritize(
inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
)
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, priority=priority)
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
assert isinstance(grad_outputs, torch.Tensor)
if not is_dummy(prompt):

View File

@ -81,6 +81,7 @@ class Server:
dht_client_mode: Optional[bool] = None,
use_relay: bool = True,
use_auto_relay: bool = True,
adapters: Optional[List[str]] = None,
**kwargs,
):
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
@ -218,6 +219,8 @@ class Server:
self.mean_balance_check_period = mean_balance_check_period
self.mean_block_selection_delay = mean_block_selection_delay
self.adapters = adapters
self.stop = threading.Event()
def _choose_num_blocks(self) -> int:
@ -291,6 +294,7 @@ class Server:
quant_type=self.quant_type,
tensor_parallel_devices=self.tensor_parallel_devices,
should_validate_reachability=self.should_validate_reachability,
adapters=self.adapters,
start=True,
)
try:
@ -384,6 +388,7 @@ class ModuleContainer(threading.Thread):
quant_type: QuantType,
tensor_parallel_devices: Sequence[torch.device],
should_validate_reachability: bool,
adapters: Optional[List[str]] = None,
**kwargs,
) -> ModuleContainer:
module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
@ -391,6 +396,7 @@ class ModuleContainer(threading.Thread):
module_uids,
dht,
ServerState.JOINING,
adapters=adapters,
throughput=throughput,
update_period=update_period,
expiration=expiration,
@ -415,7 +421,19 @@ class ModuleContainer(threading.Thread):
cache_dir=cache_dir,
max_disk_space=max_disk_space,
)
block = convert_block(block, block_config, tensor_parallel_devices, device, quant_type, freeze=True)
block = convert_block(
block,
block_index,
block_config,
tensor_parallel_devices,
device,
quant_type,
adapters=adapters,
freeze=True,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
max_disk_space=max_disk_space,
)
blocks[module_uid] = TransformerBackend(
module_uid,
block,
@ -452,6 +470,7 @@ class ModuleContainer(threading.Thread):
expiration_time=get_dht_time() + expiration,
state=ServerState.OFFLINE,
throughput=throughput,
adapters=adapters,
)
logger.info(f"Announced that blocks {module_uids} are offline")
raise
@ -465,6 +484,7 @@ class ModuleContainer(threading.Thread):
dht,
dht_prefix,
blocks,
adapters=adapters,
throughput=throughput,
update_period=update_period,
expiration=expiration,
@ -480,6 +500,7 @@ class ModuleContainer(threading.Thread):
inference_max_length: int,
num_handlers: int,
throughput: float,
adapters: Optional[Sequence[str]],
update_period: float,
expiration: Optional[float] = None,
request_timeout: float,
@ -517,6 +538,7 @@ class ModuleContainer(threading.Thread):
list(self.module_backends.keys()),
dht,
ServerState.ONLINE,
adapters=adapters,
throughput=throughput,
update_period=update_period,
expiration=expiration,
@ -616,6 +638,7 @@ class ModuleAnnouncerThread(threading.Thread):
module_uids: List[str],
dht: DHT,
state: ServerState,
adapters: Optional[Sequence[str]],
*,
throughput: float,
update_period: float = 30,
@ -626,6 +649,7 @@ class ModuleAnnouncerThread(threading.Thread):
self.module_uids = module_uids
self.dht = dht
self.state = state
self.adapters = adapters
self.throughput = throughput
self.update_period = update_period
self.expiration = expiration
@ -639,6 +663,7 @@ class ModuleAnnouncerThread(threading.Thread):
expiration_time=get_dht_time() + self.expiration,
state=self.state,
throughput=self.throughput,
adapters=self.adapters,
)
if self.stop.wait(self.update_period):
break

View File

@ -172,7 +172,7 @@ def measure_compute_rps(
tensor_parallel_devices = (device,)
with torch.inference_mode():
block = config.block_class(config).to(dtype)
block = convert_block(block, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
cache = None
elapsed = 0

View File

@ -3,8 +3,7 @@ Tools for converting transformer blocks, applying quantization and/or tensor par
"""
import os
import re
from enum import Enum
from typing import Sequence
from typing import List, Optional, Sequence
import tensor_parallel as tp
import torch
@ -13,23 +12,23 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from tensor_parallel.slicing_configs import get_bloom_config
from transformers import PretrainedConfig
from petals.utils.misc import QuantType
from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__)
class QuantType(Enum):
NONE = 0
INT8 = 1 # 8-bit as in the LLM.int8() paper
NF4 = 2 # 4-bit as in the QLoRA paper
def convert_block(
block: nn.Module,
block_index: int,
config: PretrainedConfig,
tensor_parallel_devices: Sequence[torch.device],
output_device: torch.device,
quant_type: QuantType,
freeze: bool = True,
adapters: Optional[List[str]] = None,
**kwargs,
) -> tp.TensorParallel:
"""
Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization
@ -56,6 +55,16 @@ def convert_block(
for shard, device in zip(block.module_shards, block.devices):
shard.to(device)
if adapters:
create_lora_adapter(block, quant_type=quant_type)
for adapter_name in adapters:
adapter_config, adapter_state_dict = load_peft(
adapter_name,
block_idx=block_index,
**kwargs,
)
add_adapter_to_block(block, block_index, adapter_name, adapter_config, adapter_state_dict)
return block

View File

@ -1,5 +1,14 @@
from enum import Enum
import torch
class QuantType(Enum):
NONE = 0
INT8 = 1 # 8-bit as in the LLM.int8() paper
NF4 = 2 # 4-bit as in the QLoRA paper
DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter parameters

208
src/petals/utils/peft.py Normal file
View File

@ -0,0 +1,208 @@
import re
import time
from typing import List, Optional
import bitsandbytes as bnb
import torch.nn as nn
from hivemind.utils.logging import get_logger
from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
from peft.tuners import lora
from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
from safetensors import safe_open
from safetensors.torch import load_file
from transformers.utils import get_file_from_repo
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.misc import QuantType
logger = get_logger(__name__)
def check_peft_repository(repo_id: str) -> bool:
fs = HfFileSystem()
list_of_files = fs.glob(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}", detail=False)
return len(list_of_files) > 0
def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None):
tensors = dict()
is_tensors_found = dict()
common_layer_patter_re = (
".+\." + "".join(f"({common_name})?" for common_name in COMMON_LAYERS_PATTERN) + f"\.({block_idx})?\..+"
)
with safe_open(filepath, framework=framework, device=device) as f:
for k in f.keys():
if re.match(common_layer_patter_re, k):
is_tensors_found[block_idx] = True
tensors[k] = f.get_tensor(k)
if not is_tensors_found.get(block_idx, False):
logger.warning(f"There is no peft weights for block {block_idx}")
return tensors
def get_adapter_from_repo(repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, **kwargs):
config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
if config_path is None:
raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
config = PeftConfig.from_json_file(config_path)
weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, **kwargs)
if weight_path is None:
raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
if block_idx is None:
return config, load_file(weight_path)
return config, load_specific_module(block_idx, weight_path, device=device)
def load_peft(
repo_id: str,
block_idx: Optional[int] = None,
device: Optional[int] = None,
*,
revision: Optional[str] = None,
use_auth_token: Optional[str] = None,
cache_dir: str,
max_disk_space: Optional[int] = None,
delay: float = 30,
):
# TODO: Check is it possible to add safetensors loading inside petals/server/from_pretrained.py and reuse it here
if not check_peft_repository(repo_id):
raise ValueError(f"Repo: {repo_id} doesn't have safetensors inside for a safe loading.")
try:
with allow_cache_reads(cache_dir):
return get_adapter_from_repo(
repo_id,
block_idx,
device,
revision=revision,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
local_files_only=False,
)
except Exception:
logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True)
while True:
try:
with allow_cache_writes(cache_dir):
config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
config_file_size = get_hf_file_metadata(config_url, token=use_auth_token).size
weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
weight_file_size = get_hf_file_metadata(weight_url, token=use_auth_token).size
file_size = config_file_size + weight_file_size
if file_size is not None:
free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
else:
logger.warning(f"Failed to fetch size from peft repo {repo_id}")
return get_adapter_from_repo(
repo_id,
block_idx,
device,
revision=revision,
use_auth_token=use_auth_token,
cache_dir=cache_dir,
local_files_only=False,
)
except Exception as e:
logger.warning(
f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True
)
time.sleep(delay)
def create_lora_adapter(block, quant_type: QuantType):
for name, module in block.named_modules():
for child_name, child in module.named_children():
lora_wrapped_child = None
if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)):
continue
if quant_type == QuantType.INT8:
kwargs = {
"has_fp16_weights": False,
"threshold": 6.0,
"bias": hasattr(child, "bias") and child.bias is not None,
}
lora_wrapped_child = lora.Linear8bitLt(
child_name,
child.in_features,
child.out_features,
**kwargs,
)
elif quant_type == QuantType.NF4:
kwargs = {
"compress_statistics": True,
"quant_type": "nf4",
"blocksize": 64,
"bias": hasattr(child, "bias") and child.bias is not None,
}
lora_wrapped_child = lora.Linear4bit(
child_name,
child.in_features,
child.out_features,
**kwargs,
)
else:
bias = hasattr(child, "bias") and child.bias is not None
lora_wrapped_child = lora.Linear(
child_name,
child.in_features,
child.out_features,
bias=bias,
)
if lora_wrapped_child:
lora_wrapped_child.active_adapter = None
lora_wrapped_child.weight = child.weight
lora_wrapped_child.bias = child.bias
for p in lora_wrapped_child.parameters():
p.requires_grad = False
setattr(module, child_name, lora_wrapped_child)
def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict):
assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters"
for name, module in block.named_modules():
for child_name, child in module.named_children():
if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
continue
if child_name in peft_config["target_modules"] or (
isinstance(peft_config["target_modules"], str)
and re.fullmatch(peft_config["target_modules"], child_name)
):
is_lora_a_loaded = False
is_lora_b_loaded = False
for peft_key in peft_state_dict:
if peft_key.find(child_name) == -1:
continue
if adapter_name not in child.lora_A:
child.update_layer(
adapter_name,
peft_config["r"],
peft_config["lora_alpha"],
lora_dropout=peft_config["lora_dropout"],
init_lora_weights=peft_config["init_lora_weights"],
)
child.train(False)
if peft_config["lora_dropout"] > 0:
logger.warning("Loading LoRA config with dropout enabled; this server will disable dropout")
for p in child.parameters():
p.requires_grad = False
if peft_key.endswith(".lora_A.weight"):
child.lora_A[adapter_name].weight.data = peft_state_dict[peft_key]
is_lora_a_loaded = True
elif peft_key.endswith(".lora_A.bias"):
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
elif peft_key.endswith(".lora_B.weight"):
child.lora_B[adapter_name].weight.data = peft_state_dict[peft_key]
is_lora_b_loaded = True
elif peft_key.endswith(".lora_B.bias"):
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
if is_lora_a_loaded and is_lora_b_loaded:
logger.info(f"Loading {adapter_name} for block {block_index}.{child_name} is ended successfully")

View File

@ -1,3 +1,4 @@
import peft
import pytest
import torch
import transformers
@ -12,11 +13,16 @@ logger = get_logger(__name__)
@pytest.mark.forked
@pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,))
@pytest.mark.parametrize("pass_empty_tensors", (True, False))
def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
MODEL_NAME,
initial_peers=INITIAL_PEERS,
low_cpu_mem_usage=True,
torch_dtype=torch.float32,
active_adapter=ADAPTER_NAME if use_peft else None,
)
config = model.config
assert isinstance(model, DistributedBloomForCausalLM)
@ -54,6 +60,9 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato
ref_model = transformers.BloomForCausalLM.from_pretrained(
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
if use_peft:
ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
ref_model.train(False)
if config.vocab_size < ref_model.config.vocab_size:
ref_model.resize_token_embeddings(config.vocab_size)
logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")

66
tests/test_peft.py Normal file
View File

@ -0,0 +1,66 @@
import os
import shutil
import pytest
from huggingface_hub import snapshot_download
from petals.utils.peft import check_peft_repository, load_peft
UNSAFE_PEFT_REPO = "artek0chumak/bloom-560m-unsafe-peft"
SAFE_PEFT_REPO = "artek0chumak/bloom-560m-safe-peft"
TMP_CACHE_DIR = "tmp_cache/"
def clear_dir(path_to_dir):
shutil.rmtree(path_to_dir)
os.mkdir(path_to_dir)
def dir_empty(path_to_dir):
files = os.listdir(path_to_dir)
return len(files) == 0
@pytest.mark.forked
def test_check_peft():
assert not check_peft_repository(UNSAFE_PEFT_REPO), "NOSAFE_PEFT_REPO is safe to load."
assert check_peft_repository(SAFE_PEFT_REPO), "SAFE_PEFT_REPO is not safe to load."
@pytest.mark.forked
def test_load_noncached(tmpdir):
clear_dir(tmpdir)
with pytest.raises(Exception):
load_peft(UNSAFE_PEFT_REPO, cache_dir=tmpdir)
assert dir_empty(tmpdir), "UNSAFE_PEFT_REPO is loaded"
load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
assert not dir_empty(tmpdir), "SAFE_PEFT_REPO is not loaded"
@pytest.mark.forked
def test_load_cached(tmpdir):
clear_dir(tmpdir)
snapshot_download(SAFE_PEFT_REPO, cache_dir=tmpdir)
load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
@pytest.mark.forked
def test_load_layer_exists(tmpdir):
clear_dir(tmpdir)
load_peft(SAFE_PEFT_REPO, block_idx=2, cache_dir=tmpdir)
@pytest.mark.forked
def test_load_layer_nonexists(tmpdir):
clear_dir(tmpdir)
load_peft(
SAFE_PEFT_REPO,
block_idx=1337,
cache_dir=tmpdir,
)

View File

@ -11,3 +11,5 @@ if not MODEL_NAME:
raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
REF_NAME = os.environ.get("REF_NAME")
ADAPTER_NAME = os.environ.get("ADAPTER_NAME")