Add first version

pull/506/head
Artem Chumachenko 8 months ago
parent 1ebd88ae7b
commit 01c3cf8d15

@ -29,10 +29,13 @@ class TransformerBackend(ModuleBackend):
def __init__(
self,
*args,
block_index: int,
config: PretrainedConfig,
memory_cache: MemoryCache,
backend_dtype: torch.dtype,
max_chunk_size_bytes: int,
cache_dir: str,
max_disk_space: int,
**kwargs,
):
import petals.utils.peft as _peft_module
@ -41,9 +44,12 @@ class TransformerBackend(ModuleBackend):
super().__init__(*args, **kwargs)
assert isinstance(self.module, TensorParallel)
self.block_index = block_index
self.config = config
self.memory_cache = memory_cache
self.max_chunk_size_bytes = max_chunk_size_bytes
self.cache_dir = cache_dir
self.max_disk_space = max_disk_space
for name, param in self.module.named_parameters():
assert not param.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"

@ -152,6 +152,7 @@ class TransformerConnectionHandler(ConnectionHandler):
session_id = metadata.get("session_id")
alloc_timeout = float(metadata.get("alloc_timeout", 0.0))
args_structure = metadata.get("args_structure")
active_adapter = self._get_active_adapter(metadata)
if not requested_uids:
raise ValueError("User must specify at least one block for inference, but got none")
assert isinstance(
@ -169,12 +170,12 @@ class TransformerConnectionHandler(ConnectionHandler):
async with self._allocate_cache(
requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
) as cache_handles:
) as cache_handles, self._load_peft_module(requested_backends, active_adapter):
background_tasks = set()
async for output_tensors, can_push in iterate_rpc_inference(
requested_uids=requested_uids,
requested_backends=requested_backends,
active_adapter=self._get_active_adapter(metadata),
active_adapter=active_adapter,
input_iterator=self._iterate_inference_steps(
request, requests, session_id, requested_uids, context
),
@ -546,6 +547,33 @@ class TransformerConnectionHandler(ConnectionHandler):
async with backends[0].memory_cache.allocate_cache(*chain(*descriptors), timeout=timeout) as handles:
yield nested_pack(handles, descriptors)
@contextlib.asynccontextmanager
async def _load_peft_module(self, backends: Sequence[TransformerBackend], active_adapter: str):
if active_adapter == "":
yield
elif active_adapter in self.adapters:
yield
else:
try:
_peft_module = backends[0]._peft_module
token = None # TODO: Provide token from user request maybe?
for backend in backends:
adapter_config, adapter_state_dict = _peft_module.load_peft(
active_adapter,
block_idx=backend.block_index,
token=token,
cache_dir=backend.cache_dir,
max_disk_space=backend.max_disk_space,
)
_peft_module.add_adapter_to_block(
backend.module, backend.block_index, active_adapter, adapter_config, adapter_state_dict
)
finally:
for backend in backends:
_peft_module.remove_adapter_from_block(backend.module, active_adapter)
def _log_request(
self,
method: str,

@ -512,6 +512,7 @@ class ModuleContainer(threading.Thread):
blocks[module_uid] = TransformerBackend(
module_uid,
block,
block_index=block_index,
config=block_config,
memory_cache=memory_cache,
backend_dtype=torch_dtype,

@ -58,10 +58,11 @@ def convert_block(
for shard, device in zip(block.module_shards, block.devices):
shard.to(device)
if adapters:
from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
create_lora_adapter(block, quant_type=quant_type)
create_lora_adapter(block, quant_type=quant_type)
if adapters:
for adapter_name in adapters:
adapter_config, adapter_state_dict = load_peft(
adapter_name,

@ -267,6 +267,22 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta
logger.info(f"Loaded adapter {adapter_name} for block {block_index}")
def remove_adapter_from_block(block, adapter_name):
for _, module in block.named_modules():
for child_name, child in module.named_children():
if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
continue
if adapter_name in child.lora_A:
del child.lora_A[adapter_name]
if adapter_name in child.lora_B:
del child.lora_B[adapter_name]
# TODO: check is this needed
if torch.cuda.is_available():
torch.cuda.empty_cache()
def estimate_adapter_memory_per_block(
block_config: transformers.PretrainedConfig,
torch_dtype: Optional[torch.dtype],

Loading…
Cancel
Save