|
|
|
@ -1,7 +1,7 @@
|
|
|
|
|
from typing import List, Optional
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from hivemind import get_logger
|
|
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
|
|
|
|
|
|
from src.utils.generation_algorithms import (
|
|
|
|
|
BeamSearchAlgorithm,
|
|
|
|
@ -108,14 +108,14 @@ class RemoteGenerationMixin:
|
|
|
|
|
if batch_size > 1:
|
|
|
|
|
# TODO: resolve padding problem
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"You set batch_size {batch_size} within beam search generation. Be carefull, results on sequences with different length may be padded wrong way"
|
|
|
|
|
f"You set batch_size {batch_size} within beam search generation. Be careful, results on sequences with different length may be padded wrong way"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if num_return_sequences is None:
|
|
|
|
|
num_return_sequences = 1
|
|
|
|
|
|
|
|
|
|
assert num_return_sequences <= num_beams, (
|
|
|
|
|
f"You want more sequences that beam will have."
|
|
|
|
|
f"You want more sequences than the beam has."
|
|
|
|
|
" Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|