Make default --attn_cache_tokens 2x smaller

pull/530/head
Aleksandr Borzunov 7 months ago
parent ae19b65095
commit 46aef2cf2e

@ -70,17 +70,17 @@ def main():
parser.add_argument('--inference_max_length', type=int, default=None,
help='Maximum total sequence length permitted per inference, defaults to 16384 tokens. '
'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)')
'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others')
parser.add_argument('--min_batch_size', type=int, default=1,
help='Minimum required batch size for all operations (in total tokens)')
parser.add_argument('--max_batch_size', type=int, default=None,
help='The total number of tokens in the same batch will not exceed this value. '
'Default: 2048 for most models, 8192 for models with multi-query attention (e.g., Llama-2-70b)')
'Default: 8192 for models with multi-query attention (based on Llama 2, Falcon), 2048 for others')
parser.add_argument('--max_chunk_size_bytes', type=int, default=256 * 1024 * 1024,
help='Maximum size of activation tensor processed in one go; larger tensors are split into chunks')
parser.add_argument('--attn_cache_tokens', type=int, default=None,
help='The number of past attention key/value pairs that will be stored between inference steps. '
'Default: 8192 for most models, 32768 for models with multi-query attention (e.g., Llama-2-70b)')
'Default: 16384 for models with multi-query attention (based on Llama 2, Falcon), 4096 for others')
parser.add_argument('--cache_dir', type=str, default=None,
help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')

@ -203,7 +203,7 @@ class Server:
# For attention cache in GPU or RAM
if attn_cache_tokens is None:
attn_cache_tokens = 32768 if is_multiquery_attn else 8192
attn_cache_tokens = 16384 if is_multiquery_attn else 4096
cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
cache_values_per_block //= self.block_config.num_key_value_groups
self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype)

Loading…
Cancel
Save