add minimalistic peft test

declare_adapters
Your Name 11 months ago
parent 03d2b05166
commit 6db11c6483

@ -28,6 +28,7 @@ class RemoteSequential(nn.Module):
dht: Optional[DHT] = None, dht: Optional[DHT] = None,
start_block: Optional[int] = None, start_block: Optional[int] = None,
end_block: Optional[int] = None, end_block: Optional[int] = None,
**kwargs
): ):
super().__init__() super().__init__()
self.config = config self.config = config
@ -41,7 +42,7 @@ class RemoteSequential(nn.Module):
if end_block is None: if end_block is None:
end_block = self.config.num_hidden_layers end_block = self.config.num_hidden_layers
block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block)) block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block))
sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht) sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs)
self.sequence_manager = sequence_manager self.sequence_manager = sequence_manager
def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY): def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):

@ -78,6 +78,7 @@ class RemoteSequenceManager:
*, *,
dht: Optional[DHT] = None, dht: Optional[DHT] = None,
state: Optional[SequenceManagerState] = None, state: Optional[SequenceManagerState] = None,
extra_metadata: Optional[Dict[str, Any]] = None
): ):
assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`" assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..." assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
@ -98,6 +99,7 @@ class RemoteSequenceManager:
) )
assert isinstance(dht, DHT) and dht.is_alive(), "`dht` must be a running hivemind.DHT instance" assert isinstance(dht, DHT) and dht.is_alive(), "`dht` must be a running hivemind.DHT instance"
self.dht = dht self.dht = dht
self.extra_metadata = extra_metadata if extra_metadata is not None else {}
if state.p2p is None: if state.p2p is None:
state.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) state.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
@ -167,7 +169,9 @@ class RemoteSequenceManager:
assert isinstance(ix, (int, slice)) assert isinstance(ix, (int, slice))
if not isinstance(ix, slice): if not isinstance(ix, slice):
ix = slice(int(ix), int(ix) + 1, 1) ix = slice(int(ix), int(ix) + 1, 1)
return type(self)(self.config, self.block_uids[ix], dht=self.dht, state=self.state[ix]) return type(self)(
self.config, self.block_uids[ix], dht=self.dht, state=self.state[ix], extra_metadata=self.extra_metadata
)
def update(self, *, wait: bool): def update(self, *, wait: bool):
"""Run an asynchronous update in background as soon as possible""" """Run an asynchronous update in background as soon as possible"""
@ -307,7 +311,7 @@ class RemoteSequenceManager:
:param kwargs: additional request context, such as remote peer ID :param kwargs: additional request context, such as remote peer ID
:returns: msgpack-serialized metadata dict that will be passed alongside a given request :returns: msgpack-serialized metadata dict that will be passed alongside a given request
""" """
return dict(points=self.policy.get_points(protocol, *args, **kwargs)) return dict(**self.extra_metadata, points=self.policy.get_points(protocol, *args, **kwargs))
def shutdown(self): def shutdown(self):
self._thread.shutdown() self._thread.shutdown()

@ -22,6 +22,7 @@ def declare_active_modules(
expiration_time: DHTExpiration, expiration_time: DHTExpiration,
state: ServerState, state: ServerState,
throughput: float, throughput: float,
adapters: Optional[Sequence[str]] = None,
wait: bool = True, wait: bool = True,
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]: ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
""" """
@ -39,6 +40,7 @@ def declare_active_modules(
uids = list(uids) uids = list(uids)
for uid in uids: for uid in uids:
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
return dht.run_coroutine( return dht.run_coroutine(
partial( partial(
_declare_active_modules, _declare_active_modules,
@ -46,6 +48,7 @@ def declare_active_modules(
expiration_time=expiration_time, expiration_time=expiration_time,
state=state, state=state,
throughput=throughput, throughput=throughput,
adapters=list(adapters or []),
), ),
return_future=not wait, return_future=not wait,
) )
@ -58,12 +61,13 @@ async def _declare_active_modules(
expiration_time: DHTExpiration, expiration_time: DHTExpiration,
state: ServerState, state: ServerState,
throughput: float, throughput: float,
adapters: List[str],
) -> Dict[ModuleUID, bool]: ) -> Dict[ModuleUID, bool]:
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers) num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
return await node.store_many( return await node.store_many(
keys=uids, keys=uids,
subkeys=[dht.peer_id.to_base58()] * len(uids), subkeys=[dht.peer_id.to_base58()] * len(uids),
values=[(state.value, throughput)] * len(uids), values=[(state.value, throughput, adapters)] * len(uids),
expiration_time=expiration_time, expiration_time=expiration_time,
num_workers=num_workers, num_workers=num_workers,
) )
@ -73,18 +77,30 @@ def get_remote_module_infos(
dht: DHT, dht: DHT,
uids: Sequence[ModuleUID], uids: Sequence[ModuleUID],
expiration_time: Optional[DHTExpiration] = None, expiration_time: Optional[DHTExpiration] = None,
active_adapter: Optional[str] = None,
*, *,
latest: bool = False, latest: bool = False,
return_future: bool = False, return_future: bool = False,
) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]: ) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
return dht.run_coroutine( return dht.run_coroutine(
partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time, latest=latest), partial(
_get_remote_module_infos,
uids=uids,
active_adapter=active_adapter,
expiration_time=expiration_time,
latest=latest,
),
return_future=return_future, return_future=return_future,
) )
async def _get_remote_module_infos( async def _get_remote_module_infos(
dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration], latest: bool dht: DHT,
node: DHTNode,
uids: List[ModuleUID],
active_adapter: Optional[str],
expiration_time: Optional[DHTExpiration],
latest: bool,
) -> List[Optional[RemoteModuleInfo]]: ) -> List[Optional[RemoteModuleInfo]]:
if latest: if latest:
assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both" assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
@ -105,7 +121,12 @@ async def _get_remote_module_infos(
for peer_id, server_info in metadata.value.items(): for peer_id, server_info in metadata.value.items():
try: try:
peer_id = PeerID.from_base58(peer_id) peer_id = PeerID.from_base58(peer_id)
state, throughput = server_info.value state, throughput = server_info.value[:2]
available_adapters = server_info.value[2] if len(server_info.value) > 2 else []
if active_adapter is not None and active_adapter not in available_adapters:
logger.warning(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
continue
if not ( if not (
isinstance(state, int) isinstance(state, int)
and isinstance(throughput, float) and isinstance(throughput, float)

@ -83,13 +83,14 @@ class TransformerBackend(ModuleBackend):
def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
*inputs, active_adapter = inputs *inputs, active_adapter = inputs
if active_adapter and not self.load_adapter_(active_adapter): print("--forward...")
if not self.load_adapter_(active_adapter):
raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded") raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
return super().forward(*inputs) return super().forward(*inputs)
def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]: def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
*inputs, active_adapter = inputs *inputs, active_adapter = inputs
if active_adapter and not self.load_adapter_(active_adapter): if not self.load_adapter_(active_adapter):
raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded") raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
return super().backward(*inputs) return super().backward(*inputs)
@ -102,7 +103,8 @@ class TransformerBackend(ModuleBackend):
) -> Tuple[torch.Tensor, ...]: ) -> Tuple[torch.Tensor, ...]:
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]" assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
if inference_info.active_adapter and not self.load_adapter_(inference_info.active_adapter): print("--inference...")
if not self.load_adapter_(inference_info.active_adapter):
raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded") raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors: with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
self._reorder_cache_inplace(cache_tensors, hypo_ids) self._reorder_cache_inplace(cache_tensors, hypo_ids)
@ -156,14 +158,15 @@ class TransformerBackend(ModuleBackend):
p.data = dummy p.data = dummy
def load_adapter_(self, active_adapter: str = "") -> bool: def load_adapter_(self, active_adapter: str = "") -> bool:
"""Try to make a given adapter set active if it was loaded. Return True if loaded, False if no such adapter""" """Activate a given adapter set if available. Return True if available (or no adapter), False if missing"""
adapter_is_loaded = False print("LOADING ADAPTER [", active_adapter, "]")
adapter_was_loaded = False
for layer in self.module.modules(): # select adapter set -- leave empty string for no adapter for layer in self.module.modules(): # select adapter set -- leave empty string for no adapter
if isinstance(layer, peft.tuners.lora.Linear): if isinstance(layer, peft.tuners.lora.Linear):
layer.active_adapter = active_adapter # empty string for no adapter layer.active_adapter = active_adapter # empty string for no adapter
if active_adapter in layer.lora_A.keys(): if active_adapter in layer.lora_A.keys():
adapter_is_loaded = True adapter_was_loaded = True
return adapter_is_loaded return adapter_was_loaded or active_adapter == ""
def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]): def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):

@ -356,6 +356,7 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {} metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
active_adapter = metadata.get("active_adapter", "") active_adapter = metadata.get("active_adapter", "")
print("ACTIVE_ADAPTER: [", active_adapter, "]")
points = metadata.get("points", 0) points = metadata.get("points", 0)
assert isinstance( assert isinstance(
points, (float, int) points, (float, int)
@ -383,6 +384,7 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
active_adapter = metadata.get("active_adapter", "") active_adapter = metadata.get("active_adapter", "")
print("ACTIVE_ADAPTER: [", active_adapter, "]")
points = metadata.get("points", 0) points = metadata.get("points", 0)
assert isinstance( assert isinstance(
points, (float, int) points, (float, int)

@ -396,6 +396,7 @@ class ModuleContainer(threading.Thread):
module_uids, module_uids,
dht, dht,
ServerState.JOINING, ServerState.JOINING,
adapters=adapters,
throughput=throughput, throughput=throughput,
update_period=update_period, update_period=update_period,
expiration=expiration, expiration=expiration,
@ -469,6 +470,7 @@ class ModuleContainer(threading.Thread):
expiration_time=get_dht_time() + expiration, expiration_time=get_dht_time() + expiration,
state=ServerState.OFFLINE, state=ServerState.OFFLINE,
throughput=throughput, throughput=throughput,
adapters=adapters,
) )
logger.info(f"Announced that blocks {module_uids} are offline") logger.info(f"Announced that blocks {module_uids} are offline")
raise raise
@ -482,6 +484,7 @@ class ModuleContainer(threading.Thread):
dht, dht,
dht_prefix, dht_prefix,
blocks, blocks,
adapters=adapters,
throughput=throughput, throughput=throughput,
update_period=update_period, update_period=update_period,
expiration=expiration, expiration=expiration,
@ -497,6 +500,7 @@ class ModuleContainer(threading.Thread):
inference_max_length: int, inference_max_length: int,
num_handlers: int, num_handlers: int,
throughput: float, throughput: float,
adapters: Optional[Sequence[str]],
update_period: float, update_period: float,
expiration: Optional[float] = None, expiration: Optional[float] = None,
request_timeout: float, request_timeout: float,
@ -534,6 +538,7 @@ class ModuleContainer(threading.Thread):
list(self.module_backends.keys()), list(self.module_backends.keys()),
dht, dht,
ServerState.ONLINE, ServerState.ONLINE,
adapters=adapters,
throughput=throughput, throughput=throughput,
update_period=update_period, update_period=update_period,
expiration=expiration, expiration=expiration,
@ -633,6 +638,7 @@ class ModuleAnnouncerThread(threading.Thread):
module_uids: List[str], module_uids: List[str],
dht: DHT, dht: DHT,
state: ServerState, state: ServerState,
adapters: Optional[Sequence[str]],
*, *,
throughput: float, throughput: float,
update_period: float = 30, update_period: float = 30,
@ -643,6 +649,7 @@ class ModuleAnnouncerThread(threading.Thread):
self.module_uids = module_uids self.module_uids = module_uids
self.dht = dht self.dht = dht
self.state = state self.state = state
self.adapters = adapters
self.throughput = throughput self.throughput = throughput
self.update_period = update_period self.update_period = update_period
self.expiration = expiration self.expiration = expiration
@ -656,6 +663,7 @@ class ModuleAnnouncerThread(threading.Thread):
expiration_time=get_dht_time() + self.expiration, expiration_time=get_dht_time() + self.expiration,
state=self.state, state=self.state,
throughput=self.throughput, throughput=self.throughput,
adapters=self.adapters,
) )
if self.stop.wait(self.update_period): if self.stop.wait(self.update_period):
break break

@ -154,6 +154,8 @@ def create_lora_adapter(block):
) )
if lora_wrapped_child: if lora_wrapped_child:
lora_wrapped_child.active_adapter = None lora_wrapped_child.active_adapter = None
lora_wrapped_child.weight = child.weight
lora_wrapped_child.bias = child.bias
for p in lora_wrapped_child.parameters(): for p in lora_wrapped_child.parameters():
p.requires_grad = False p.requires_grad = False
setattr(module, child_name, lora_wrapped_child) setattr(module, child_name, lora_wrapped_child)

@ -1,6 +1,7 @@
import pytest import pytest
import torch import torch
import transformers import transformers
import peft
from hivemind import get_logger from hivemind import get_logger
from transformers.generation import BeamSearchScorer from transformers.generation import BeamSearchScorer
from transformers.models.bloom import BloomForCausalLM from transformers.models.bloom import BloomForCausalLM
@ -12,12 +13,16 @@ logger = get_logger(__name__)
@pytest.mark.forked @pytest.mark.forked
@pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,))
@pytest.mark.parametrize("pass_empty_tensors", (True, False)) @pytest.mark.parametrize("pass_empty_tensors", (True, False))
def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3): def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME) tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained( model = DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32 MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
) )
if use_peft:
model.transformer.h.sequence_manager.extra_metadata = dict(active_adapter=ADAPTER_NAME)
config = model.config config = model.config
assert isinstance(model, DistributedBloomForCausalLM) assert isinstance(model, DistributedBloomForCausalLM)
assert len(model.transformer.h) == model.config.num_hidden_layers assert len(model.transformer.h) == model.config.num_hidden_layers
@ -54,6 +59,8 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato
ref_model = transformers.BloomForCausalLM.from_pretrained( ref_model = transformers.BloomForCausalLM.from_pretrained(
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32 REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
) )
if use_peft:
ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
if config.vocab_size < ref_model.config.vocab_size: if config.vocab_size < ref_model.config.vocab_size:
ref_model.resize_token_embeddings(config.vocab_size) ref_model.resize_token_embeddings(config.vocab_size)
logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}") logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")

@ -11,3 +11,5 @@ if not MODEL_NAME:
raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested") raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
REF_NAME = os.environ.get("REF_NAME") REF_NAME = os.environ.get("REF_NAME")
ADAPTER_NAME = os.environ.get("PEFT_NAME")

Loading…
Cancel
Save