check num block kwargs

pull/467/head
Your Name 9 months ago
parent 17d278e88a
commit 62e780c054

@ -49,11 +49,11 @@ async def sequential_forward(
end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
assert len(block_kwargs) in (
0,
1,
end_index - start_index,
assert (
len(block_kwargs) in (0, 1, end_index - start_index)
), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
if len(block_kwargs) == 1:
block_kwargs = block_kwargs * (end_index - start_index)
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
assert is_dummy(prompts) or len(prompts) == len(
sequence_manager.block_uids

Loading…
Cancel
Save