From 6db11c6483df60bda84243cbd0aaf886a6b8e4d6 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 12 Jul 2023 04:02:31 +0300 Subject: [PATCH] add minimalistic peft test --- src/petals/client/remote_sequential.py | 3 +- src/petals/client/routing/sequence_manager.py | 8 +++-- src/petals/dht_utils.py | 29 ++++++++++++++++--- src/petals/server/backend.py | 17 ++++++----- src/petals/server/handler.py | 2 ++ src/petals/server/server.py | 8 +++++ src/petals/utils/peft.py | 2 ++ tests/test_full_model.py | 9 +++++- tests/test_utils.py | 2 ++ 9 files changed, 65 insertions(+), 15 deletions(-) diff --git a/src/petals/client/remote_sequential.py b/src/petals/client/remote_sequential.py index 745b5c1..c7b65ea 100644 --- a/src/petals/client/remote_sequential.py +++ b/src/petals/client/remote_sequential.py @@ -28,6 +28,7 @@ class RemoteSequential(nn.Module): dht: Optional[DHT] = None, start_block: Optional[int] = None, end_block: Optional[int] = None, + **kwargs ): super().__init__() self.config = config @@ -41,7 +42,7 @@ class RemoteSequential(nn.Module): if end_block is None: end_block = self.config.num_hidden_layers block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block)) - sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht) + sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs) self.sequence_manager = sequence_manager def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY): diff --git a/src/petals/client/routing/sequence_manager.py b/src/petals/client/routing/sequence_manager.py index 88d6d16..624d0b9 100644 --- a/src/petals/client/routing/sequence_manager.py +++ b/src/petals/client/routing/sequence_manager.py @@ -78,6 +78,7 @@ class RemoteSequenceManager: *, dht: Optional[DHT] = None, state: Optional[SequenceManagerState] = None, + extra_metadata: Optional[Dict[str, Any]] = None ): assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`" assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..." @@ -98,6 +99,7 @@ class RemoteSequenceManager: ) assert isinstance(dht, DHT) and dht.is_alive(), "`dht` must be a running hivemind.DHT instance" self.dht = dht + self.extra_metadata = extra_metadata if extra_metadata is not None else {} if state.p2p is None: state.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) @@ -167,7 +169,9 @@ class RemoteSequenceManager: assert isinstance(ix, (int, slice)) if not isinstance(ix, slice): ix = slice(int(ix), int(ix) + 1, 1) - return type(self)(self.config, self.block_uids[ix], dht=self.dht, state=self.state[ix]) + return type(self)( + self.config, self.block_uids[ix], dht=self.dht, state=self.state[ix], extra_metadata=self.extra_metadata + ) def update(self, *, wait: bool): """Run an asynchronous update in background as soon as possible""" @@ -307,7 +311,7 @@ class RemoteSequenceManager: :param kwargs: additional request context, such as remote peer ID :returns: msgpack-serialized metadata dict that will be passed alongside a given request """ - return dict(points=self.policy.get_points(protocol, *args, **kwargs)) + return dict(**self.extra_metadata, points=self.policy.get_points(protocol, *args, **kwargs)) def shutdown(self): self._thread.shutdown() diff --git a/src/petals/dht_utils.py b/src/petals/dht_utils.py index 177b2f6..42b8c94 100644 --- a/src/petals/dht_utils.py +++ b/src/petals/dht_utils.py @@ -22,6 +22,7 @@ def declare_active_modules( expiration_time: DHTExpiration, state: ServerState, throughput: float, + adapters: Optional[Sequence[str]] = None, wait: bool = True, ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]: """ @@ -39,6 +40,7 @@ def declare_active_modules( uids = list(uids) for uid in uids: assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid + return dht.run_coroutine( partial( _declare_active_modules, @@ -46,6 +48,7 @@ def declare_active_modules( expiration_time=expiration_time, state=state, throughput=throughput, + adapters=list(adapters or []), ), return_future=not wait, ) @@ -58,12 +61,13 @@ async def _declare_active_modules( expiration_time: DHTExpiration, state: ServerState, throughput: float, + adapters: List[str], ) -> Dict[ModuleUID, bool]: num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers) return await node.store_many( keys=uids, subkeys=[dht.peer_id.to_base58()] * len(uids), - values=[(state.value, throughput)] * len(uids), + values=[(state.value, throughput, adapters)] * len(uids), expiration_time=expiration_time, num_workers=num_workers, ) @@ -73,18 +77,30 @@ def get_remote_module_infos( dht: DHT, uids: Sequence[ModuleUID], expiration_time: Optional[DHTExpiration] = None, + active_adapter: Optional[str] = None, *, latest: bool = False, return_future: bool = False, ) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]: return dht.run_coroutine( - partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time, latest=latest), + partial( + _get_remote_module_infos, + uids=uids, + active_adapter=active_adapter, + expiration_time=expiration_time, + latest=latest, + ), return_future=return_future, ) async def _get_remote_module_infos( - dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration], latest: bool + dht: DHT, + node: DHTNode, + uids: List[ModuleUID], + active_adapter: Optional[str], + expiration_time: Optional[DHTExpiration], + latest: bool, ) -> List[Optional[RemoteModuleInfo]]: if latest: assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both" @@ -105,7 +121,12 @@ async def _get_remote_module_infos( for peer_id, server_info in metadata.value.items(): try: peer_id = PeerID.from_base58(peer_id) - state, throughput = server_info.value + state, throughput = server_info.value[:2] + available_adapters = server_info.value[2] if len(server_info.value) > 2 else [] + if active_adapter is not None and active_adapter not in available_adapters: + logger.warning(f"Skipped server {peer_id} since it does not have adapter {active_adapter}") + continue + if not ( isinstance(state, int) and isinstance(throughput, float) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index aad7f2d..ca3b938 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -83,13 +83,14 @@ class TransformerBackend(ModuleBackend): def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: *inputs, active_adapter = inputs - if active_adapter and not self.load_adapter_(active_adapter): + print("--forward...") + if not self.load_adapter_(active_adapter): raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded") return super().forward(*inputs) def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: *inputs, active_adapter = inputs - if active_adapter and not self.load_adapter_(active_adapter): + if not self.load_adapter_(active_adapter): raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded") return super().backward(*inputs) @@ -102,7 +103,8 @@ class TransformerBackend(ModuleBackend): ) -> Tuple[torch.Tensor, ...]: assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" - if inference_info.active_adapter and not self.load_adapter_(inference_info.active_adapter): + print("--inference...") + if not self.load_adapter_(inference_info.active_adapter): raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded") with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors: self._reorder_cache_inplace(cache_tensors, hypo_ids) @@ -156,14 +158,15 @@ class TransformerBackend(ModuleBackend): p.data = dummy def load_adapter_(self, active_adapter: str = "") -> bool: - """Try to make a given adapter set active if it was loaded. Return True if loaded, False if no such adapter""" - adapter_is_loaded = False + """Activate a given adapter set if available. Return True if available (or no adapter), False if missing""" + print("LOADING ADAPTER [", active_adapter, "]") + adapter_was_loaded = False for layer in self.module.modules(): # select adapter set -- leave empty string for no adapter if isinstance(layer, peft.tuners.lora.Linear): layer.active_adapter = active_adapter # empty string for no adapter if active_adapter in layer.lora_A.keys(): - adapter_is_loaded = True - return adapter_is_loaded + adapter_was_loaded = True + return adapter_was_loaded or active_adapter == "" def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]): diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index d7295ca..e3c8330 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -356,6 +356,7 @@ class TransformerConnectionHandler(ConnectionHandler): requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} active_adapter = metadata.get("active_adapter", "") + print("ACTIVE_ADAPTER: [", active_adapter, "]") points = metadata.get("points", 0) assert isinstance( points, (float, int) @@ -383,6 +384,7 @@ class TransformerConnectionHandler(ConnectionHandler): requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) active_adapter = metadata.get("active_adapter", "") + print("ACTIVE_ADAPTER: [", active_adapter, "]") points = metadata.get("points", 0) assert isinstance( points, (float, int) diff --git a/src/petals/server/server.py b/src/petals/server/server.py index e71034c..643bf1b 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -396,6 +396,7 @@ class ModuleContainer(threading.Thread): module_uids, dht, ServerState.JOINING, + adapters=adapters, throughput=throughput, update_period=update_period, expiration=expiration, @@ -469,6 +470,7 @@ class ModuleContainer(threading.Thread): expiration_time=get_dht_time() + expiration, state=ServerState.OFFLINE, throughput=throughput, + adapters=adapters, ) logger.info(f"Announced that blocks {module_uids} are offline") raise @@ -482,6 +484,7 @@ class ModuleContainer(threading.Thread): dht, dht_prefix, blocks, + adapters=adapters, throughput=throughput, update_period=update_period, expiration=expiration, @@ -497,6 +500,7 @@ class ModuleContainer(threading.Thread): inference_max_length: int, num_handlers: int, throughput: float, + adapters: Optional[Sequence[str]], update_period: float, expiration: Optional[float] = None, request_timeout: float, @@ -534,6 +538,7 @@ class ModuleContainer(threading.Thread): list(self.module_backends.keys()), dht, ServerState.ONLINE, + adapters=adapters, throughput=throughput, update_period=update_period, expiration=expiration, @@ -633,6 +638,7 @@ class ModuleAnnouncerThread(threading.Thread): module_uids: List[str], dht: DHT, state: ServerState, + adapters: Optional[Sequence[str]], *, throughput: float, update_period: float = 30, @@ -643,6 +649,7 @@ class ModuleAnnouncerThread(threading.Thread): self.module_uids = module_uids self.dht = dht self.state = state + self.adapters = adapters self.throughput = throughput self.update_period = update_period self.expiration = expiration @@ -656,6 +663,7 @@ class ModuleAnnouncerThread(threading.Thread): expiration_time=get_dht_time() + self.expiration, state=self.state, throughput=self.throughput, + adapters=self.adapters, ) if self.stop.wait(self.update_period): break diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index 6ac7fd9..1af9a51 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -154,6 +154,8 @@ def create_lora_adapter(block): ) if lora_wrapped_child: lora_wrapped_child.active_adapter = None + lora_wrapped_child.weight = child.weight + lora_wrapped_child.bias = child.bias for p in lora_wrapped_child.parameters(): p.requires_grad = False setattr(module, child_name, lora_wrapped_child) diff --git a/tests/test_full_model.py b/tests/test_full_model.py index f2679f2..8fdcd33 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -1,6 +1,7 @@ import pytest import torch import transformers +import peft from hivemind import get_logger from transformers.generation import BeamSearchScorer from transformers.models.bloom import BloomForCausalLM @@ -12,12 +13,16 @@ logger = get_logger(__name__) @pytest.mark.forked +@pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,)) @pytest.mark.parametrize("pass_empty_tensors", (True, False)) -def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3): +def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3): tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) model = DistributedBloomForCausalLM.from_pretrained( MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32 ) + if use_peft: + model.transformer.h.sequence_manager.extra_metadata = dict(active_adapter=ADAPTER_NAME) + config = model.config assert isinstance(model, DistributedBloomForCausalLM) assert len(model.transformer.h) == model.config.num_hidden_layers @@ -54,6 +59,8 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato ref_model = transformers.BloomForCausalLM.from_pretrained( REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32 ) + if use_peft: + ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME) if config.vocab_size < ref_model.config.vocab_size: ref_model.resize_token_embeddings(config.vocab_size) logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}") diff --git a/tests/test_utils.py b/tests/test_utils.py index ee440d6..4a0936c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -11,3 +11,5 @@ if not MODEL_NAME: raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested") REF_NAME = os.environ.get("REF_NAME") + +ADAPTER_NAME = os.environ.get("PEFT_NAME")