|
|
|
@ -49,11 +49,11 @@ async def sequential_forward(
|
|
|
|
|
|
|
|
|
|
end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
|
|
|
|
|
assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
|
|
|
|
|
assert len(block_kwargs) in (
|
|
|
|
|
0,
|
|
|
|
|
1,
|
|
|
|
|
end_index - start_index,
|
|
|
|
|
assert (
|
|
|
|
|
len(block_kwargs) in (0, 1, end_index - start_index)
|
|
|
|
|
), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
|
|
|
|
|
if len(block_kwargs) == 1:
|
|
|
|
|
block_kwargs = block_kwargs * (end_index - start_index)
|
|
|
|
|
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
|
|
|
|
|
assert is_dummy(prompts) or len(prompts) == len(
|
|
|
|
|
sequence_manager.block_uids
|
|
|
|
|