|
|
@ -10,8 +10,9 @@ import transformers
|
|
|
|
from accelerate import init_empty_weights
|
|
|
|
from accelerate import init_empty_weights
|
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
|
|
|
|
from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
|
|
|
|
|
|
|
|
from peft.config import PeftConfig
|
|
|
|
from peft.tuners import lora
|
|
|
|
from peft.tuners import lora
|
|
|
|
from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
|
|
|
|
from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME
|
|
|
|
from safetensors import safe_open
|
|
|
|
from safetensors import safe_open
|
|
|
|
from safetensors.torch import load_file
|
|
|
|
from safetensors.torch import load_file
|
|
|
|
from transformers.utils import get_file_from_repo
|
|
|
|
from transformers.utils import get_file_from_repo
|
|
|
@ -155,15 +156,15 @@ class AdapterContextMixin:
|
|
|
|
using_adapter = AdapterContextMixin.using_adapter
|
|
|
|
using_adapter = AdapterContextMixin.using_adapter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoraLinear(lora.Linear, AdapterContextMixin):
|
|
|
|
class LoraLinear(AdapterContextMixin, lora.Linear):
|
|
|
|
"""LoRA linear layer that uses adapter selected via using_adapter"""
|
|
|
|
"""LoRA linear layer that uses adapter selected via using_adapter"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoraLinear8bitLt(lora.Linear8bitLt, AdapterContextMixin):
|
|
|
|
class LoraLinear8bitLt(AdapterContextMixin, lora.Linear8bitLt):
|
|
|
|
"""LoRA linear 8-bit with outliers that uses adapter selected via using_adapter"""
|
|
|
|
"""LoRA linear 8-bit with outliers that uses adapter selected via using_adapter"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoraLinear4bit(lora.Linear4bit, AdapterContextMixin):
|
|
|
|
class LoraLinear4bit(AdapterContextMixin, lora.Linear4bit):
|
|
|
|
"""LoRA linear 4-bit that uses adapter selected via using_adapter"""
|
|
|
|
"""LoRA linear 4-bit that uses adapter selected via using_adapter"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|