|
|
|
@ -100,6 +100,7 @@ class Server(threading.Thread):
|
|
|
|
|
custom_module_path=None,
|
|
|
|
|
update_period: float = 30,
|
|
|
|
|
expiration: Optional[float] = None,
|
|
|
|
|
use_auth_token: Optional[str] = None,
|
|
|
|
|
*,
|
|
|
|
|
start: bool,
|
|
|
|
|
**kwargs,
|
|
|
|
@ -121,12 +122,12 @@ class Server(threading.Thread):
|
|
|
|
|
|
|
|
|
|
if block_indices is not None:
|
|
|
|
|
try:
|
|
|
|
|
start, end = block_indices.split(":")
|
|
|
|
|
start, end = map(int, map(str.strip, (start, end)))
|
|
|
|
|
first_block_index, last_block_index = block_indices.split(":")
|
|
|
|
|
first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:33)")
|
|
|
|
|
logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
|
|
|
|
|
raise
|
|
|
|
|
block_indices = range(start, end)
|
|
|
|
|
block_indices = range(first_block_index, last_block_index)
|
|
|
|
|
else:
|
|
|
|
|
assert num_blocks is not None
|
|
|
|
|
block_indices = range(num_blocks) # TODO replace with proper load balancing
|
|
|
|
@ -137,7 +138,13 @@ class Server(threading.Thread):
|
|
|
|
|
blocks = {}
|
|
|
|
|
for block_index in block_indices:
|
|
|
|
|
module_uid = f"{prefix}.{block_index}"
|
|
|
|
|
block = load_pretrained_block(converted_model_name_or_path, block_index, block_config, torch_dtype)
|
|
|
|
|
block = load_pretrained_block(
|
|
|
|
|
converted_model_name_or_path,
|
|
|
|
|
block_index,
|
|
|
|
|
block_config,
|
|
|
|
|
torch_dtype=torch_dtype,
|
|
|
|
|
use_auth_token=use_auth_token
|
|
|
|
|
)
|
|
|
|
|
for param in block.parameters():
|
|
|
|
|
param.requires_grad = False
|
|
|
|
|
|
|
|
|
|