diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index fbb5b72..b98667e 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -33,10 +33,11 @@ jobs: run: | export MODEL_NAME=bigscience/bloom-560m export REF_NAME=bigscience/bloom-560m + export ADAPTER_NAME=artek0chumak/bloom-560m-safe-peft python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \ --new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \ - --torch_dtype float32 --compression NONE --attn_cache_tokens 2048 &> server1.log & + --torch_dtype float32 --compression NONE --attn_cache_tokens 2048 --adapters $ADAPTER_NAME &> server1.log & SERVER1_PID=$! sleep 5 # wait for the first server to initialize DHT @@ -45,17 +46,17 @@ jobs: # ^-- server 1 multiaddr is determined by --identity and --host_maddrs python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:22 \ - --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server2.log & + --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server2.log & SERVER2_PID=$! sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve - python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:5 \ - --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server3.log & + python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:15 \ + --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server3.log & SERVER3_PID=$! python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \ - --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server4.log & + --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --adapters $ADAPTER_NAME &> server4.log & SERVER4_PID=$! tail -n 100 -f server*.log & diff --git a/setup.cfg b/setup.cfg index 6242651..f56a7cc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -46,6 +46,8 @@ install_requires = cpufeature>=0.2.0 packaging>=20.9 sentencepiece>=0.1.99 + peft@git+https://github.com/huggingface/peft@5884bdbea49e5e71e2cd06ecfa484bb635063735 + safetensors>=0.3.1 [options.extras_require] dev = diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 1d3c438..6b3fde8 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -146,6 +146,8 @@ def main(): help="Skip checking this server's reachability via health.petals.ml " "when connecting to the public swarm. If you connect to a private swarm, " "the check is skipped by default. Use this option only if you know what you are doing") + + parser.add_argument("--adapters", nargs='+', default=None, help="List of pretrained LoRA adapters that can be used for inference or training.") # fmt:on args = vars(parser.parse_args()) diff --git a/src/petals/client/remote_sequential.py b/src/petals/client/remote_sequential.py index 745b5c1..6ae664a 100644 --- a/src/petals/client/remote_sequential.py +++ b/src/petals/client/remote_sequential.py @@ -28,6 +28,7 @@ class RemoteSequential(nn.Module): dht: Optional[DHT] = None, start_block: Optional[int] = None, end_block: Optional[int] = None, + **kwargs, ): super().__init__() self.config = config @@ -41,7 +42,7 @@ class RemoteSequential(nn.Module): if end_block is None: end_block = self.config.num_hidden_layers block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block)) - sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht) + sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs) self.sequence_manager = sequence_manager def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY): diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 88d6d16..fc505cc 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -43,6 +43,7 @@ class SequenceManagerConfig: min_backoff: float = 1 # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1) max_backoff: float = 60 # limit maximal sleep time between retries to this value ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds + active_adapter: Optional[str] = None @dataclasses.dataclass @@ -78,6 +79,7 @@ class RemoteSequenceManager: *, dht: Optional[DHT] = None, state: Optional[SequenceManagerState] = None, + active_adapter: Optional[str] = None, ): assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`" assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..." @@ -115,7 +117,9 @@ class RemoteSequenceManager: if state.sequence_info.last_updated_time is None: # Pre-fetch module infos in DHT in parallel with .from_pretrained(), then use cached records # in the first _update() instead of the latest ones. This makes the first .update() faster. - petals.dht_utils.get_remote_module_infos(self.dht, self.block_uids, latest=True, return_future=True) + petals.dht_utils.get_remote_module_infos( + self.dht, self.block_uids, active_adapter=active_adapter, latest=True, return_future=True + ) self._need_latest_infos = False else: assert block_uids == state.sequence_info.block_uids @@ -179,7 +183,7 @@ class RemoteSequenceManager: def _update(self): """Perform an immediate and synchronous refresh, may take time""" new_block_infos = petals.dht_utils.get_remote_module_infos( - self.dht, self.block_uids, latest=self._need_latest_infos + self.dht, self.block_uids, active_adapter=self.config.active_adapter, latest=self._need_latest_infos ) self._need_latest_infos = True # All future _update() should use latest infos @@ -307,7 +311,7 @@ class RemoteSequenceManager: :param kwargs: additional request context, such as remote peer ID :returns: msgpack-serialized metadata dict that will be passed alongside a given request """ - return dict(points=self.policy.get_points(protocol, *args, **kwargs)) + return dict(points=self.policy.get_points(protocol, *args, **kwargs), active_adapter=self.config.active_adapter) def shutdown(self): self._thread.shutdown() diff --git a/src/petals/data_structures.py b/src/petals/data_structures.py index 80b8f62..254faae 100644 --- a/src/petals/data_structures.py +++ b/src/petals/data_structures.py @@ -3,7 +3,7 @@ from __future__ import annotations import dataclasses from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Tuple +from typing import Any, Dict, Optional, Tuple from hivemind import PeerID from hivemind.moe.expert_uid import ExpertUID @@ -57,3 +57,4 @@ class InferenceMetadata: uid: ExpertUID prefix_length: int cache_handles: Tuple[Handle, ...] + active_adapter: Optional[str] diff --git a/src/petals/dht_utils.py b/src/petals/dht_utils.py index 177b2f6..99316f2 100644 --- a/src/petals/dht_utils.py +++ b/src/petals/dht_utils.py @@ -22,6 +22,7 @@ def declare_active_modules( expiration_time: DHTExpiration, state: ServerState, throughput: float, + adapters: Optional[Sequence[str]] = None, wait: bool = True, ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]: """ @@ -39,6 +40,7 @@ def declare_active_modules( uids = list(uids) for uid in uids: assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid + return dht.run_coroutine( partial( _declare_active_modules, @@ -46,6 +48,7 @@ def declare_active_modules( expiration_time=expiration_time, state=state, throughput=throughput, + adapters=list(adapters or []), ), return_future=not wait, ) @@ -58,12 +61,13 @@ async def _declare_active_modules( expiration_time: DHTExpiration, state: ServerState, throughput: float, + adapters: List[str], ) -> Dict[ModuleUID, bool]: num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers) return await node.store_many( keys=uids, subkeys=[dht.peer_id.to_base58()] * len(uids), - values=[(state.value, throughput)] * len(uids), + values=[(state.value, throughput, dict(adapters=adapters))] * len(uids), expiration_time=expiration_time, num_workers=num_workers, ) @@ -73,18 +77,30 @@ def get_remote_module_infos( dht: DHT, uids: Sequence[ModuleUID], expiration_time: Optional[DHTExpiration] = None, + active_adapter: Optional[str] = None, *, latest: bool = False, return_future: bool = False, ) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]: return dht.run_coroutine( - partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time, latest=latest), + partial( + _get_remote_module_infos, + uids=uids, + active_adapter=active_adapter, + expiration_time=expiration_time, + latest=latest, + ), return_future=return_future, ) async def _get_remote_module_infos( - dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration], latest: bool + dht: DHT, + node: DHTNode, + uids: List[ModuleUID], + active_adapter: Optional[str], + expiration_time: Optional[DHTExpiration], + latest: bool, ) -> List[Optional[RemoteModuleInfo]]: if latest: assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both" @@ -105,7 +121,13 @@ async def _get_remote_module_infos( for peer_id, server_info in metadata.value.items(): try: peer_id = PeerID.from_base58(peer_id) - state, throughput = server_info.value + state, throughput = server_info.value[:2] + extra_info = server_info.value[2] if len(server_info.value) > 2 else {} + adapters = extra_info.get("adapters", []) + if bool(active_adapter) and active_adapter not in adapters: + logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}") + continue + if not ( isinstance(state, int) and isinstance(throughput, float) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index adcd617..9e81170 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -2,8 +2,9 @@ from __future__ import annotations from collections import Counter from itertools import chain -from typing import Any, Dict, Optional, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple, Union +import peft import torch from hivemind import BatchTensorDescriptor, TensorDescriptor from hivemind.moe.expert_uid import ExpertUID @@ -80,6 +81,18 @@ class TransformerBackend(ModuleBackend): cache_tensors.extend((keys, values)) return cache_tensors + def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: + *inputs, active_adapter = inputs + if not self.load_adapter_(active_adapter): + raise KeyError(f"Could not find adapter {active_adapter}; perhaps it is not loaded") + return super().forward(*inputs) + + def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: + *inputs, active_adapter = inputs + if not self.load_adapter_(active_adapter): + raise KeyError(f"Could not find adapter {active_adapter}; perhaps it is not loaded") + return super().backward(*inputs) + @torch.inference_mode() def inference_step( self, @@ -88,6 +101,8 @@ class TransformerBackend(ModuleBackend): inference_info: InferenceMetadata, ) -> Tuple[torch.Tensor, ...]: assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" + if not self.load_adapter_(inference_info.active_adapter): + raise KeyError(f"Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded") with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors: self._reorder_cache_inplace(cache_tensors, hypo_ids) layer_past = self._select_layer_past(cache_tensors, inference_info.prefix_length) @@ -139,6 +154,16 @@ class TransformerBackend(ModuleBackend): for p in self.module.parameters(): p.data = dummy + def load_adapter_(self, active_adapter: Optional[str] = None) -> bool: + """Activate a given adapter set if available. Return True if available (or no adapter), False if missing""" + adapter_was_loaded = False + for layer in self.module.modules(): # select adapter set -- leave empty string for no adapter + if isinstance(layer, (peft.tuners.lora.Linear, peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)): + layer.active_adapter = active_adapter # empty string for no adapter + if active_adapter in layer.lora_A.keys(): + adapter_was_loaded = True + return adapter_was_loaded or not active_adapter + def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]): """Replace each backend's rpc_inference pools with a combined pool runs multiple blocks in one call""" diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 65ee5c6..d7295ca 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -141,6 +141,7 @@ class TransformerConnectionHandler(ConnectionHandler): metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) max_length = metadata.get("max_length") + active_adapter = metadata.get("active_adapter", "") points = metadata.get("points", 0) session_id = metadata.get("session_id") @@ -201,7 +202,7 @@ class TransformerConnectionHandler(ConnectionHandler): ) inference_infos = tuple( - InferenceMetadata(uid, prefix_length, tuple(handles)) + InferenceMetadata(uid, prefix_length, tuple(handles), active_adapter) for uid, handles in zip(requested_uids, cache_handles) ) @@ -354,13 +355,18 @@ class TransformerConnectionHandler(ConnectionHandler): requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} + active_adapter = metadata.get("active_adapter", "") points = metadata.get("points", 0) assert isinstance( points, (float, int) ), f"rpc_forward should have number of points as number or None, got {points}" hidden_states = await _rpc_forward( - *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points + *flat_inputs, + requested_backends=requested_backends, + prioritizer=self._prioritizer, + active_adapter=active_adapter, + points=points, ) return runtime_pb2.ExpertResponse( tensors=self._serialize_outputs(hidden_states, requested_backends, metadata) @@ -376,13 +382,18 @@ class TransformerConnectionHandler(ConnectionHandler): self._log_request("rpc_forward_stream", requested_uids, context) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) + active_adapter = metadata.get("active_adapter", "") points = metadata.get("points", 0) assert isinstance( points, (float, int) ), f"rpc_forward_stream should have number of points as number or None, got {points}" hidden_states = await _rpc_forward( - *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points + *flat_inputs, + requested_backends=requested_backends, + prioritizer=self._prioritizer, + active_adapter=active_adapter, + points=points, ) # Split the serialized_output for streaming and respond to client @@ -422,13 +433,18 @@ class TransformerConnectionHandler(ConnectionHandler): requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} + active_adapter = metadata.get("active_adapter", "") points = metadata.get("points", 0) assert isinstance( points, (float, int) ), f"rpc_backward should have number of points as number or None, got {points}" grads = await _rpc_backward( - *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points + *flat_tensors, + requested_backends=requested_backends, + prioritizer=self._prioritizer, + active_adapter=active_adapter, + points=points, ) return runtime_pb2.ExpertResponse(tensors=self._serialize_grads(grads, requested_backends, metadata)) @@ -442,13 +458,18 @@ class TransformerConnectionHandler(ConnectionHandler): self._log_request("rpc_backward_stream", requested_uids, context) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) + active_adapter = metadata.get("active_adapter", "") points = metadata.get("points", 0) assert isinstance( points, (float, int) ), f"rpc_backward_stream should have number of points as number or None, got {points}" grads = await _rpc_backward( - *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points + *flat_tensors, + requested_backends=requested_backends, + prioritizer=self._prioritizer, + active_adapter=active_adapter, + points=points, ) # Split the serialized_grad_inputs for streaming and respond for tensor in self._serialize_grads(grads, requested_backends, metadata): @@ -553,6 +574,7 @@ class TransformerConnectionHandler(ConnectionHandler): async def _rpc_forward( *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], + active_adapter: str = "", prioritizer: TaskPrioritizerBase, points: int = 0, ) -> torch.Tensor: @@ -585,6 +607,7 @@ async def _rpc_forward( ) (hidden_states,) = await backend.forward_pool.submit_task( hidden_states, + active_adapter, priority=priority, ) assert isinstance(hidden_states, torch.Tensor) @@ -598,6 +621,7 @@ async def _rpc_forward( async def _rpc_backward( *flat_tensors: torch.Tensor, requested_backends: Sequence[TransformerBackend], + active_adapter: str = "", prioritizer: TaskPrioritizerBase, points: int = 0, ) -> Union[torch.Tensor, Sequence[torch.Tensor]]: @@ -623,7 +647,7 @@ async def _rpc_backward( priority = prioritizer.prioritize( inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward" ) - (inputs,) = await backend.forward_pool.submit_task(inputs, priority=priority) + (inputs,) = await backend.forward_pool.submit_task(inputs, active_adapter, priority=priority) assert isinstance(inputs, torch.Tensor) @@ -639,7 +663,7 @@ async def _rpc_backward( priority = prioritizer.prioritize( inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward" ) - (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, priority=priority) + (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, active_adapter, priority=priority) assert isinstance(grad_outputs, torch.Tensor) if not is_dummy(prompt): diff --git a/src/petals/server/server.py b/src/petals/server/server.py index eddb76e..643bf1b 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -81,6 +81,7 @@ class Server: dht_client_mode: Optional[bool] = None, use_relay: bool = True, use_auto_relay: bool = True, + adapters: Optional[List[str]] = None, **kwargs, ): """Create a server with one or more bloom blocks. See run_server.py for documentation.""" @@ -218,6 +219,8 @@ class Server: self.mean_balance_check_period = mean_balance_check_period self.mean_block_selection_delay = mean_block_selection_delay + self.adapters = adapters + self.stop = threading.Event() def _choose_num_blocks(self) -> int: @@ -291,6 +294,7 @@ class Server: quant_type=self.quant_type, tensor_parallel_devices=self.tensor_parallel_devices, should_validate_reachability=self.should_validate_reachability, + adapters=self.adapters, start=True, ) try: @@ -384,6 +388,7 @@ class ModuleContainer(threading.Thread): quant_type: QuantType, tensor_parallel_devices: Sequence[torch.device], should_validate_reachability: bool, + adapters: Optional[List[str]] = None, **kwargs, ) -> ModuleContainer: module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices] @@ -391,6 +396,7 @@ class ModuleContainer(threading.Thread): module_uids, dht, ServerState.JOINING, + adapters=adapters, throughput=throughput, update_period=update_period, expiration=expiration, @@ -415,7 +421,19 @@ class ModuleContainer(threading.Thread): cache_dir=cache_dir, max_disk_space=max_disk_space, ) - block = convert_block(block, block_config, tensor_parallel_devices, device, quant_type, freeze=True) + block = convert_block( + block, + block_index, + block_config, + tensor_parallel_devices, + device, + quant_type, + adapters=adapters, + freeze=True, + use_auth_token=use_auth_token, + cache_dir=cache_dir, + max_disk_space=max_disk_space, + ) blocks[module_uid] = TransformerBackend( module_uid, block, @@ -452,6 +470,7 @@ class ModuleContainer(threading.Thread): expiration_time=get_dht_time() + expiration, state=ServerState.OFFLINE, throughput=throughput, + adapters=adapters, ) logger.info(f"Announced that blocks {module_uids} are offline") raise @@ -465,6 +484,7 @@ class ModuleContainer(threading.Thread): dht, dht_prefix, blocks, + adapters=adapters, throughput=throughput, update_period=update_period, expiration=expiration, @@ -480,6 +500,7 @@ class ModuleContainer(threading.Thread): inference_max_length: int, num_handlers: int, throughput: float, + adapters: Optional[Sequence[str]], update_period: float, expiration: Optional[float] = None, request_timeout: float, @@ -517,6 +538,7 @@ class ModuleContainer(threading.Thread): list(self.module_backends.keys()), dht, ServerState.ONLINE, + adapters=adapters, throughput=throughput, update_period=update_period, expiration=expiration, @@ -616,6 +638,7 @@ class ModuleAnnouncerThread(threading.Thread): module_uids: List[str], dht: DHT, state: ServerState, + adapters: Optional[Sequence[str]], *, throughput: float, update_period: float = 30, @@ -626,6 +649,7 @@ class ModuleAnnouncerThread(threading.Thread): self.module_uids = module_uids self.dht = dht self.state = state + self.adapters = adapters self.throughput = throughput self.update_period = update_period self.expiration = expiration @@ -639,6 +663,7 @@ class ModuleAnnouncerThread(threading.Thread): expiration_time=get_dht_time() + self.expiration, state=self.state, throughput=self.throughput, + adapters=self.adapters, ) if self.stop.wait(self.update_period): break diff --git a/src/petals/server/throughput.py b/src/petals/server/throughput.py index 76bbc85..20625e6 100644 --- a/src/petals/server/throughput.py +++ b/src/petals/server/throughput.py @@ -172,7 +172,7 @@ def measure_compute_rps( tensor_parallel_devices = (device,) with torch.inference_mode(): block = config.block_class(config).to(dtype) - block = convert_block(block, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) + block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True) cache = None elapsed = 0 diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index 6b129f5..b1c412e 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -3,8 +3,7 @@ Tools for converting transformer blocks, applying quantization and/or tensor par """ import os import re -from enum import Enum -from typing import Sequence +from typing import List, Optional, Sequence import tensor_parallel as tp import torch @@ -13,23 +12,23 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler from tensor_parallel.slicing_configs import get_bloom_config from transformers import PretrainedConfig +from petals.utils.misc import QuantType +from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft + use_hivemind_log_handler("in_root_logger") logger = get_logger(__name__) -class QuantType(Enum): - NONE = 0 - INT8 = 1 # 8-bit as in the LLM.int8() paper - NF4 = 2 # 4-bit as in the QLoRA paper - - def convert_block( block: nn.Module, + block_index: int, config: PretrainedConfig, tensor_parallel_devices: Sequence[torch.device], output_device: torch.device, quant_type: QuantType, freeze: bool = True, + adapters: Optional[List[str]] = None, + **kwargs, ) -> tp.TensorParallel: """ Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization @@ -56,6 +55,16 @@ def convert_block( for shard, device in zip(block.module_shards, block.devices): shard.to(device) + if adapters: + create_lora_adapter(block, quant_type=quant_type) + for adapter_name in adapters: + adapter_config, adapter_state_dict = load_peft( + adapter_name, + block_idx=block_index, + **kwargs, + ) + add_adapter_to_block(block, block_index, adapter_name, adapter_config, adapter_state_dict) + return block diff --git a/src/petals/utils/misc.py b/src/petals/utils/misc.py index 2f67202..99b246c 100644 --- a/src/petals/utils/misc.py +++ b/src/petals/utils/misc.py @@ -1,5 +1,14 @@ +from enum import Enum + import torch + +class QuantType(Enum): + NONE = 0 + INT8 = 1 # 8-bit as in the LLM.int8() paper + NF4 = 2 # 4-bit as in the QLoRA paper + + DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter parameters diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py new file mode 100644 index 0000000..c551f97 --- /dev/null +++ b/src/petals/utils/peft.py @@ -0,0 +1,208 @@ +import re +import time +from typing import List, Optional + +import bitsandbytes as bnb +import torch.nn as nn +from hivemind.utils.logging import get_logger +from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url +from peft.tuners import lora +from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig +from safetensors import safe_open +from safetensors.torch import load_file +from transformers.utils import get_file_from_repo + +from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for +from petals.utils.misc import QuantType + +logger = get_logger(__name__) + + +def check_peft_repository(repo_id: str) -> bool: + fs = HfFileSystem() + list_of_files = fs.glob(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}", detail=False) + return len(list_of_files) > 0 + + +def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None): + tensors = dict() + is_tensors_found = dict() + common_layer_patter_re = ( + ".+\." + "".join(f"({common_name})?" for common_name in COMMON_LAYERS_PATTERN) + f"\.({block_idx})?\..+" + ) + with safe_open(filepath, framework=framework, device=device) as f: + for k in f.keys(): + if re.match(common_layer_patter_re, k): + is_tensors_found[block_idx] = True + tensors[k] = f.get_tensor(k) + if not is_tensors_found.get(block_idx, False): + logger.warning(f"There is no peft weights for block {block_idx}") + return tensors + + +def get_adapter_from_repo(repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, **kwargs): + config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs) + if config_path is None: + raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}") + config = PeftConfig.from_json_file(config_path) + + weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, **kwargs) + if weight_path is None: + raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}") + if block_idx is None: + return config, load_file(weight_path) + return config, load_specific_module(block_idx, weight_path, device=device) + + +def load_peft( + repo_id: str, + block_idx: Optional[int] = None, + device: Optional[int] = None, + *, + revision: Optional[str] = None, + use_auth_token: Optional[str] = None, + cache_dir: str, + max_disk_space: Optional[int] = None, + delay: float = 30, +): + # TODO: Check is it possible to add safetensors loading inside petals/server/from_pretrained.py and reuse it here + + if not check_peft_repository(repo_id): + raise ValueError(f"Repo: {repo_id} doesn't have safetensors inside for a safe loading.") + + try: + with allow_cache_reads(cache_dir): + return get_adapter_from_repo( + repo_id, + block_idx, + device, + revision=revision, + use_auth_token=use_auth_token, + cache_dir=cache_dir, + local_files_only=False, + ) + except Exception: + logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True) + + while True: + try: + with allow_cache_writes(cache_dir): + config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision) + config_file_size = get_hf_file_metadata(config_url, token=use_auth_token).size + weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision) + weight_file_size = get_hf_file_metadata(weight_url, token=use_auth_token).size + + file_size = config_file_size + weight_file_size + if file_size is not None: + free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space) + else: + logger.warning(f"Failed to fetch size from peft repo {repo_id}") + + return get_adapter_from_repo( + repo_id, + block_idx, + device, + revision=revision, + use_auth_token=use_auth_token, + cache_dir=cache_dir, + local_files_only=False, + ) + except Exception as e: + logger.warning( + f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True + ) + time.sleep(delay) + + +def create_lora_adapter(block, quant_type: QuantType): + for name, module in block.named_modules(): + for child_name, child in module.named_children(): + lora_wrapped_child = None + if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)): + continue + if quant_type == QuantType.INT8: + kwargs = { + "has_fp16_weights": False, + "threshold": 6.0, + "bias": hasattr(child, "bias") and child.bias is not None, + } + lora_wrapped_child = lora.Linear8bitLt( + child_name, + child.in_features, + child.out_features, + **kwargs, + ) + elif quant_type == QuantType.NF4: + kwargs = { + "compress_statistics": True, + "quant_type": "nf4", + "blocksize": 64, + "bias": hasattr(child, "bias") and child.bias is not None, + } + lora_wrapped_child = lora.Linear4bit( + child_name, + child.in_features, + child.out_features, + **kwargs, + ) + else: + bias = hasattr(child, "bias") and child.bias is not None + lora_wrapped_child = lora.Linear( + child_name, + child.in_features, + child.out_features, + bias=bias, + ) + if lora_wrapped_child: + lora_wrapped_child.active_adapter = None + lora_wrapped_child.weight = child.weight + lora_wrapped_child.bias = child.bias + for p in lora_wrapped_child.parameters(): + p.requires_grad = False + setattr(module, child_name, lora_wrapped_child) + + +def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict): + assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters" + for name, module in block.named_modules(): + for child_name, child in module.named_children(): + if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)): + continue + + if child_name in peft_config["target_modules"] or ( + isinstance(peft_config["target_modules"], str) + and re.fullmatch(peft_config["target_modules"], child_name) + ): + is_lora_a_loaded = False + is_lora_b_loaded = False + for peft_key in peft_state_dict: + if peft_key.find(child_name) == -1: + continue + + if adapter_name not in child.lora_A: + child.update_layer( + adapter_name, + peft_config["r"], + peft_config["lora_alpha"], + lora_dropout=peft_config["lora_dropout"], + init_lora_weights=peft_config["init_lora_weights"], + ) + child.train(False) + if peft_config["lora_dropout"] > 0: + logger.warning("Loading LoRA config with dropout enabled; this server will disable dropout") + for p in child.parameters(): + p.requires_grad = False + + if peft_key.endswith(".lora_A.weight"): + child.lora_A[adapter_name].weight.data = peft_state_dict[peft_key] + is_lora_a_loaded = True + elif peft_key.endswith(".lora_A.bias"): + raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}") + elif peft_key.endswith(".lora_B.weight"): + child.lora_B[adapter_name].weight.data = peft_state_dict[peft_key] + is_lora_b_loaded = True + elif peft_key.endswith(".lora_B.bias"): + raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}") + + if is_lora_a_loaded and is_lora_b_loaded: + logger.info(f"Loading {adapter_name} for block {block_index}.{child_name} is ended successfully") diff --git a/tests/test_full_model.py b/tests/test_full_model.py index f2679f2..acd5e6a 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -1,3 +1,4 @@ +import peft import pytest import torch import transformers @@ -12,11 +13,16 @@ logger = get_logger(__name__) @pytest.mark.forked +@pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,)) @pytest.mark.parametrize("pass_empty_tensors", (True, False)) -def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3): +def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3): tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) model = DistributedBloomForCausalLM.from_pretrained( - MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32 + MODEL_NAME, + initial_peers=INITIAL_PEERS, + low_cpu_mem_usage=True, + torch_dtype=torch.float32, + active_adapter=ADAPTER_NAME if use_peft else None, ) config = model.config assert isinstance(model, DistributedBloomForCausalLM) @@ -54,6 +60,9 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato ref_model = transformers.BloomForCausalLM.from_pretrained( REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32 ) + if use_peft: + ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME) + ref_model.train(False) if config.vocab_size < ref_model.config.vocab_size: ref_model.resize_token_embeddings(config.vocab_size) logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}") diff --git a/tests/test_peft.py b/tests/test_peft.py new file mode 100644 index 0000000..7ac4f80 --- /dev/null +++ b/tests/test_peft.py @@ -0,0 +1,66 @@ +import os +import shutil + +import pytest +from huggingface_hub import snapshot_download + +from petals.utils.peft import check_peft_repository, load_peft + +UNSAFE_PEFT_REPO = "artek0chumak/bloom-560m-unsafe-peft" +SAFE_PEFT_REPO = "artek0chumak/bloom-560m-safe-peft" +TMP_CACHE_DIR = "tmp_cache/" + + +def clear_dir(path_to_dir): + shutil.rmtree(path_to_dir) + os.mkdir(path_to_dir) + + +def dir_empty(path_to_dir): + files = os.listdir(path_to_dir) + return len(files) == 0 + + +@pytest.mark.forked +def test_check_peft(): + assert not check_peft_repository(UNSAFE_PEFT_REPO), "NOSAFE_PEFT_REPO is safe to load." + assert check_peft_repository(SAFE_PEFT_REPO), "SAFE_PEFT_REPO is not safe to load." + + +@pytest.mark.forked +def test_load_noncached(tmpdir): + clear_dir(tmpdir) + with pytest.raises(Exception): + load_peft(UNSAFE_PEFT_REPO, cache_dir=tmpdir) + + assert dir_empty(tmpdir), "UNSAFE_PEFT_REPO is loaded" + + load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir) + + assert not dir_empty(tmpdir), "SAFE_PEFT_REPO is not loaded" + + +@pytest.mark.forked +def test_load_cached(tmpdir): + clear_dir(tmpdir) + snapshot_download(SAFE_PEFT_REPO, cache_dir=tmpdir) + + load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir) + + +@pytest.mark.forked +def test_load_layer_exists(tmpdir): + clear_dir(tmpdir) + + load_peft(SAFE_PEFT_REPO, block_idx=2, cache_dir=tmpdir) + + +@pytest.mark.forked +def test_load_layer_nonexists(tmpdir): + clear_dir(tmpdir) + + load_peft( + SAFE_PEFT_REPO, + block_idx=1337, + cache_dir=tmpdir, + ) diff --git a/tests/test_utils.py b/tests/test_utils.py index ee440d6..e40d235 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -11,3 +11,5 @@ if not MODEL_NAME: raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested") REF_NAME = os.environ.get("REF_NAME") + +ADAPTER_NAME = os.environ.get("ADAPTER_NAME")