mirror of
https://github.com/bigscience-workshop/petals
synced 2024-11-18 03:25:33 +00:00
Support peft LoRA adapters (#335)
Implement an option to deploy PEFT adapters to a server. Clients can set active_adapter=... to use these adapters. --------- Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com> Co-authored-by: justheuristic <justheuristic@gmail.com>
This commit is contained in:
parent
dfc6578c8e
commit
b9f0a5467f
11
.github/workflows/run-tests.yaml
vendored
11
.github/workflows/run-tests.yaml
vendored
@ -33,10 +33,11 @@ jobs:
|
||||
run: |
|
||||
export MODEL_NAME=bigscience/bloom-560m
|
||||
export REF_NAME=bigscience/bloom-560m
|
||||
export ADAPTER_NAME=artek0chumak/bloom-560m-safe-peft
|
||||
|
||||
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
|
||||
--new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \
|
||||
--torch_dtype float32 --compression NONE --attn_cache_tokens 2048 &> server1.log &
|
||||
--torch_dtype float32 --compression NONE --attn_cache_tokens 2048 --adapters $ADAPTER_NAME &> server1.log &
|
||||
SERVER1_PID=$!
|
||||
|
||||
sleep 5 # wait for the first server to initialize DHT
|
||||
@ -45,17 +46,17 @@ jobs:
|
||||
# ^-- server 1 multiaddr is determined by --identity and --host_maddrs
|
||||
|
||||
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:22 \
|
||||
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server2.log &
|
||||
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server2.log &
|
||||
SERVER2_PID=$!
|
||||
|
||||
sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve
|
||||
|
||||
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:5 \
|
||||
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server3.log &
|
||||
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:15 \
|
||||
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server3.log &
|
||||
SERVER3_PID=$!
|
||||
|
||||
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \
|
||||
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server4.log &
|
||||
--initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server4.log &
|
||||
SERVER4_PID=$!
|
||||
|
||||
tail -n 100 -f server*.log &
|
||||
|
@ -46,6 +46,8 @@ install_requires =
|
||||
cpufeature>=0.2.0
|
||||
packaging>=20.9
|
||||
sentencepiece>=0.1.99
|
||||
peft@git+https://github.com/huggingface/peft@5884bdbea49e5e71e2cd06ecfa484bb635063735
|
||||
safetensors>=0.3.1
|
||||
|
||||
[options.extras_require]
|
||||
dev =
|
||||
|
@ -146,6 +146,8 @@ def main():
|
||||
help="Skip checking this server's reachability via health.petals.ml "
|
||||
"when connecting to the public swarm. If you connect to a private swarm, "
|
||||
"the check is skipped by default. Use this option only if you know what you are doing")
|
||||
|
||||
parser.add_argument("--adapters", nargs='+', default=None, help="List of pretrained LoRA adapters that can be used for inference or training.")
|
||||
|
||||
# fmt:on
|
||||
args = vars(parser.parse_args())
|
||||
|
@ -28,6 +28,7 @@ class RemoteSequential(nn.Module):
|
||||
dht: Optional[DHT] = None,
|
||||
start_block: Optional[int] = None,
|
||||
end_block: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -41,7 +42,7 @@ class RemoteSequential(nn.Module):
|
||||
if end_block is None:
|
||||
end_block = self.config.num_hidden_layers
|
||||
block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block))
|
||||
sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht)
|
||||
sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs)
|
||||
self.sequence_manager = sequence_manager
|
||||
|
||||
def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
|
||||
|
@ -43,6 +43,7 @@ class SequenceManagerConfig:
|
||||
min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
|
||||
max_backoff: float = 60 # limit maximal sleep time between retries to this value
|
||||
ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
|
||||
active_adapter: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -78,6 +79,7 @@ class RemoteSequenceManager:
|
||||
*,
|
||||
dht: Optional[DHT] = None,
|
||||
state: Optional[SequenceManagerState] = None,
|
||||
active_adapter: Optional[str] = None,
|
||||
):
|
||||
assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
|
||||
assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
|
||||
@ -115,7 +117,9 @@ class RemoteSequenceManager:
|
||||
if state.sequence_info.last_updated_time is None:
|
||||
# Pre-fetch module infos in DHT in parallel with .from_pretrained(), then use cached records
|
||||
# in the first _update() instead of the latest ones. This makes the first .update() faster.
|
||||
petals.dht_utils.get_remote_module_infos(self.dht, self.block_uids, latest=True, return_future=True)
|
||||
petals.dht_utils.get_remote_module_infos(
|
||||
self.dht, self.block_uids, active_adapter=active_adapter, latest=True, return_future=True
|
||||
)
|
||||
self._need_latest_infos = False
|
||||
else:
|
||||
assert block_uids == state.sequence_info.block_uids
|
||||
@ -179,7 +183,7 @@ class RemoteSequenceManager:
|
||||
def _update(self):
|
||||
"""Perform an immediate and synchronous refresh, may take time"""
|
||||
new_block_infos = petals.dht_utils.get_remote_module_infos(
|
||||
self.dht, self.block_uids, latest=self._need_latest_infos
|
||||
self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=self._need_latest_infos
|
||||
)
|
||||
self._need_latest_infos = True # All future _update() should use latest infos
|
||||
|
||||
@ -307,7 +311,7 @@ class RemoteSequenceManager:
|
||||
:param kwargs: additional request context, such as remote peer ID
|
||||
:returns: msgpack-serialized metadata dict that will be passed alongside a given request
|
||||
"""
|
||||
return dict(points=self.policy.get_points(protocol, *args, **kwargs))
|
||||
return dict(points=self.policy.get_points(protocol, *args, **kwargs), active_adapter=self.config.active_adapter)
|
||||
|
||||
def shutdown(self):
|
||||
self._thread.shutdown()
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from hivemind import PeerID
|
||||
from hivemind.moe.expert_uid import ExpertUID
|
||||
@ -57,3 +57,4 @@ class InferenceMetadata:
|
||||
uid: ExpertUID
|
||||
prefix_length: int
|
||||
cache_handles: Tuple[Handle, ...]
|
||||
active_adapter: Optional[str]
|
||||
|
@ -22,6 +22,7 @@ def declare_active_modules(
|
||||
expiration_time: DHTExpiration,
|
||||
state: ServerState,
|
||||
throughput: float,
|
||||
adapters: Optional[Sequence[str]] = None,
|
||||
wait: bool = True,
|
||||
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
|
||||
"""
|
||||
@ -39,6 +40,7 @@ def declare_active_modules(
|
||||
uids = list(uids)
|
||||
for uid in uids:
|
||||
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
|
||||
|
||||
return dht.run_coroutine(
|
||||
partial(
|
||||
_declare_active_modules,
|
||||
@ -46,6 +48,7 @@ def declare_active_modules(
|
||||
expiration_time=expiration_time,
|
||||
state=state,
|
||||
throughput=throughput,
|
||||
adapters=list(adapters or []),
|
||||
),
|
||||
return_future=not wait,
|
||||
)
|
||||
@ -58,12 +61,13 @@ async def _declare_active_modules(
|
||||
expiration_time: DHTExpiration,
|
||||
state: ServerState,
|
||||
throughput: float,
|
||||
adapters: List[str],
|
||||
) -> Dict[ModuleUID, bool]:
|
||||
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
|
||||
return await node.store_many(
|
||||
keys=uids,
|
||||
subkeys=[dht.peer_id.to_base58()] * len(uids),
|
||||
values=[(state.value, throughput)] * len(uids),
|
||||
values=[(state.value, throughput, dict(adapters=adapters))] * len(uids),
|
||||
expiration_time=expiration_time,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
@ -73,18 +77,30 @@ def get_remote_module_infos(
|
||||
dht: DHT,
|
||||
uids: Sequence[ModuleUID],
|
||||
expiration_time: Optional[DHTExpiration] = None,
|
||||
active_adapter: Optional[str] = None,
|
||||
*,
|
||||
latest: bool = False,
|
||||
return_future: bool = False,
|
||||
) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
|
||||
return dht.run_coroutine(
|
||||
partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time, latest=latest),
|
||||
partial(
|
||||
_get_remote_module_infos,
|
||||
uids=uids,
|
||||
active_adapter=active_adapter,
|
||||
expiration_time=expiration_time,
|
||||
latest=latest,
|
||||
),
|
||||
return_future=return_future,
|
||||
)
|
||||
|
||||
|
||||
async def _get_remote_module_infos(
|
||||
dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration], latest: bool
|
||||
dht: DHT,
|
||||
node: DHTNode,
|
||||
uids: List[ModuleUID],
|
||||
active_adapter: Optional[str],
|
||||
expiration_time: Optional[DHTExpiration],
|
||||
latest: bool,
|
||||
) -> List[Optional[RemoteModuleInfo]]:
|
||||
if latest:
|
||||
assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
|
||||
@ -105,7 +121,13 @@ async def _get_remote_module_infos(
|
||||
for peer_id, server_info in metadata.value.items():
|
||||
try:
|
||||
peer_id = PeerID.from_base58(peer_id)
|
||||
state, throughput = server_info.value
|
||||
state, throughput = server_info.value[:2]
|
||||
extra_info = server_info.value[2] if len(server_info.value) > 2 else {}
|
||||
adapters = extra_info.get("adapters", [])
|
||||
if bool(active_adapter) and active_adapter not in adapters:
|
||||
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
|
||||
continue
|
||||
|
||||
if not (
|
||||
isinstance(state, int)
|
||||
and isinstance(throughput, float)
|
||||
|
@ -2,8 +2,9 @@ from __future__ import annotations
|
||||
|
||||
from collections import Counter
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import peft
|
||||
import torch
|
||||
from hivemind import BatchTensorDescriptor, TensorDescriptor
|
||||
from hivemind.moe.expert_uid import ExpertUID
|
||||
@ -80,6 +81,18 @@ class TransformerBackend(ModuleBackend):
|
||||
cache_tensors.extend((keys, values))
|
||||
return cache_tensors
|
||||
|
||||
def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
|
||||
*inputs, active_adapter = inputs
|
||||
if not self.load_adapter_(active_adapter):
|
||||
raise KeyError(f"Could not find adapter {active_adapter}; perhaps it is not loaded")
|
||||
return super().forward(*inputs)
|
||||
|
||||
def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
|
||||
*inputs, active_adapter = inputs
|
||||
if not self.load_adapter_(active_adapter):
|
||||
raise KeyError(f"Could not find adapter {active_adapter}; perhaps it is not loaded")
|
||||
return super().backward(*inputs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference_step(
|
||||
self,
|
||||
@ -88,6 +101,8 @@ class TransformerBackend(ModuleBackend):
|
||||
inference_info: InferenceMetadata,
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
|
||||
if not self.load_adapter_(inference_info.active_adapter):
|
||||
raise KeyError(f"Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
|
||||
with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
|
||||
self._reorder_cache_inplace(cache_tensors, hypo_ids)
|
||||
layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length)
|
||||
@ -139,6 +154,16 @@ class TransformerBackend(ModuleBackend):
|
||||
for p in self.module.parameters():
|
||||
p.data = dummy
|
||||
|
||||
def load_adapter_(self, active_adapter: Optional[str] = None) -> bool:
|
||||
"""Activate a given adapter set if available. Return True if available (or no adapter), False if missing"""
|
||||
adapter_was_loaded = False
|
||||
for layer in self.module.modules(): # select adapter set -- leave empty string for no adapter
|
||||
if isinstance(layer, (peft.tuners.lora.Linear, peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
|
||||
layer.active_adapter = active_adapter # empty string for no adapter
|
||||
if active_adapter in layer.lora_A.keys():
|
||||
adapter_was_loaded = True
|
||||
return adapter_was_loaded or not active_adapter
|
||||
|
||||
|
||||
def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):
|
||||
"""Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call"""
|
||||
|
@ -141,6 +141,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
||||
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
||||
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
||||
max_length = metadata.get("max_length")
|
||||
active_adapter = metadata.get("active_adapter", "")
|
||||
points = metadata.get("points", 0)
|
||||
session_id = metadata.get("session_id")
|
||||
|
||||
@ -201,7 +202,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
||||
)
|
||||
|
||||
inference_infos = tuple(
|
||||
InferenceMetadata(uid, prefix_length, tuple(handles))
|
||||
InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter)
|
||||
for uid, handles in zip(requested_uids, cache_handles)
|
||||
)
|
||||
|
||||
@ -354,13 +355,18 @@ class TransformerConnectionHandler(ConnectionHandler):
|
||||
|
||||
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
||||
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
||||
active_adapter = metadata.get("active_adapter", "")
|
||||
points = metadata.get("points", 0)
|
||||
assert isinstance(
|
||||
points, (float, int)
|
||||
), f"rpc_forward should have number of points as number or None, got {points}"
|
||||
|
||||
hidden_states = await _rpc_forward(
|
||||
*flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
||||
*flat_inputs,
|
||||
requested_backends=requested_backends,
|
||||
prioritizer=self._prioritizer,
|
||||
active_adapter=active_adapter,
|
||||
points=points,
|
||||
)
|
||||
return runtime_pb2.ExpertResponse(
|
||||
tensors=self._serialize_outputs(hidden_states, requested_backends, metadata)
|
||||
@ -376,13 +382,18 @@ class TransformerConnectionHandler(ConnectionHandler):
|
||||
self._log_request("rpc_forward_stream", requested_uids, context)
|
||||
|
||||
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
||||
active_adapter = metadata.get("active_adapter", "")
|
||||
points = metadata.get("points", 0)
|
||||
assert isinstance(
|
||||
points, (float, int)
|
||||
), f"rpc_forward_stream should have number of points as number or None, got {points}"
|
||||
|
||||
hidden_states = await _rpc_forward(
|
||||
*flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
||||
*flat_inputs,
|
||||
requested_backends=requested_backends,
|
||||
prioritizer=self._prioritizer,
|
||||
active_adapter=active_adapter,
|
||||
points=points,
|
||||
)
|
||||
|
||||
# Split the serialized_output for streaming and respond to client
|
||||
@ -422,13 +433,18 @@ class TransformerConnectionHandler(ConnectionHandler):
|
||||
|
||||
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
||||
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
|
||||
active_adapter = metadata.get("active_adapter", "")
|
||||
points = metadata.get("points", 0)
|
||||
assert isinstance(
|
||||
points, (float, int)
|
||||
), f"rpc_backward should have number of points as number or None, got {points}"
|
||||
|
||||
grads = await _rpc_backward(
|
||||
*flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
||||
*flat_tensors,
|
||||
requested_backends=requested_backends,
|
||||
prioritizer=self._prioritizer,
|
||||
active_adapter=active_adapter,
|
||||
points=points,
|
||||
)
|
||||
|
||||
return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata))
|
||||
@ -442,13 +458,18 @@ class TransformerConnectionHandler(ConnectionHandler):
|
||||
self._log_request("rpc_backward_stream", requested_uids, context)
|
||||
|
||||
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
|
||||
active_adapter = metadata.get("active_adapter", "")
|
||||
points = metadata.get("points", 0)
|
||||
assert isinstance(
|
||||
points, (float, int)
|
||||
), f"rpc_backward_stream should have number of points as number or None, got {points}"
|
||||
|
||||
grads = await _rpc_backward(
|
||||
*flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
|
||||
*flat_tensors,
|
||||
requested_backends=requested_backends,
|
||||
prioritizer=self._prioritizer,
|
||||
active_adapter=active_adapter,
|
||||
points=points,
|
||||
)
|
||||
# Split the serialized_grad_inputs for streaming and respond
|
||||
for tensor in self._serialize_grads(grads, requested_backends, metadata):
|
||||
@ -553,6 +574,7 @@ class TransformerConnectionHandler(ConnectionHandler):
|
||||
async def _rpc_forward(
|
||||
*flat_tensors: torch.Tensor,
|
||||
requested_backends: Sequence[TransformerBackend],
|
||||
active_adapter: str = "",
|
||||
prioritizer: TaskPrioritizerBase,
|
||||
points: int = 0,
|
||||
) -> torch.Tensor:
|
||||
@ -585,6 +607,7 @@ async def _rpc_forward(
|
||||
)
|
||||
(hidden_states,) = await backend.forward_pool.submit_task(
|
||||
hidden_states,
|
||||
active_adapter,
|
||||
priority=priority,
|
||||
)
|
||||
assert isinstance(hidden_states, torch.Tensor)
|
||||
@ -598,6 +621,7 @@ async def _rpc_forward(
|
||||
async def _rpc_backward(
|
||||
*flat_tensors: torch.Tensor,
|
||||
requested_backends: Sequence[TransformerBackend],
|
||||
active_adapter: str = "",
|
||||
prioritizer: TaskPrioritizerBase,
|
||||
points: int = 0,
|
||||
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
|
||||
@ -623,7 +647,7 @@ async def _rpc_backward(
|
||||
priority = prioritizer.prioritize(
|
||||
inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
|
||||
)
|
||||
(inputs,) = await backend.forward_pool.submit_task(inputs, priority=priority)
|
||||
(inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority)
|
||||
|
||||
assert isinstance(inputs, torch.Tensor)
|
||||
|
||||
@ -639,7 +663,7 @@ async def _rpc_backward(
|
||||
priority = prioritizer.prioritize(
|
||||
inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
|
||||
)
|
||||
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, priority=priority)
|
||||
(grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority)
|
||||
|
||||
assert isinstance(grad_outputs, torch.Tensor)
|
||||
if not is_dummy(prompt):
|
||||
|
@ -81,6 +81,7 @@ class Server:
|
||||
dht_client_mode: Optional[bool] = None,
|
||||
use_relay: bool = True,
|
||||
use_auto_relay: bool = True,
|
||||
adapters: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
|
||||
@ -218,6 +219,8 @@ class Server:
|
||||
self.mean_balance_check_period = mean_balance_check_period
|
||||
self.mean_block_selection_delay = mean_block_selection_delay
|
||||
|
||||
self.adapters = adapters
|
||||
|
||||
self.stop = threading.Event()
|
||||
|
||||
def _choose_num_blocks(self) -> int:
|
||||
@ -291,6 +294,7 @@ class Server:
|
||||
quant_type=self.quant_type,
|
||||
tensor_parallel_devices=self.tensor_parallel_devices,
|
||||
should_validate_reachability=self.should_validate_reachability,
|
||||
adapters=self.adapters,
|
||||
start=True,
|
||||
)
|
||||
try:
|
||||
@ -384,6 +388,7 @@ class ModuleContainer(threading.Thread):
|
||||
quant_type: QuantType,
|
||||
tensor_parallel_devices: Sequence[torch.device],
|
||||
should_validate_reachability: bool,
|
||||
adapters: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> ModuleContainer:
|
||||
module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
|
||||
@ -391,6 +396,7 @@ class ModuleContainer(threading.Thread):
|
||||
module_uids,
|
||||
dht,
|
||||
ServerState.JOINING,
|
||||
adapters=adapters,
|
||||
throughput=throughput,
|
||||
update_period=update_period,
|
||||
expiration=expiration,
|
||||
@ -415,7 +421,19 @@ class ModuleContainer(threading.Thread):
|
||||
cache_dir=cache_dir,
|
||||
max_disk_space=max_disk_space,
|
||||
)
|
||||
block = convert_block(block, block_config, tensor_parallel_devices, device, quant_type, freeze=True)
|
||||
block = convert_block(
|
||||
block,
|
||||
block_index,
|
||||
block_config,
|
||||
tensor_parallel_devices,
|
||||
device,
|
||||
quant_type,
|
||||
adapters=adapters,
|
||||
freeze=True,
|
||||
use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir,
|
||||
max_disk_space=max_disk_space,
|
||||
)
|
||||
blocks[module_uid] = TransformerBackend(
|
||||
module_uid,
|
||||
block,
|
||||
@ -452,6 +470,7 @@ class ModuleContainer(threading.Thread):
|
||||
expiration_time=get_dht_time() + expiration,
|
||||
state=ServerState.OFFLINE,
|
||||
throughput=throughput,
|
||||
adapters=adapters,
|
||||
)
|
||||
logger.info(f"Announced that blocks {module_uids} are offline")
|
||||
raise
|
||||
@ -465,6 +484,7 @@ class ModuleContainer(threading.Thread):
|
||||
dht,
|
||||
dht_prefix,
|
||||
blocks,
|
||||
adapters=adapters,
|
||||
throughput=throughput,
|
||||
update_period=update_period,
|
||||
expiration=expiration,
|
||||
@ -480,6 +500,7 @@ class ModuleContainer(threading.Thread):
|
||||
inference_max_length: int,
|
||||
num_handlers: int,
|
||||
throughput: float,
|
||||
adapters: Optional[Sequence[str]],
|
||||
update_period: float,
|
||||
expiration: Optional[float] = None,
|
||||
request_timeout: float,
|
||||
@ -517,6 +538,7 @@ class ModuleContainer(threading.Thread):
|
||||
list(self.module_backends.keys()),
|
||||
dht,
|
||||
ServerState.ONLINE,
|
||||
adapters=adapters,
|
||||
throughput=throughput,
|
||||
update_period=update_period,
|
||||
expiration=expiration,
|
||||
@ -616,6 +638,7 @@ class ModuleAnnouncerThread(threading.Thread):
|
||||
module_uids: List[str],
|
||||
dht: DHT,
|
||||
state: ServerState,
|
||||
adapters: Optional[Sequence[str]],
|
||||
*,
|
||||
throughput: float,
|
||||
update_period: float = 30,
|
||||
@ -626,6 +649,7 @@ class ModuleAnnouncerThread(threading.Thread):
|
||||
self.module_uids = module_uids
|
||||
self.dht = dht
|
||||
self.state = state
|
||||
self.adapters = adapters
|
||||
self.throughput = throughput
|
||||
self.update_period = update_period
|
||||
self.expiration = expiration
|
||||
@ -639,6 +663,7 @@ class ModuleAnnouncerThread(threading.Thread):
|
||||
expiration_time=get_dht_time() + self.expiration,
|
||||
state=self.state,
|
||||
throughput=self.throughput,
|
||||
adapters=self.adapters,
|
||||
)
|
||||
if self.stop.wait(self.update_period):
|
||||
break
|
||||
|
@ -172,7 +172,7 @@ def measure_compute_rps(
|
||||
tensor_parallel_devices = (device,)
|
||||
with torch.inference_mode():
|
||||
block = config.block_class(config).to(dtype)
|
||||
block = convert_block(block, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
|
||||
block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
|
||||
|
||||
cache = None
|
||||
elapsed = 0
|
||||
|
@ -3,8 +3,7 @@ Tools for converting transformer blocks, applying quantization and/or tensor par
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Sequence
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
import tensor_parallel as tp
|
||||
import torch
|
||||
@ -13,23 +12,23 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
||||
from tensor_parallel.slicing_configs import get_bloom_config
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from petals.utils.misc import QuantType
|
||||
from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
|
||||
|
||||
use_hivemind_log_handler("in_root_logger")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class QuantType(Enum):
|
||||
NONE = 0
|
||||
INT8 = 1 # 8-bit as in the LLM.int8() paper
|
||||
NF4 = 2 # 4-bit as in the QLoRA paper
|
||||
|
||||
|
||||
def convert_block(
|
||||
block: nn.Module,
|
||||
block_index: int,
|
||||
config: PretrainedConfig,
|
||||
tensor_parallel_devices: Sequence[torch.device],
|
||||
output_device: torch.device,
|
||||
quant_type: QuantType,
|
||||
freeze: bool = True,
|
||||
adapters: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> tp.TensorParallel:
|
||||
"""
|
||||
Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization
|
||||
@ -56,6 +55,16 @@ def convert_block(
|
||||
for shard, device in zip(block.module_shards, block.devices):
|
||||
shard.to(device)
|
||||
|
||||
if adapters:
|
||||
create_lora_adapter(block, quant_type=quant_type)
|
||||
for adapter_name in adapters:
|
||||
adapter_config, adapter_state_dict = load_peft(
|
||||
adapter_name,
|
||||
block_idx=block_index,
|
||||
**kwargs,
|
||||
)
|
||||
add_adapter_to_block(block, block_index, adapter_name, adapter_config, adapter_state_dict)
|
||||
|
||||
return block
|
||||
|
||||
|
||||
|
@ -1,5 +1,14 @@
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class QuantType(Enum):
|
||||
NONE = 0
|
||||
INT8 = 1 # 8-bit as in the LLM.int8() paper
|
||||
NF4 = 2 # 4-bit as in the QLoRA paper
|
||||
|
||||
|
||||
DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter parameters
|
||||
|
||||
|
||||
|
208
src/petals/utils/peft.py
Normal file
208
src/petals/utils/peft.py
Normal file
@ -0,0 +1,208 @@
|
||||
import re
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch.nn as nn
|
||||
from hivemind.utils.logging import get_logger
|
||||
from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
|
||||
from peft.tuners import lora
|
||||
from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import load_file
|
||||
from transformers.utils import get_file_from_repo
|
||||
|
||||
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
|
||||
from petals.utils.misc import QuantType
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def check_peft_repository(repo_id: str) -> bool:
|
||||
fs = HfFileSystem()
|
||||
list_of_files = fs.glob(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}", detail=False)
|
||||
return len(list_of_files) > 0
|
||||
|
||||
|
||||
def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None):
|
||||
tensors = dict()
|
||||
is_tensors_found = dict()
|
||||
common_layer_patter_re = (
|
||||
".+\." + "".join(f"({common_name})?" for common_name in COMMON_LAYERS_PATTERN) + f"\.({block_idx})?\..+"
|
||||
)
|
||||
with safe_open(filepath, framework=framework, device=device) as f:
|
||||
for k in f.keys():
|
||||
if re.match(common_layer_patter_re, k):
|
||||
is_tensors_found[block_idx] = True
|
||||
tensors[k] = f.get_tensor(k)
|
||||
if not is_tensors_found.get(block_idx, False):
|
||||
logger.warning(f"There is no peft weights for block {block_idx}")
|
||||
return tensors
|
||||
|
||||
|
||||
def get_adapter_from_repo(repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, **kwargs):
|
||||
config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
|
||||
if config_path is None:
|
||||
raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
|
||||
config = PeftConfig.from_json_file(config_path)
|
||||
|
||||
weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, **kwargs)
|
||||
if weight_path is None:
|
||||
raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
|
||||
if block_idx is None:
|
||||
return config, load_file(weight_path)
|
||||
return config, load_specific_module(block_idx, weight_path, device=device)
|
||||
|
||||
|
||||
def load_peft(
|
||||
repo_id: str,
|
||||
block_idx: Optional[int] = None,
|
||||
device: Optional[int] = None,
|
||||
*,
|
||||
revision: Optional[str] = None,
|
||||
use_auth_token: Optional[str] = None,
|
||||
cache_dir: str,
|
||||
max_disk_space: Optional[int] = None,
|
||||
delay: float = 30,
|
||||
):
|
||||
# TODO: Check is it possible to add safetensors loading inside petals/server/from_pretrained.py and reuse it here
|
||||
|
||||
if not check_peft_repository(repo_id):
|
||||
raise ValueError(f"Repo: {repo_id} doesn't have safetensors inside for a safe loading.")
|
||||
|
||||
try:
|
||||
with allow_cache_reads(cache_dir):
|
||||
return get_adapter_from_repo(
|
||||
repo_id,
|
||||
block_idx,
|
||||
device,
|
||||
revision=revision,
|
||||
use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True)
|
||||
|
||||
while True:
|
||||
try:
|
||||
with allow_cache_writes(cache_dir):
|
||||
config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
|
||||
config_file_size = get_hf_file_metadata(config_url, token=use_auth_token).size
|
||||
weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
|
||||
weight_file_size = get_hf_file_metadata(weight_url, token=use_auth_token).size
|
||||
|
||||
file_size = config_file_size + weight_file_size
|
||||
if file_size is not None:
|
||||
free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
|
||||
else:
|
||||
logger.warning(f"Failed to fetch size from peft repo {repo_id}")
|
||||
|
||||
return get_adapter_from_repo(
|
||||
repo_id,
|
||||
block_idx,
|
||||
device,
|
||||
revision=revision,
|
||||
use_auth_token=use_auth_token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
|
||||
def create_lora_adapter(block, quant_type: QuantType):
|
||||
for name, module in block.named_modules():
|
||||
for child_name, child in module.named_children():
|
||||
lora_wrapped_child = None
|
||||
if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)):
|
||||
continue
|
||||
if quant_type == QuantType.INT8:
|
||||
kwargs = {
|
||||
"has_fp16_weights": False,
|
||||
"threshold": 6.0,
|
||||
"bias": hasattr(child, "bias") and child.bias is not None,
|
||||
}
|
||||
lora_wrapped_child = lora.Linear8bitLt(
|
||||
child_name,
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
**kwargs,
|
||||
)
|
||||
elif quant_type == QuantType.NF4:
|
||||
kwargs = {
|
||||
"compress_statistics": True,
|
||||
"quant_type": "nf4",
|
||||
"blocksize": 64,
|
||||
"bias": hasattr(child, "bias") and child.bias is not None,
|
||||
}
|
||||
lora_wrapped_child = lora.Linear4bit(
|
||||
child_name,
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
bias = hasattr(child, "bias") and child.bias is not None
|
||||
lora_wrapped_child = lora.Linear(
|
||||
child_name,
|
||||
child.in_features,
|
||||
child.out_features,
|
||||
bias=bias,
|
||||
)
|
||||
if lora_wrapped_child:
|
||||
lora_wrapped_child.active_adapter = None
|
||||
lora_wrapped_child.weight = child.weight
|
||||
lora_wrapped_child.bias = child.bias
|
||||
for p in lora_wrapped_child.parameters():
|
||||
p.requires_grad = False
|
||||
setattr(module, child_name, lora_wrapped_child)
|
||||
|
||||
|
||||
def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict):
|
||||
assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters"
|
||||
for name, module in block.named_modules():
|
||||
for child_name, child in module.named_children():
|
||||
if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
|
||||
continue
|
||||
|
||||
if child_name in peft_config["target_modules"] or (
|
||||
isinstance(peft_config["target_modules"], str)
|
||||
and re.fullmatch(peft_config["target_modules"], child_name)
|
||||
):
|
||||
is_lora_a_loaded = False
|
||||
is_lora_b_loaded = False
|
||||
for peft_key in peft_state_dict:
|
||||
if peft_key.find(child_name) == -1:
|
||||
continue
|
||||
|
||||
if adapter_name not in child.lora_A:
|
||||
child.update_layer(
|
||||
adapter_name,
|
||||
peft_config["r"],
|
||||
peft_config["lora_alpha"],
|
||||
lora_dropout=peft_config["lora_dropout"],
|
||||
init_lora_weights=peft_config["init_lora_weights"],
|
||||
)
|
||||
child.train(False)
|
||||
if peft_config["lora_dropout"] > 0:
|
||||
logger.warning("Loading LoRA config with dropout enabled; this server will disable dropout")
|
||||
for p in child.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
if peft_key.endswith(".lora_A.weight"):
|
||||
child.lora_A[adapter_name].weight.data = peft_state_dict[peft_key]
|
||||
is_lora_a_loaded = True
|
||||
elif peft_key.endswith(".lora_A.bias"):
|
||||
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
|
||||
elif peft_key.endswith(".lora_B.weight"):
|
||||
child.lora_B[adapter_name].weight.data = peft_state_dict[peft_key]
|
||||
is_lora_b_loaded = True
|
||||
elif peft_key.endswith(".lora_B.bias"):
|
||||
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
|
||||
|
||||
if is_lora_a_loaded and is_lora_b_loaded:
|
||||
logger.info(f"Loading {adapter_name} for block {block_index}.{child_name} is ended successfully")
|
@ -1,3 +1,4 @@
|
||||
import peft
|
||||
import pytest
|
||||
import torch
|
||||
import transformers
|
||||
@ -12,11 +13,16 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
@pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,))
|
||||
@pytest.mark.parametrize("pass_empty_tensors", (True, False))
|
||||
def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
|
||||
def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
|
||||
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
||||
model = DistributedBloomForCausalLM.from_pretrained(
|
||||
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
||||
MODEL_NAME,
|
||||
initial_peers=INITIAL_PEERS,
|
||||
low_cpu_mem_usage=True,
|
||||
torch_dtype=torch.float32,
|
||||
active_adapter=ADAPTER_NAME if use_peft else None,
|
||||
)
|
||||
config = model.config
|
||||
assert isinstance(model, DistributedBloomForCausalLM)
|
||||
@ -54,6 +60,9 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato
|
||||
ref_model = transformers.BloomForCausalLM.from_pretrained(
|
||||
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
||||
)
|
||||
if use_peft:
|
||||
ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
|
||||
ref_model.train(False)
|
||||
if config.vocab_size < ref_model.config.vocab_size:
|
||||
ref_model.resize_token_embeddings(config.vocab_size)
|
||||
logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")
|
||||
|
66
tests/test_peft.py
Normal file
66
tests/test_peft.py
Normal file
@ -0,0 +1,66 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from petals.utils.peft import check_peft_repository, load_peft
|
||||
|
||||
UNSAFE_PEFT_REPO = "artek0chumak/bloom-560m-unsafe-peft"
|
||||
SAFE_PEFT_REPO = "artek0chumak/bloom-560m-safe-peft"
|
||||
TMP_CACHE_DIR = "tmp_cache/"
|
||||
|
||||
|
||||
def clear_dir(path_to_dir):
|
||||
shutil.rmtree(path_to_dir)
|
||||
os.mkdir(path_to_dir)
|
||||
|
||||
|
||||
def dir_empty(path_to_dir):
|
||||
files = os.listdir(path_to_dir)
|
||||
return len(files) == 0
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_check_peft():
|
||||
assert not check_peft_repository(UNSAFE_PEFT_REPO), "NOSAFE_PEFT_REPO is safe to load."
|
||||
assert check_peft_repository(SAFE_PEFT_REPO), "SAFE_PEFT_REPO is not safe to load."
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_load_noncached(tmpdir):
|
||||
clear_dir(tmpdir)
|
||||
with pytest.raises(Exception):
|
||||
load_peft(UNSAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||
|
||||
assert dir_empty(tmpdir), "UNSAFE_PEFT_REPO is loaded"
|
||||
|
||||
load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||
|
||||
assert not dir_empty(tmpdir), "SAFE_PEFT_REPO is not loaded"
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_load_cached(tmpdir):
|
||||
clear_dir(tmpdir)
|
||||
snapshot_download(SAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||
|
||||
load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_load_layer_exists(tmpdir):
|
||||
clear_dir(tmpdir)
|
||||
|
||||
load_peft(SAFE_PEFT_REPO, block_idx=2, cache_dir=tmpdir)
|
||||
|
||||
|
||||
@pytest.mark.forked
|
||||
def test_load_layer_nonexists(tmpdir):
|
||||
clear_dir(tmpdir)
|
||||
|
||||
load_peft(
|
||||
SAFE_PEFT_REPO,
|
||||
block_idx=1337,
|
||||
cache_dir=tmpdir,
|
||||
)
|
@ -11,3 +11,5 @@ if not MODEL_NAME:
|
||||
raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
|
||||
|
||||
REF_NAME = os.environ.get("REF_NAME")
|
||||
|
||||
ADAPTER_NAME = os.environ.get("ADAPTER_NAME")
|
||||
|
Loading…
Reference in New Issue
Block a user