You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
48 lines
2.0 KiB
Python
48 lines
2.0 KiB
Python
import os
|
|
from typing import Optional, Union
|
|
|
|
from hivemind import get_logger
|
|
from transformers.models.llama import LlamaConfig
|
|
from transformers.models.llama.modeling_llama import LlamaAttention
|
|
|
|
from petals.client.config import ClientConfig
|
|
from petals.client.lm_head import LMHeadConfig
|
|
from petals.client.ptune import PTuneConfig
|
|
from petals.models.llama.block import WrappedLlamaBlock
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class DistributedLlamaConfig(LlamaConfig, ClientConfig, PTuneConfig, LMHeadConfig):
|
|
block_class = WrappedLlamaBlock
|
|
attn_class = LlamaAttention
|
|
block_prefix = "model.layers"
|
|
|
|
@property
|
|
def num_key_value_groups(self):
|
|
return self.num_attention_heads // self.num_key_value_heads
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
|
|
):
|
|
logger.info(
|
|
"Make sure you follow the LLaMA's terms of use: "
|
|
"https://bit.ly/llama2-license for LLaMA 2, https://bit.ly/llama-license for LLaMA 1"
|
|
)
|
|
|
|
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
|
|
if loading_from_repo and dht_prefix is None:
|
|
dht_prefix = str(model_name_or_path)
|
|
dht_prefix = dht_prefix.split("/")[-1] # Use only repo name to merge blocks hosted by different accounts
|
|
dht_prefix = dht_prefix.replace(".", "-")
|
|
if not dht_prefix.endswith("-hf"):
|
|
dht_prefix += "-hf"
|
|
logger.info(f"Using DHT prefix: {dht_prefix}")
|
|
|
|
result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
|
|
config = result[0] if isinstance(result, tuple) else result
|
|
config.pretraining_tp = 1 # This may give less accurate results but it doesn't matter if we use quantization
|
|
config.use_cache = True # use_cache=False leads to identical results but is slower and not supported by Petals
|
|
return result
|