|
|
|
@ -108,3 +108,52 @@ def load_peft(
|
|
|
|
|
f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True
|
|
|
|
|
)
|
|
|
|
|
time.sleep(delay)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_lora_adapter(block):
|
|
|
|
|
for name, module in block.named_modules():
|
|
|
|
|
for child_name, child in module.named_children():
|
|
|
|
|
lora_wrapped_child = None
|
|
|
|
|
if isinstance(child, nn.Linear):
|
|
|
|
|
bias = hasattr(target, "bias") and target.bias is not None
|
|
|
|
|
lora_wrapped_child = peft.tuners.lora.Linear(
|
|
|
|
|
child_name,
|
|
|
|
|
child.in_features,
|
|
|
|
|
child.out_features,
|
|
|
|
|
bias=bias,
|
|
|
|
|
)
|
|
|
|
|
elif isinstance(child, bnb.nn.Linear8bitLt):
|
|
|
|
|
kwargs = {
|
|
|
|
|
"has_fp16_weights": child.state.has_fp16_weights,
|
|
|
|
|
"memory_efficient_backward": child.state.memory_efficient_backward,
|
|
|
|
|
"threshold": child.state.threshold,
|
|
|
|
|
"index": child.index,
|
|
|
|
|
"bias": hasattr(target, "bias") and target.bias is not None,
|
|
|
|
|
}
|
|
|
|
|
lora_wrapped_child = peft.tuners.lora.Linear8bitLt(
|
|
|
|
|
child_name,
|
|
|
|
|
child.in_features,
|
|
|
|
|
child.out_features,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
elif isinstance(child, bnb.nn.Linear4bit):
|
|
|
|
|
kwargs = {
|
|
|
|
|
"compute_dtype": child.compute_dtype,
|
|
|
|
|
"compress_statistics": child.weight.compress_statistics,
|
|
|
|
|
"quant_type": child.weight.quant_type,
|
|
|
|
|
"bias": hasattr(target, "bias") and target.bias is not None,
|
|
|
|
|
}
|
|
|
|
|
lora_wrapped_child = peft.tuners.lora.Linear4bit(
|
|
|
|
|
child_name,
|
|
|
|
|
child.in_features,
|
|
|
|
|
child.out_features,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
if lora_wrapped_child:
|
|
|
|
|
lora_wrapped_child.active_adapter = None
|
|
|
|
|
setattr(module, child_name, lora_wrapped_child)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_adapter_to_block(block, peft_config, peft_state_dict):
|
|
|
|
|
assert peft_config.peft_type == peft.PeftType.LORA, "Petals works only with LORA adapters"
|
|
|
|
|
pass
|
|
|
|
|