mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
standardize checking block_kwargs
This commit is contained in:
parent
aacd8b2f9d
commit
056cd77f11
@ -49,11 +49,11 @@ async def sequential_forward(
|
||||
|
||||
end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
|
||||
assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
|
||||
assert (
|
||||
len(block_kwargs) in (0, 1, end_index - start_index)
|
||||
), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
|
||||
if len(block_kwargs) == 1:
|
||||
block_kwargs = block_kwargs * (end_index - start_index)
|
||||
assert (
|
||||
len(block_kwargs) in (0, end_index - start_index)
|
||||
), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
|
||||
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
|
||||
assert is_dummy(prompts) or len(prompts) == len(
|
||||
sequence_manager.block_uids
|
||||
|
Loading…
Reference in New Issue
Block a user