Support Llama 2 (#379)

pull/383/head
Alexander Borzunov 10 months ago committed by GitHub
parent 3218534745
commit 057a2fb5de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,7 +11,7 @@ from petals.models import *
from petals.utils import *
from petals.utils.logging import initialize_logs as _initialize_logs
__version__ = "1.2.0.dev3"
__version__ = "1.2.0.dev4"
if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):

@ -25,6 +25,8 @@ def main():
help="path or name of a pretrained model, converted with cli/convert_model.py")
group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path")
parser.add_argument("--public_name", type=str, default=None, help="Public name to be reported in the leaderboard")
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("--token", type=str, default=None, help="Hugging Face hub auth token for .from_pretrained()")
group.add_argument("--use_auth_token", action="store_true", dest="token",
@ -59,16 +61,22 @@ def main():
parser.add_argument('--num_handlers', type=int, default=8, required=False,
help='server will use this many processes to handle incoming requests')
parser.add_argument('--min_batch_size', type=int, default=1,
help='Minimum required batch size for all operations (in total tokens)')
parser.add_argument('--max_batch_size', type=int, default=2048,
help='The total number of tokens in the same batch will not exceed this value')
parser.add_argument('--prefetch_batches', type=int, default=1, required=False,
help='Pre-form this many subsequent batches while GPU is processing the current one')
parser.add_argument('--sender_threads', type=int, default=1, required=False,
help='Use this many threads to pass results/exceptions from Runtime to Pools')
parser.add_argument('--inference_max_length', type=int, default=2048,
help='Maximum total sequence length permitted per inference, defaults to 16384 tokens')
parser.add_argument('--inference_max_length', type=int, default=None,
help='Maximum total sequence length permitted per inference, defaults to 16384 tokens. '
'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)')
parser.add_argument('--min_batch_size', type=int, default=1,
help='Minimum required batch size for all operations (in total tokens)')
parser.add_argument('--max_batch_size', type=int, default=None,
help='The total number of tokens in the same batch will not exceed this value. '
'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)')
parser.add_argument('--attn_cache_tokens', type=int, default=None,
help='The number of past attention key/value pairs that will be stored between inference steps. '
'Default: 8192 for most models, 32768 for models with multi-query attention (e.g., Llama-2-70b)')
parser.add_argument('--cache_dir', type=str, default=None,
help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
@ -86,9 +94,6 @@ def main():
parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
help="Use this dtype to store block weights and do computations. "
"By default, respect the dtypes in the pre-trained state dict.")
parser.add_argument('--attn_cache_tokens', type=int, default=8192,
help='The number of past attention key/value pairs that will be stored between inference steps. '
'Default: 8192 (4 simultaneous sessions of up to 2048 tokens).')
parser.add_argument('--alloc_timeout', type=float, default=5,
help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
'before rejecting the request')

@ -27,12 +27,14 @@ class ServerInfo:
state: ServerState
throughput: RPS
public_name: Optional[str] = None
version: Optional[str] = None
network_rps: Optional[RPS] = None
forward_rps: Optional[RPS] = None
inference_rps: Optional[RPS] = None
adapters: Sequence[str] = ()
version: Optional[str] = None
torch_dtype: Optional[str] = None
quant_type: Optional[str] = None
using_relay: Optional[bool] = None

@ -18,6 +18,8 @@ class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LM
attn_class = BloomAttention
block_prefix = "h"
num_key_value_groups = 1
@classmethod
def from_pretrained(
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs

@ -73,7 +73,9 @@ class WrappedLlamaBlock(LlamaDecoderLayer):
) -> Tuple[torch.Tensor]:
key_states, value_states = key_value
key_states = key_states.permute(0, 2, 1)
key_states = key_states.view(batch_size, self.self_attn.num_heads, seq_length, self.self_attn.head_dim)
key_states = key_states.view(
batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
)
value_states = value_states.view(*key_states.shape)
return (key_states, value_states)
@ -81,7 +83,9 @@ class WrappedLlamaBlock(LlamaDecoderLayer):
self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
) -> Tuple[torch.Tensor]:
key_states, value_states = key_value
value_states = value_states.view(batch_size * self.self_attn.num_heads, seq_length, self.self_attn.head_dim)
value_states = value_states.view(
batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
)
key_states = key_states.view(*value_states.shape)
key_states = key_states.permute(0, 2, 1)
return (key_states, value_states)

@ -18,13 +18,17 @@ class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LM
attn_class = LlamaAttention
block_prefix = "model.layers"
@property
def num_key_value_groups(self):
return self.num_attention_heads // self.num_key_value_heads
@classmethod
def from_pretrained(
cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
):
logger.info(
"LLaMA is available solely for non-commercial research purposes. "
"Make sure you follow the terms of use: https://bit.ly/llama-license"
"Make sure you follow the LLaMA's terms of use: "
"https://bit.ly/llama2-license for LLaMA 2, https://bit.ly/llama-license for LLaMA 1"
)
loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
@ -34,4 +38,8 @@ class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LM
if not dht_prefix.endswith("-hf"):
dht_prefix += "-hf"
logger.info(f"Using DHT prefix: {dht_prefix}")
return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
config = result[0] if isinstance(result, tuple) else result
config.pretraining_tp = 1 # This may give less accurate results but it doesn't matter if we use quantization
return result

@ -81,6 +81,7 @@ class TransformerBackend(ModuleBackend):
head_dim = self.config.hidden_size // self.config.num_attention_heads
cache_tensors = []
for device, num_heads in zip(self.module.devices, self.shard_num_heads):
num_heads //= self.config.num_key_value_groups
keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device)
values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device)
cache_tensors.extend((keys, values))
@ -123,8 +124,10 @@ class TransformerBackend(ModuleBackend):
"""Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past"""
key_cache, value_cache = list(cache_tensors[0::2]), list(cache_tensors[1::2])
for i in range(len(key_cache)):
key_cache[i] = key_cache[i].flatten(0, 1)[:, :, :prefix_length] # [batch * num_heads, head_dim, kv_length]
value_cache[i] = value_cache[i].flatten(0, 1)[:, :prefix_length] # [batch * num_heads, kv_length, head_dim]
key_cache[i] = key_cache[i].flatten(0, 1)[:, :, :prefix_length]
# shape: [batch * num_kv_heads, head_dim, kv_length]
value_cache[i] = value_cache[i].flatten(0, 1)[:, :prefix_length]
# shape: [batch * num_kv_heads, kv_length, head_dim]
layer_past = tuple(chain(*zip(key_cache, value_cache)))
return PerDeviceTensors(*layer_past) if len(self.module.module_shards) > 1 else layer_past
@ -132,7 +135,7 @@ class TransformerBackend(ModuleBackend):
self, cache_tensors: Sequence[torch.Tensor], new_kvs: Sequence[torch.Tensor], prefix_length: int
):
"""Writes new key/value tensors back into cache, works in-place"""
_batch_size_times_num_heads, head_dim, new_length = new_kvs[0].shape
_batch_size_times_num_kv_heads, head_dim, new_length = new_kvs[0].shape
for cache_key, new_key in zip(cache_tensors[0::2], new_kvs[0::2]):
new_key = new_key.view(*cache_key.shape[:3], new_length)
cache_key[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length]

@ -23,6 +23,7 @@ from petals.constants import DTYPE_MAP
from petals.server.block_utils import resolve_block_dtype
from petals.utils.auto_config import AutoDistributedConfig
from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.hf_auth import always_needs_auth
logger = get_logger(__name__)
@ -86,6 +87,9 @@ def _load_state_dict_from_repo(
cache_dir: str,
max_disk_space: Optional[int] = None,
) -> StateDict:
if always_needs_auth(model_name) and token is None:
token = True
index_file = get_file_from_repo(
model_name, filename="pytorch_model.bin.index.json", use_auth_token=token, cache_dir=cache_dir
)

@ -145,8 +145,7 @@ class ReachabilityProtocol(ServicerBase):
async with protocol.serve(common_p2p):
await protocol._stop.wait()
except Exception as e:
logger.warning(f"Reachability service failed: {repr(e)}")
logger.debug("See detailed traceback below:", exc_info=True)
logger.debug("Reachability service failed:", exc_info=True)
if not ready.done():
ready.set_exception(e)

@ -50,18 +50,19 @@ class Server:
initial_peers: List[str],
dht_prefix: Optional[str],
converted_model_name_or_path: str,
public_name: Optional[str] = None,
throughput: Union[float, str],
num_blocks: Optional[int] = None,
block_indices: Optional[str] = None,
num_handlers: int = 8,
inference_max_length: Optional[int] = None,
min_batch_size: int = 1,
max_batch_size: int = 2048,
inference_max_length: int = 2048,
max_batch_size: Optional[int] = None,
attn_cache_tokens: Optional[int] = None,
torch_dtype: str = "auto",
revision: Optional[str] = None,
cache_dir: Optional[str] = None,
max_disk_space: Optional[int] = None,
attn_cache_tokens: int = 8192,
alloc_timeout: float = 5,
device: Optional[Union[str, torch.device]] = None,
compression=CompressionType.NONE,
@ -93,8 +94,6 @@ class Server:
self.converted_model_name_or_path = converted_model_name_or_path
self.num_handlers = num_handlers
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
self.inference_max_length = inference_max_length
self.compression = compression
self.stats_report_interval, self.update_period = stats_report_interval, update_period
self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
@ -177,8 +176,19 @@ class Server:
self.quant_type = quant_type
logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
is_multiquery_attn = self.block_config.num_key_value_groups > 1
if max_batch_size is None:
max_batch_size = 8192 if is_multiquery_attn else 2048
if inference_max_length is None:
inference_max_length = 8192 if is_multiquery_attn else 2048
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
self.inference_max_length = inference_max_length
# For attention cache in GPU or RAM
if attn_cache_tokens is None:
attn_cache_tokens = 32768 if is_multiquery_attn else 2048
cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
cache_values_per_block //= self.block_config.num_key_value_groups
self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
# For disk cache
@ -222,8 +232,9 @@ class Server:
throughput_info = {"throughput": throughput}
self.server_info = ServerInfo(
state=ServerState.JOINING,
adapters=tuple(adapters),
public_name=public_name,
version=petals.__version__,
adapters=tuple(adapters),
torch_dtype=str(torch_dtype).replace("torch.", ""),
quant_type=quant_type.name.lower(),
using_relay=self.dht.client_mode,
@ -642,7 +653,10 @@ class ModuleAnnouncerThread(threading.Thread):
self.dht = dht
self.server_info = server_info
self.memory_cache = memory_cache
self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8
self.bytes_per_token //= block_config.num_key_value_groups
self.update_period = update_period
self.expiration = expiration
self.trigger = threading.Event()

@ -1,8 +1,12 @@
import os
import re
from dataclasses import dataclass
from typing import Optional, Type
from typing import Optional, Type, Union
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
from petals.utils.hf_auth import always_needs_auth
@dataclass
class _ModelClasses:
@ -26,8 +30,11 @@ class _AutoDistributedBase:
_mapping_field = None # Should be defined in child classes
@classmethod
def from_pretrained(cls, *args, **kwargs) -> PretrainedConfig:
config = AutoConfig.from_pretrained(*args, **kwargs)
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike, None], *args, **kwargs) -> PretrainedConfig:
if always_needs_auth(model_name_or_path) and "token" not in kwargs and "use_auth_token" not in kwargs:
kwargs["token"] = True
config = AutoConfig.from_pretrained(model_name_or_path, *args, **kwargs)
if config.model_type not in _CLASS_MAPPING:
raise ValueError(f"Petals does not support model type {config.model_type}")
@ -35,7 +42,7 @@ class _AutoDistributedBase:
if proper_cls is None:
raise ValueError(f"Petals does not have {cls.__name__} for model type {config.model_type}")
return proper_cls.from_pretrained(*args, **kwargs)
return proper_cls.from_pretrained(model_name_or_path, *args, **kwargs)
class AutoDistributedConfig(_AutoDistributedBase):

@ -2,6 +2,7 @@
Tools for converting transformer blocks, applying quantization and/or tensor parallelism
"""
import re
from enum import Enum
from typing import Optional, Sequence
import tensor_parallel as tp
@ -11,12 +12,16 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
from tensor_parallel.slicing_configs import get_bloom_config
from transformers import PretrainedConfig
from petals.utils.misc import QuantType
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__name__)
class QuantType(Enum):
NONE = 0
INT8 = 1 # 8-bit as in the LLM.int8() paper
NF4 = 2 # 4-bit as in the QLoRA paper
def convert_block(
block: nn.Module,
block_index: int,

@ -0,0 +1,7 @@
import os
from typing import Union
def always_needs_auth(model_name: Union[str, os.PathLike, None]) -> bool:
loading_from_repo = model_name is not None and not os.path.isdir(model_name)
return loading_from_repo and model_name.startswith("meta-llama/Llama-2-")

@ -1,14 +1,5 @@
from enum import Enum
import torch
class QuantType(Enum):
NONE = 0
INT8 = 1 # 8-bit as in the LLM.int8() paper
NF4 = 2 # 4-bit as in the QLoRA paper
DUMMY = torch.empty(0) # dummy tensor that replaces empty prompt or adapter parameters

@ -17,8 +17,8 @@ from safetensors.torch import load_file
from transformers.utils import get_file_from_repo
from petals.server.block_utils import resolve_block_dtype
from petals.utils.convert_block import QuantType
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
from petals.utils.misc import QuantType
logger = get_logger(__name__)

Loading…
Cancel
Save