standardize checking block_kwargs

This commit is contained in:
Your Name 2023-09-06 03:16:15 +03:00
parent aacd8b2f9d
commit 056cd77f11

View File

@ -49,11 +49,11 @@ async def sequential_forward(
end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
assert (
len(block_kwargs) in (0, 1, end_index - start_index)
), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
if len(block_kwargs) == 1:
block_kwargs = block_kwargs * (end_index - start_index)
assert (
len(block_kwargs) in (0, end_index - start_index)
), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
assert is_dummy(prompts) or len(prompts) == len(
sequence_manager.block_uids