mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
Englesh erars
This commit is contained in:
parent
5ef7fea883
commit
ee1f56b492
@ -1,7 +1,7 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from hivemind import get_logger
|
||||
from hivemind.utils.logging import get_logger
|
||||
|
||||
from src.utils.generation_algorithms import (
|
||||
BeamSearchAlgorithm,
|
||||
@ -108,14 +108,14 @@ class RemoteGenerationMixin:
|
||||
if batch_size > 1:
|
||||
# TODO: resolve padding problem
|
||||
logger.warning(
|
||||
f"You set batch_size {batch_size} within beam search generation. Be carefull, results on sequences with different length may be padded wrong way"
|
||||
f"You set batch_size {batch_size} within beam search generation. Be careful, results on sequences with different length may be padded wrong way"
|
||||
)
|
||||
|
||||
if num_return_sequences is None:
|
||||
num_return_sequences = 1
|
||||
|
||||
assert num_return_sequences <= num_beams, (
|
||||
f"You want more sequences that beam will have."
|
||||
f"You want more sequences than the beam has."
|
||||
" Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
|
||||
)
|
||||
|
||||
|
@ -116,4 +116,4 @@ def test_beam_search_generation(max_new_tokens=4, num_beams=2):
|
||||
hf_outputs = BloomForCausalLM.beam_search(
|
||||
model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
|
||||
)
|
||||
assert torch.allclose(remote_outputs, hf_outputs), "Beam search are not identical to HF"
|
||||
assert torch.allclose(remote_outputs, hf_outputs), "Beam search results are not identical to HF"
|
||||
|
Loading…
Reference in New Issue
Block a user