Use number of tokens for attn_cache_size (#286)

* Use number of tokens for attn_cache_size

* Fix cache_bytes_per_block

* Rename attn_cache_size to attn_cache_tokens
pull/329/head
Max Ryabinin 11 months ago committed by GitHub
parent c839173e57
commit 5c0733711a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -86,7 +86,7 @@ jobs:
python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
--new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \
--torch_dtype float32 --compression NONE --attn_cache_size 0.2GiB &> server1.log &
--torch_dtype float32 --compression NONE --attn_cache_tokens 2048 &> server1.log &
SERVER1_PID=$!
sleep 5 # wait for the first server to initialize DHT

@ -7,7 +7,7 @@ from hivemind.utils.logging import get_logger
from humanfriendly import parse_size
from petals.constants import PUBLIC_INITIAL_PEERS
from petals.server.server import Server
from petals.server.server import DTYPE_MAP, Server
from petals.utils.version import validate_version
logger = get_logger(__name__)
@ -78,14 +78,12 @@ def main():
parser.add_argument('--device', type=str, default=None, required=False,
help='all blocks will use this device in torch notation; default: cuda if available else cpu')
parser.add_argument("--torch_dtype", type=str, default="auto",
parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
help="Use this dtype to store block weights and do computations. "
"By default, respect the dtypes in the pre-trained state dict.")
parser.add_argument('--attn_cache_size', type=str, default=None,
help='The size of GPU memory allocated for storing past attention keys/values between inference steps. '
'Examples: 500MB, 1.2GB, 1073741824 (bytes). Note that 1KB != 1KiB here. '
'Default: 0.5GiB * num_blocks * hidden_size / 14336. '
'The latter is the hidden size of the bigscience/bloom-petals model.')
parser.add_argument('--attn_cache_tokens', type=int, default=8192,
help='The number of past attention key/value pairs that will be stored between inference steps. '
'Default: 8192 (4 simultaneous sessions of up to 2048 tokens).')
parser.add_argument('--alloc_timeout', type=float, default=60,
help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
'before rejecting the request')
@ -178,13 +176,6 @@ def main():
compression_type = args.pop("compression").upper()
compression = getattr(CompressionType, compression_type)
attn_cache_size = args.pop("attn_cache_size")
if attn_cache_size is not None:
attn_cache_size = parse_size(attn_cache_size)
assert isinstance(
attn_cache_size, (int, type(None))
), "Unrecognized value for --attn_cache_size. Correct examples: 1.5GB or 1500MB or 1572864000 (bytes)"
max_disk_space = args.pop("max_disk_space")
if max_disk_space is not None:
max_disk_space = parse_size(max_disk_space)
@ -207,7 +198,6 @@ def main():
announce_maddrs=announce_maddrs,
compression=compression,
max_disk_space=max_disk_space,
attn_cache_size=attn_cache_size,
)
try:
server.run()

@ -48,7 +48,6 @@ class TransformerBackend(ModuleBackend):
self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward"
)
assert backend_dtype is not None
self.dtype = backend_dtype
self.shard_num_heads = []
for shard in self.module.module_shards:

@ -56,7 +56,7 @@ class Server:
revision: str = "main",
cache_dir: Optional[str] = None,
max_disk_space: Optional[int] = None,
attn_cache_size: Optional[int] = None,
attn_cache_tokens: int = 8192,
alloc_timeout: float = 60,
device: Optional[Union[str, torch.device]] = None,
compression=CompressionType.NONE,
@ -148,9 +148,7 @@ class Server:
device = torch.device(device.type, index=0)
self.device = device
if isinstance(torch_dtype, str):
torch_dtype = DTYPE_MAP[torch_dtype]
assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
torch_dtype = DTYPE_MAP[torch_dtype]
self.torch_dtype = resolve_block_dtype(self.block_config, torch_dtype)
if tensor_parallel_devices is None:
@ -165,6 +163,9 @@ class Server:
self.load_in_8bit = load_in_8bit
logger.info(f"Model weights will be loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format")
max_values_in_cache = 2 * self.block_config.hidden_size * attn_cache_tokens
self._cache_bytes_per_block = max_values_in_cache * torch.finfo(self.torch_dtype).bits // 8
assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
if num_blocks is None and block_indices is None:
num_blocks = self._choose_num_blocks()
@ -179,13 +180,10 @@ class Server:
self.strict_block_indices, self.num_blocks = block_indices, num_blocks
gib = 1024**3
if attn_cache_size is None:
# Hidden size is 14336 for the bigscience/bloom-petals model. For other models, scale accordingly
attn_cache_size = 0.5 * gib * num_blocks * self.block_config.hidden_size / 14336
self.attn_cache_size, self.alloc_timeout = attn_cache_size, alloc_timeout
logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
self.alloc_timeout = alloc_timeout
if cache_dir is None:
cache_dir = DEFAULT_CACHE_DIR
self.cache_dir = cache_dir
@ -236,10 +234,9 @@ class Server:
# The estimates below are for bigscience/bloom-petals, serving as an upper bound for other models
gib = 1024**3
attn_cache_per_block = 0.5 * gib * num_devices # TODO: This does not account for manually set --attn_cache_size
autograd_memory = 2 * gib * num_devices # GPU memory used for intermediate tensors in rpc_backward
num_blocks = math.floor((total_memory - autograd_memory) / (block_size + attn_cache_per_block))
num_blocks = math.floor((total_memory - autograd_memory) / (block_size + self._cache_bytes_per_block))
assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
logger.info(
@ -256,7 +253,7 @@ class Server:
prefix=self.prefix,
converted_model_name_or_path=self.converted_model_name_or_path,
block_config=self.block_config,
attn_cache_size=self.attn_cache_size,
attn_cache_bytes=self.attn_cache_bytes,
alloc_timeout=self.alloc_timeout,
throughput=self.throughput,
block_indices=block_indices,
@ -356,7 +353,7 @@ class ModuleContainer(threading.Thread):
prefix: str,
converted_model_name_or_path: str,
block_config: BloomConfig,
attn_cache_size: int,
attn_cache_bytes: int,
alloc_timeout: float,
throughput: float,
block_indices: List[int],
@ -390,7 +387,7 @@ class ModuleContainer(threading.Thread):
assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices)
memory_cache = MemoryCache(attn_cache_size, alloc_timeout)
memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
blocks = {}
try:
for module_uid, block_index in zip(module_uids, block_indices):

Loading…
Cancel
Save