pull/467/head
Your Name 9 months ago
parent 49474e5477
commit 4393d99e78

@ -4,7 +4,7 @@ import asyncio
import itertools
import time
import uuid
from typing import AsyncIterator, List, Optional, Tuple, Sequence
from typing import AsyncIterator, List, Optional, Sequence, Tuple
import torch
from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor

@ -49,8 +49,11 @@ async def sequential_forward(
end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
assert len(block_kwargs) in (0, 1, end_index - start_index), \
f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
assert len(block_kwargs) in (
0,
1,
end_index - start_index,
), f"got {end_index - start_index} blocks but {len(block_kwargs)} sets of kwargs"
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
assert is_dummy(prompts) or len(prompts) == len(
sequence_manager.block_uids

Loading…
Cancel
Save