diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index 3d48d37..37edb8f 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -86,7 +86,7 @@ jobs: python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \ --new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \ - --torch_dtype float32 --compression NONE --attn_cache_size 0.2GiB &> server1.log & + --torch_dtype float32 --compression NONE --attn_cache_tokens 2048 &> server1.log & SERVER1_PID=$! sleep 5 # wait for the first server to initialize DHT diff --git a/src/petals/cli/run_server.py b/src/petals/cli/run_server.py index 5e7efb5..fb521ef 100644 --- a/src/petals/cli/run_server.py +++ b/src/petals/cli/run_server.py @@ -7,7 +7,7 @@ from hivemind.utils.logging import get_logger from humanfriendly import parse_size from petals.constants import PUBLIC_INITIAL_PEERS -from petals.server.server import Server +from petals.server.server import DTYPE_MAP, Server from petals.utils.version import validate_version logger = get_logger(__name__) @@ -78,14 +78,12 @@ def main(): parser.add_argument('--device', type=str, default=None, required=False, help='all blocks will use this device in torch notation; default: cuda if available else cpu') - parser.add_argument("--torch_dtype", type=str, default="auto", + parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto", help="Use this dtype to store block weights and do computations. " "By default, respect the dtypes in the pre-trained state dict.") - parser.add_argument('--attn_cache_size', type=str, default=None, - help='The size of GPU memory allocated for storing past attention keys/values between inference steps. ' - 'Examples: 500MB, 1.2GB, 1073741824 (bytes). Note that 1KB != 1KiB here. ' - 'Default: 0.5GiB * num_blocks * hidden_size / 14336. ' - 'The latter is the hidden size of the bigscience/bloom-petals model.') + parser.add_argument('--attn_cache_tokens', type=int, default=8192, + help='The number of past attention key/value pairs that will be stored between inference steps. ' + 'Default: 8192 (4 simultaneous sessions of up to 2048 tokens).') parser.add_argument('--alloc_timeout', type=float, default=60, help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed ' 'before rejecting the request') @@ -178,13 +176,6 @@ def main(): compression_type = args.pop("compression").upper() compression = getattr(CompressionType, compression_type) - attn_cache_size = args.pop("attn_cache_size") - if attn_cache_size is not None: - attn_cache_size = parse_size(attn_cache_size) - assert isinstance( - attn_cache_size, (int, type(None)) - ), "Unrecognized value for --attn_cache_size. Correct examples: 1.5GB or 1500MB or 1572864000 (bytes)" - max_disk_space = args.pop("max_disk_space") if max_disk_space is not None: max_disk_space = parse_size(max_disk_space) @@ -207,7 +198,6 @@ def main(): announce_maddrs=announce_maddrs, compression=compression, max_disk_space=max_disk_space, - attn_cache_size=attn_cache_size, ) try: server.run() diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index aae181e..76dc52b 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -48,7 +48,6 @@ class TransformerBackend(ModuleBackend): self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward" ) - assert backend_dtype is not None self.dtype = backend_dtype self.shard_num_heads = [] for shard in self.module.module_shards: diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 2d666ae..e424fb5 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -56,7 +56,7 @@ class Server: revision: str = "main", cache_dir: Optional[str] = None, max_disk_space: Optional[int] = None, - attn_cache_size: Optional[int] = None, + attn_cache_tokens: int = 8192, alloc_timeout: float = 60, device: Optional[Union[str, torch.device]] = None, compression=CompressionType.NONE, @@ -148,9 +148,7 @@ class Server: device = torch.device(device.type, index=0) self.device = device - if isinstance(torch_dtype, str): - torch_dtype = DTYPE_MAP[torch_dtype] - assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}" + torch_dtype = DTYPE_MAP[torch_dtype] self.torch_dtype = resolve_block_dtype(self.block_config, torch_dtype) if tensor_parallel_devices is None: @@ -165,6 +163,9 @@ class Server: self.load_in_8bit = load_in_8bit logger.info(f"Model weights will be loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format") + max_values_in_cache = 2 * self.block_config.hidden_size * attn_cache_tokens + self._cache_bytes_per_block = max_values_in_cache * torch.finfo(self.torch_dtype).bits // 8 + assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both" if num_blocks is None and block_indices is None: num_blocks = self._choose_num_blocks() @@ -179,13 +180,10 @@ class Server: self.strict_block_indices, self.num_blocks = block_indices, num_blocks gib = 1024**3 - if attn_cache_size is None: - # Hidden size is 14336 for the bigscience/bloom-petals model. For other models, scale accordingly - attn_cache_size = 0.5 * gib * num_blocks * self.block_config.hidden_size / 14336 - - self.attn_cache_size, self.alloc_timeout = attn_cache_size, alloc_timeout - logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB") + self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks + logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB") + self.alloc_timeout = alloc_timeout if cache_dir is None: cache_dir = DEFAULT_CACHE_DIR self.cache_dir = cache_dir @@ -236,10 +234,9 @@ class Server: # The estimates below are for bigscience/bloom-petals, serving as an upper bound for other models gib = 1024**3 - attn_cache_per_block = 0.5 * gib * num_devices # TODO: This does not account for manually set --attn_cache_size autograd_memory = 2 * gib * num_devices # GPU memory used for intermediate tensors in rpc_backward - num_blocks = math.floor((total_memory - autograd_memory) / (block_size + attn_cache_per_block)) + num_blocks = math.floor((total_memory - autograd_memory) / (block_size + self._cache_bytes_per_block)) assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block" logger.info( @@ -256,7 +253,7 @@ class Server: prefix=self.prefix, converted_model_name_or_path=self.converted_model_name_or_path, block_config=self.block_config, - attn_cache_size=self.attn_cache_size, + attn_cache_bytes=self.attn_cache_bytes, alloc_timeout=self.alloc_timeout, throughput=self.throughput, block_indices=block_indices, @@ -356,7 +353,7 @@ class ModuleContainer(threading.Thread): prefix: str, converted_model_name_or_path: str, block_config: BloomConfig, - attn_cache_size: int, + attn_cache_bytes: int, alloc_timeout: float, throughput: float, block_indices: List[int], @@ -390,7 +387,7 @@ class ModuleContainer(threading.Thread): assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices) - memory_cache = MemoryCache(attn_cache_size, alloc_timeout) + memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout) blocks = {} try: for module_uid, block_index in zip(module_uids, block_indices):