pull/506/merge
Artem Chumachenko 9 months ago committed by GitHub
commit 51af621c51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -29,10 +29,13 @@ class TransformerBackend(ModuleBackend):
def __init__(
self,
*args,
block_index: int,
config: PretrainedConfig,
memory_cache: MemoryCache,
backend_dtype: torch.dtype,
max_chunk_size_bytes: int,
cache_dir: str,
max_disk_space: int,
**kwargs,
):
import petals.utils.peft as _peft_module
@ -41,9 +44,12 @@ class TransformerBackend(ModuleBackend):
super().__init__(*args, **kwargs)
assert isinstance(self.module, TensorParallel)
self.block_index = block_index
self.config = config
self.memory_cache = memory_cache
self.max_chunk_size_bytes = max_chunk_size_bytes
self.cache_dir = cache_dir
self.max_disk_space = max_disk_space
for name, param in self.module.named_parameters():
assert not param.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
@ -51,15 +57,15 @@ class TransformerBackend(ModuleBackend):
assert not buf.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
max_batch_size = self.forward_pool.max_batch_size
device = self.module.devices[self.module.output_device_index]
self.device = self.module.devices[self.module.output_device_index]
self.inference_pool = PrioritizedTaskPool(
self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
self.inference_step, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_inference"
) # note: inference_pools may be merged later, see merge_inference_pools_inplace
self.forward_pool = PrioritizedTaskPool(
self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
self.forward, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_forward"
)
self.backward_pool = PrioritizedTaskPool(
self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward"
self.backward, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_backward"
)
self.dtype = backend_dtype

@ -15,6 +15,7 @@ from hivemind import (
MSGPackSerializer,
P2PContext,
PeerID,
TensorDescriptor,
deserialize_tensor_stream,
deserialize_torch_tensor,
nested_flatten,
@ -152,6 +153,7 @@ class TransformerConnectionHandler(ConnectionHandler):
session_id = metadata.get("session_id")
alloc_timeout = float(metadata.get("alloc_timeout", 0.0))
args_structure = metadata.get("args_structure")
active_adapter = self._get_active_adapter(metadata)
if not requested_uids:
raise ValueError("User must specify at least one block for inference, but got none")
assert isinstance(
@ -169,12 +171,14 @@ class TransformerConnectionHandler(ConnectionHandler):
async with self._allocate_cache(
requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
) as cache_handles:
) as cache_handles, self._load_peft_module(
requested_backends, active_adapter=active_adapter, timeout=alloc_timeout
):
background_tasks = set()
async for output_tensors, can_push in iterate_rpc_inference(
requested_uids=requested_uids,
requested_backends=requested_backends,
active_adapter=self._get_active_adapter(metadata),
active_adapter=active_adapter,
input_iterator=self._iterate_inference_steps(
request, requests, session_id, requested_uids, context
),
@ -489,9 +493,9 @@ class TransformerConnectionHandler(ConnectionHandler):
def _get_active_adapter(self, metadata: dict) -> str:
active_adapter = metadata.get("active_adapter", "")
if active_adapter and (active_adapter not in self.adapters):
raise KeyError(f"adapter {active_adapter} not found")
return active_adapter
if active_adapter:
return active_adapter
return ""
def _serialize_grads(
self,
@ -546,6 +550,51 @@ class TransformerConnectionHandler(ConnectionHandler):
async with backends[0].memory_cache.allocate_cache(*chain(*descriptors), timeout=timeout) as handles:
yield nested_pack(handles, descriptors)
@contextlib.asynccontextmanager
async def _load_peft_module(
self,
backends: Sequence[TransformerBackend],
*,
active_adapter: str,
timeout: float,
):
if active_adapter == "":
yield
elif active_adapter in self.adapters:
yield
else:
_peft_module = backends[0]._peft_module
token = None # TODO: Provide token from user request maybe?
estimated_peft_size = _peft_module.get_estimated_peft_module_size(
active_adapter,
token=token,
)
fake_descriptor = TensorDescriptor(
size=(estimated_peft_size,),
dtype=torch.int8,
device=backends[0].device,
)
async with backends[0].memory_cache.allocate_cache(fake_descriptor, timeout=timeout) as _:
try:
for backend in backends:
adapter_config, adapter_state_dict = _peft_module.load_peft(
active_adapter,
block_idx=backend.block_index,
token=token,
cache_dir=backend.cache_dir,
max_disk_space=backend.max_disk_space,
)
_peft_module.add_adapter_to_block(
backend.module, backend.block_index, active_adapter, adapter_config, adapter_state_dict
)
finally:
for backend in backends:
_peft_module.remove_adapter_from_block(backend.module, active_adapter)
def _log_request(
self,
method: str,

@ -231,6 +231,8 @@ class Server:
gib = 1024**3
self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
self.adapters_cache_bytes = self.attn_cache_bytes
logger.info(f"Adapter cache for all blocks will consume up to {self.adapters_cache_bytes / gib:.2f} GiB")
assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"]
if throughput in ["auto", "eval", "dry_run"]:
@ -335,6 +337,7 @@ class Server:
converted_model_name_or_path=self.converted_model_name_or_path,
block_config=self.block_config,
attn_cache_bytes=self.attn_cache_bytes,
adapters_cache_bytes=self.adapters_cache_bytes,
server_info=self.server_info,
model_info=self.model_info,
block_indices=block_indices,
@ -442,6 +445,7 @@ class ModuleContainer(threading.Thread):
converted_model_name_or_path: str,
block_config: PretrainedConfig,
attn_cache_bytes: int,
adapters_cache_bytes: int,
server_info: ServerInfo,
model_info: ModelInfo,
block_indices: List[int],
@ -464,7 +468,7 @@ class ModuleContainer(threading.Thread):
**kwargs,
) -> ModuleContainer:
module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
memory_cache = MemoryCache(attn_cache_bytes, max_alloc_timeout)
memory_cache = MemoryCache(attn_cache_bytes + adapters_cache_bytes, max_alloc_timeout)
server_info.state = ServerState.JOINING
dht_announcer = ModuleAnnouncerThread(
@ -512,10 +516,13 @@ class ModuleContainer(threading.Thread):
blocks[module_uid] = TransformerBackend(
module_uid,
block,
block_index=block_index,
config=block_config,
memory_cache=memory_cache,
backend_dtype=torch_dtype,
max_chunk_size_bytes=max_chunk_size_bytes,
cache_dir=cache_dir,
max_disk_space=max_disk_space,
args_schema=(
BatchTensorDescriptor(
1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression

@ -58,10 +58,11 @@ def convert_block(
for shard, device in zip(block.module_shards, block.devices):
shard.to(device)
if adapters:
from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
create_lora_adapter(block, quant_type=quant_type)
create_lora_adapter(block, quant_type=quant_type)
if adapters:
for adapter_name in adapters:
adapter_config, adapter_state_dict = load_peft(
adapter_name,

@ -111,11 +111,6 @@ async def _get_remote_module_infos(
try:
peer_id = PeerID.from_base58(peer_id)
server_info = ServerInfo.from_tuple(server_info.value)
if active_adapter and active_adapter not in server_info.adapters:
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
continue
servers[peer_id] = server_info
except (TypeError, ValueError) as e:
logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")

@ -128,6 +128,15 @@ def load_peft(
time.sleep(delay)
def get_estimated_peft_module_size(
repo_id: str,
revision: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
):
weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
return get_hf_file_metadata(weight_url, token=token).size
class AdapterContextMixin:
"""A mixin that makes LoRA-wrapped linear layers obey an adapter set from context"""
@ -267,6 +276,22 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta
logger.info(f"Loaded adapter {adapter_name} for block {block_index}")
def remove_adapter_from_block(block, adapter_name):
for _, module in block.named_modules():
for child_name, child in module.named_children():
if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
continue
if adapter_name in child.lora_A:
del child.lora_A[adapter_name]
if adapter_name in child.lora_B:
del child.lora_B[adapter_name]
# TODO: check is this needed
if torch.cuda.is_available():
torch.cuda.empty_cache()
def estimate_adapter_memory_per_block(
block_config: transformers.PretrainedConfig,
torch_dtype: Optional[torch.dtype],

Loading…
Cancel
Save