add minimalistic peft test

declare_adapters
Your Name 10 months ago
parent 03d2b05166
commit 6db11c6483

@ -28,6 +28,7 @@ class RemoteSequential(nn.Module):
dht: Optional[DHT] = None,
start_block: Optional[int] = None,
end_block: Optional[int] = None,
**kwargs
):
super().__init__()
self.config = config
@ -41,7 +42,7 @@ class RemoteSequential(nn.Module):
if end_block is None:
end_block = self.config.num_hidden_layers
block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block))
sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht)
sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs)
self.sequence_manager = sequence_manager
def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):

@ -78,6 +78,7 @@ class RemoteSequenceManager:
*,
dht: Optional[DHT] = None,
state: Optional[SequenceManagerState] = None,
extra_metadata: Optional[Dict[str, Any]] = None
):
assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
@ -98,6 +99,7 @@ class RemoteSequenceManager:
)
assert isinstance(dht, DHT) and dht.is_alive(), "`dht` must be a running hivemind.DHT instance"
self.dht = dht
self.extra_metadata = extra_metadata if extra_metadata is not None else {}
if state.p2p is None:
state.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
@ -167,7 +169,9 @@ class RemoteSequenceManager:
assert isinstance(ix, (int, slice))
if not isinstance(ix, slice):
ix = slice(int(ix), int(ix) + 1, 1)
return type(self)(self.config, self.block_uids[ix], dht=self.dht, state=self.state[ix])
return type(self)(
self.config, self.block_uids[ix], dht=self.dht, state=self.state[ix], extra_metadata=self.extra_metadata
)
def update(self, *, wait: bool):
"""Run an asynchronous update in background as soon as possible"""
@ -307,7 +311,7 @@ class RemoteSequenceManager:
:param kwargs: additional request context, such as remote peer ID
:returns: msgpack-serialized metadata dict that will be passed alongside a given request
"""
return dict(points=self.policy.get_points(protocol, *args, **kwargs))
return dict(**self.extra_metadata, points=self.policy.get_points(protocol, *args, **kwargs))
def shutdown(self):
self._thread.shutdown()

@ -22,6 +22,7 @@ def declare_active_modules(
expiration_time: DHTExpiration,
state: ServerState,
throughput: float,
adapters: Optional[Sequence[str]] = None,
wait: bool = True,
) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
"""
@ -39,6 +40,7 @@ def declare_active_modules(
uids = list(uids)
for uid in uids:
assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
return dht.run_coroutine(
partial(
_declare_active_modules,
@ -46,6 +48,7 @@ def declare_active_modules(
expiration_time=expiration_time,
state=state,
throughput=throughput,
adapters=list(adapters or []),
),
return_future=not wait,
)
@ -58,12 +61,13 @@ async def _declare_active_modules(
expiration_time: DHTExpiration,
state: ServerState,
throughput: float,
adapters: List[str],
) -> Dict[ModuleUID, bool]:
num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
return await node.store_many(
keys=uids,
subkeys=[dht.peer_id.to_base58()] * len(uids),
values=[(state.value, throughput)] * len(uids),
values=[(state.value, throughput, adapters)] * len(uids),
expiration_time=expiration_time,
num_workers=num_workers,
)
@ -73,18 +77,30 @@ def get_remote_module_infos(
dht: DHT,
uids: Sequence[ModuleUID],
expiration_time: Optional[DHTExpiration] = None,
active_adapter: Optional[str] = None,
*,
latest: bool = False,
return_future: bool = False,
) -> Union[List[Optional[RemoteModuleInfo]], MPFuture]:
return dht.run_coroutine(
partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time, latest=latest),
partial(
_get_remote_module_infos,
uids=uids,
active_adapter=active_adapter,
expiration_time=expiration_time,
latest=latest,
),
return_future=return_future,
)
async def _get_remote_module_infos(
dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration], latest: bool
dht: DHT,
node: DHTNode,
uids: List[ModuleUID],
active_adapter: Optional[str],
expiration_time: Optional[DHTExpiration],
latest: bool,
) -> List[Optional[RemoteModuleInfo]]:
if latest:
assert expiration_time is None, "You should define either `expiration_time` or `latest`, not both"
@ -105,7 +121,12 @@ async def _get_remote_module_infos(
for peer_id, server_info in metadata.value.items():
try:
peer_id = PeerID.from_base58(peer_id)
state, throughput = server_info.value
state, throughput = server_info.value[:2]
available_adapters = server_info.value[2] if len(server_info.value) > 2 else []
if active_adapter is not None and active_adapter not in available_adapters:
logger.warning(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
continue
if not (
isinstance(state, int)
and isinstance(throughput, float)

@ -83,13 +83,14 @@ class TransformerBackend(ModuleBackend):
def forward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
*inputs, active_adapter = inputs
if active_adapter and not self.load_adapter_(active_adapter):
print("--forward...")
if not self.load_adapter_(active_adapter):
raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
return super().forward(*inputs)
def backward(self, *inputs: Union[torch.Tensor, str]) -> Tuple[torch.Tensor, ...]:
*inputs, active_adapter = inputs
if active_adapter and not self.load_adapter_(active_adapter):
if not self.load_adapter_(active_adapter):
raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
return super().backward(*inputs)
@ -102,7 +103,8 @@ class TransformerBackend(ModuleBackend):
) -> Tuple[torch.Tensor, ...]:
assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
if inference_info.active_adapter and not self.load_adapter_(inference_info.active_adapter):
print("--inference...")
if not self.load_adapter_(inference_info.active_adapter):
raise KeyError("Could not find adapter {inference_info.active_adapter}; perhaps it is not loaded")
with self.memory_cache.use_cache(*inference_info.cache_handles) as cache_tensors:
self._reorder_cache_inplace(cache_tensors, hypo_ids)
@ -156,14 +158,15 @@ class TransformerBackend(ModuleBackend):
p.data = dummy
def load_adapter_(self, active_adapter: str = "") -> bool:
"""Try to make a given adapter set active if it was loaded. Return True if loaded, False if no such adapter"""
adapter_is_loaded = False
"""Activate a given adapter set if available. Return True if available (or no adapter), False if missing"""
print("LOADING ADAPTER [", active_adapter, "]")
adapter_was_loaded = False
for layer in self.module.modules(): # select adapter set -- leave empty string for no adapter
if isinstance(layer, peft.tuners.lora.Linear):
layer.active_adapter = active_adapter # empty string for no adapter
if active_adapter in layer.lora_A.keys():
adapter_is_loaded = True
return adapter_is_loaded
adapter_was_loaded = True
return adapter_was_loaded or active_adapter == ""
def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]):

@ -356,6 +356,7 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
active_adapter = metadata.get("active_adapter", "")
print("ACTIVE_ADAPTER: [", active_adapter, "]")
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
@ -383,6 +384,7 @@ class TransformerConnectionHandler(ConnectionHandler):
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
active_adapter = metadata.get("active_adapter", "")
print("ACTIVE_ADAPTER: [", active_adapter, "]")
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)

@ -396,6 +396,7 @@ class ModuleContainer(threading.Thread):
module_uids,
dht,
ServerState.JOINING,
adapters=adapters,
throughput=throughput,
update_period=update_period,
expiration=expiration,
@ -469,6 +470,7 @@ class ModuleContainer(threading.Thread):
expiration_time=get_dht_time() + expiration,
state=ServerState.OFFLINE,
throughput=throughput,
adapters=adapters,
)
logger.info(f"Announced that blocks {module_uids} are offline")
raise
@ -482,6 +484,7 @@ class ModuleContainer(threading.Thread):
dht,
dht_prefix,
blocks,
adapters=adapters,
throughput=throughput,
update_period=update_period,
expiration=expiration,
@ -497,6 +500,7 @@ class ModuleContainer(threading.Thread):
inference_max_length: int,
num_handlers: int,
throughput: float,
adapters: Optional[Sequence[str]],
update_period: float,
expiration: Optional[float] = None,
request_timeout: float,
@ -534,6 +538,7 @@ class ModuleContainer(threading.Thread):
list(self.module_backends.keys()),
dht,
ServerState.ONLINE,
adapters=adapters,
throughput=throughput,
update_period=update_period,
expiration=expiration,
@ -633,6 +638,7 @@ class ModuleAnnouncerThread(threading.Thread):
module_uids: List[str],
dht: DHT,
state: ServerState,
adapters: Optional[Sequence[str]],
*,
throughput: float,
update_period: float = 30,
@ -643,6 +649,7 @@ class ModuleAnnouncerThread(threading.Thread):
self.module_uids = module_uids
self.dht = dht
self.state = state
self.adapters = adapters
self.throughput = throughput
self.update_period = update_period
self.expiration = expiration
@ -656,6 +663,7 @@ class ModuleAnnouncerThread(threading.Thread):
expiration_time=get_dht_time() + self.expiration,
state=self.state,
throughput=self.throughput,
adapters=self.adapters,
)
if self.stop.wait(self.update_period):
break

@ -154,6 +154,8 @@ def create_lora_adapter(block):
)
if lora_wrapped_child:
lora_wrapped_child.active_adapter = None
lora_wrapped_child.weight = child.weight
lora_wrapped_child.bias = child.bias
for p in lora_wrapped_child.parameters():
p.requires_grad = False
setattr(module, child_name, lora_wrapped_child)

@ -1,6 +1,7 @@
import pytest
import torch
import transformers
import peft
from hivemind import get_logger
from transformers.generation import BeamSearchScorer
from transformers.models.bloom import BloomForCausalLM
@ -12,12 +13,16 @@ logger = get_logger(__name__)
@pytest.mark.forked
@pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,))
@pytest.mark.parametrize("pass_empty_tensors", (True, False))
def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
if use_peft:
model.transformer.h.sequence_manager.extra_metadata = dict(active_adapter=ADAPTER_NAME)
config = model.config
assert isinstance(model, DistributedBloomForCausalLM)
assert len(model.transformer.h) == model.config.num_hidden_layers
@ -54,6 +59,8 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato
ref_model = transformers.BloomForCausalLM.from_pretrained(
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
if use_peft:
ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
if config.vocab_size < ref_model.config.vocab_size:
ref_model.resize_token_embeddings(config.vocab_size)
logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")

@ -11,3 +11,5 @@ if not MODEL_NAME:
raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
REF_NAME = os.environ.get("REF_NAME")
ADAPTER_NAME = os.environ.get("PEFT_NAME")

Loading…
Cancel
Save