Add skeleton for peft init

declare_adapters
artek0chumak 11 months ago
parent da204f1285
commit e452df25cc

@ -46,7 +46,7 @@ install_requires =
cpufeature>=0.2.0
packaging>=20.9
sentencepiece>=0.1.99
peft @ git+https://github.com/huggingface/peft
peft@git+https://github.com/huggingface/peft@5884bdbea49e5e71e2cd06ecfa484bb635063735
safetensors>=0.3.1
[options.extras_require]

@ -10,6 +10,7 @@ import tensor_parallel as tp
import torch
import torch.nn as nn
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from peft import create_lora_adapter, add_adapter_to_block, load_peft
from tensor_parallel.slicing_configs import get_bloom_config
from transformers import PretrainedConfig
@ -30,6 +31,7 @@ def convert_block(
output_device: torch.device,
quant_type: QuantType,
freeze: bool = True,
adapters: Optional[List[str]] = None,
) -> tp.TensorParallel:
"""
Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization
@ -55,6 +57,12 @@ def convert_block(
for shard, device in zip(block.module_shards, block.devices):
shard.to(device)
if adapters:
create_lora_adapter(block)
for adapter in adapters:
adapter_config, adapter_state_dict = load_peft(adapter)
add_adapter_to_block(block, adapter_config, adapter_state_dict)
return block

@ -108,3 +108,52 @@ def load_peft(
f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True
)
time.sleep(delay)
def create_lora_adapter(block):
for name, module in block.named_modules():
for child_name, child in module.named_children():
lora_wrapped_child = None
if isinstance(child, nn.Linear):
bias = hasattr(target, "bias") and target.bias is not None
lora_wrapped_child = peft.tuners.lora.Linear(
child_name,
child.in_features,
child.out_features,
bias=bias,
)
elif isinstance(child, bnb.nn.Linear8bitLt):
kwargs = {
"has_fp16_weights": child.state.has_fp16_weights,
"memory_efficient_backward": child.state.memory_efficient_backward,
"threshold": child.state.threshold,
"index": child.index,
"bias": hasattr(target, "bias") and target.bias is not None,
}
lora_wrapped_child = peft.tuners.lora.Linear8bitLt(
child_name,
child.in_features,
child.out_features,
**kwargs,
)
elif isinstance(child, bnb.nn.Linear4bit):
kwargs = {
"compute_dtype": child.compute_dtype,
"compress_statistics": child.weight.compress_statistics,
"quant_type": child.weight.quant_type,
"bias": hasattr(target, "bias") and target.bias is not None,
}
lora_wrapped_child = peft.tuners.lora.Linear4bit(
child_name,
child.in_features,
child.out_features,
**kwargs,
)
if lora_wrapped_child:
lora_wrapped_child.active_adapter = None
setattr(module, child_name, lora_wrapped_child)
def add_adapter_to_block(block, peft_config, peft_state_dict):
assert peft_config.peft_type == peft.PeftType.LORA, "Petals works only with LORA adapters"
pass

Loading…
Cancel
Save