mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
Draft repetition penalty
This commit is contained in:
parent
50fb8205de
commit
e191ce2f4e
@ -171,7 +171,7 @@ class InferenceSession:
|
||||
self._server_inputs = [] # Used in case of server failures to regenerate attention caches on new servers
|
||||
self._position = 0
|
||||
self._max_length = max_length
|
||||
self.last_token_id = None
|
||||
self.token_ids = []
|
||||
|
||||
@property
|
||||
def position(self) -> int:
|
||||
|
@ -10,6 +10,7 @@ from petals.utils.generation_algorithms import (
|
||||
DecodingAlgorithm,
|
||||
GreedyAlgorithm,
|
||||
NucleusAlgorithm,
|
||||
RepetitionPenaltyAlgorithm,
|
||||
SamplingAlgorithm,
|
||||
TopKAlgorithm,
|
||||
)
|
||||
@ -48,6 +49,7 @@ class RemoteGenerationMixin:
|
||||
temperature: float = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
num_beams: Optional[int] = 1,
|
||||
bos_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
@ -69,6 +71,7 @@ class RemoteGenerationMixin:
|
||||
:param temperature: The temperature to use for sampling.
|
||||
:param top_k: The number of results to return.
|
||||
:param top_p: The cumulative probability of results to return.
|
||||
:param repetition_penalty: Repetition penalty (1.0 means no penalty). See https://arxiv.org/pdf/1909.05858.pdf
|
||||
:param num_beams: The number of beams to use for beam search.
|
||||
:param bos_token_id: The id of the beginning of sentence token.
|
||||
:param eos_token_id: The id of the end of sentence token.
|
||||
@ -111,11 +114,11 @@ class RemoteGenerationMixin:
|
||||
|
||||
if inputs is not None:
|
||||
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
|
||||
if session is not None and session.last_token_id is not None:
|
||||
inputs = torch.cat([session.last_token_id, inputs], dim=1)
|
||||
if session is not None and session.token_ids:
|
||||
inputs = torch.cat([session.token_ids[-1], inputs], dim=1)
|
||||
else:
|
||||
if session is not None and session.last_token_id is not None:
|
||||
inputs = session.last_token_id
|
||||
if session is not None and session.token_ids:
|
||||
inputs = session.token_ids[-1]
|
||||
else:
|
||||
assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
|
||||
inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
|
||||
@ -123,12 +126,14 @@ class RemoteGenerationMixin:
|
||||
|
||||
if decoding_algorithm is None:
|
||||
if do_sample:
|
||||
decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
|
||||
decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p, repetition_penalty)
|
||||
elif num_beams is not None and num_beams > 1:
|
||||
decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size)
|
||||
else:
|
||||
if top_k is not None or top_p is not None:
|
||||
logger.warning("You passed top_k or top_p but did pass do_sample=True. Running greedy sampling")
|
||||
if top_k is not None or top_p is not None or repetition_penalty is not None:
|
||||
logger.warning(
|
||||
"You passed top_k, top_p, or repetition_penalty but did pass do_sample=True. Running greedy sampling"
|
||||
)
|
||||
decoding_algorithm = GreedyAlgorithm()
|
||||
|
||||
if num_beams > 1:
|
||||
@ -160,6 +165,12 @@ class RemoteGenerationMixin:
|
||||
else:
|
||||
context_manager = contextlib.nullcontext(session) # Doesn't actually enter session or exit from it
|
||||
with context_manager as session:
|
||||
if session.token_ids:
|
||||
if inputs.shape[1] >= 2:
|
||||
session.token_ids.append(inputs[:, 1:])
|
||||
else:
|
||||
session.token_ids.append(inputs)
|
||||
|
||||
outputs = []
|
||||
# Find samples with padded inputs.
|
||||
# They will be changed before all of the samples have right length.
|
||||
@ -183,7 +194,8 @@ class RemoteGenerationMixin:
|
||||
|
||||
for constraint in constraints:
|
||||
lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
|
||||
last_token_id, hypo_ids = decoding_algorithm(lm_logits)
|
||||
token_ids = torch.cat(session.token_ids, dim=1) if session.token_ids else torch.empty(batch_size, 0, dtype=torch.int64)
|
||||
last_token_id, hypo_ids = decoding_algorithm(token_ids, lm_logits)
|
||||
|
||||
# If some samples were padded, change only these samples
|
||||
if seq_idx < inputs.size(1):
|
||||
@ -198,7 +210,7 @@ class RemoteGenerationMixin:
|
||||
outputs[i - 1] = outputs[i - 1][hypo_ids]
|
||||
|
||||
outputs.append(last_token_id)
|
||||
session.last_token_id = last_token_id
|
||||
session.token_ids.append(last_token_id)
|
||||
seq_idx += 1
|
||||
if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
|
||||
break
|
||||
@ -342,6 +354,7 @@ class RemoteGenerationMixin:
|
||||
temperature: float = 1.0,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
) -> DecodingAlgorithm:
|
||||
if (top_k is not None) and (top_p is not None):
|
||||
raise ValueError("You have to provide only top_k or top_p for sampling")
|
||||
@ -349,6 +362,8 @@ class RemoteGenerationMixin:
|
||||
return TopKAlgorithm(top_k, temperature)
|
||||
elif top_p is not None:
|
||||
return NucleusAlgorithm(top_p, temperature)
|
||||
elif repetition_penalty is not None:
|
||||
return RepetitionPenaltyAlgorithm(repetition_penalty, temperature)
|
||||
else:
|
||||
return SamplingAlgorithm(temperature)
|
||||
|
||||
|
@ -14,7 +14,7 @@ class DecodingAlgorithm(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
"""
|
||||
:param logits: A tensor of shape (batch_size, seq_lenth, vocab_size)
|
||||
:return: A tuple of selected token ids and corresponding hypotheses.
|
||||
@ -28,7 +28,7 @@ class GreedyAlgorithm(DecodingAlgorithm):
|
||||
The simplest algorithm for decoding. It selects the most probable token.
|
||||
"""
|
||||
|
||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
"""
|
||||
Returns the most probable token. The second returned object is always a range of integers
|
||||
from 0 to batch_size - 1.
|
||||
@ -51,7 +51,7 @@ class SamplingAlgorithm(DecodingAlgorithm):
|
||||
probs = torch.softmax(logits / self.temperature, -1)
|
||||
return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
|
||||
|
||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
indices_to_remove = torch.full_like(logits, False, dtype=torch.bool)
|
||||
return self.sample(logits, indices_to_remove)
|
||||
|
||||
@ -61,7 +61,7 @@ class TopKAlgorithm(SamplingAlgorithm):
|
||||
super().__init__(temperature=temperature)
|
||||
self.top_k = top_k
|
||||
|
||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None]
|
||||
return self.sample(logits, indices_to_remove)
|
||||
|
||||
@ -71,7 +71,7 @@ class NucleusAlgorithm(SamplingAlgorithm):
|
||||
super().__init__(temperature=temperature)
|
||||
self.top_p = top_p
|
||||
|
||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=False, dim=-1)
|
||||
probs = torch.softmax(sorted_logits / self.temperature, -1)
|
||||
cumulative_probs = torch.cumsum(probs, dim=-1)
|
||||
@ -82,6 +82,20 @@ class NucleusAlgorithm(SamplingAlgorithm):
|
||||
return self.sample(logits, indices_to_remove)
|
||||
|
||||
|
||||
class RepetitionPenaltyAlgorithm(SamplingAlgorithm):
|
||||
def __init__(self, repetition_penalty: float, temperature: float = 1.0) -> None:
|
||||
super().__init__(temperature=temperature)
|
||||
self.repetition_penalty = repetition_penalty
|
||||
|
||||
def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||
score = torch.gather(logits, -1, token_ids)
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||
score = torch.where(score < 0, score * self.repetition_penalty, score / self.repetition_penalty)
|
||||
logits.scatter_(-1, token_ids, score)
|
||||
|
||||
return super().__call__(token_ids, logits)
|
||||
|
||||
|
||||
class BeamSearchAlgorithm(DecodingAlgorithm):
|
||||
def __init__(self, num_beams: int, batch_size: int) -> None:
|
||||
self.num_beams = num_beams
|
||||
@ -90,7 +104,7 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
|
||||
|
||||
self._batch_beams = [list() for _ in range(batch_size)]
|
||||
|
||||
def __call__(self, logits: torch.Tensor):
|
||||
def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor):
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
probs = torch.log_softmax(sorted_logits, -1)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user