mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
Draft repetition penalty
This commit is contained in:
parent
50fb8205de
commit
e191ce2f4e
@ -171,7 +171,7 @@ class InferenceSession:
|
|||||||
self._server_inputs = [] # Used in case of server failures to regenerate attention caches on new servers
|
self._server_inputs = [] # Used in case of server failures to regenerate attention caches on new servers
|
||||||
self._position = 0
|
self._position = 0
|
||||||
self._max_length = max_length
|
self._max_length = max_length
|
||||||
self.last_token_id = None
|
self.token_ids = []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def position(self) -> int:
|
def position(self) -> int:
|
||||||
|
@ -10,6 +10,7 @@ from petals.utils.generation_algorithms import (
|
|||||||
DecodingAlgorithm,
|
DecodingAlgorithm,
|
||||||
GreedyAlgorithm,
|
GreedyAlgorithm,
|
||||||
NucleusAlgorithm,
|
NucleusAlgorithm,
|
||||||
|
RepetitionPenaltyAlgorithm,
|
||||||
SamplingAlgorithm,
|
SamplingAlgorithm,
|
||||||
TopKAlgorithm,
|
TopKAlgorithm,
|
||||||
)
|
)
|
||||||
@ -48,6 +49,7 @@ class RemoteGenerationMixin:
|
|||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
num_beams: Optional[int] = 1,
|
num_beams: Optional[int] = 1,
|
||||||
bos_token_id: Optional[int] = None,
|
bos_token_id: Optional[int] = None,
|
||||||
eos_token_id: Optional[int] = None,
|
eos_token_id: Optional[int] = None,
|
||||||
@ -69,6 +71,7 @@ class RemoteGenerationMixin:
|
|||||||
:param temperature: The temperature to use for sampling.
|
:param temperature: The temperature to use for sampling.
|
||||||
:param top_k: The number of results to return.
|
:param top_k: The number of results to return.
|
||||||
:param top_p: The cumulative probability of results to return.
|
:param top_p: The cumulative probability of results to return.
|
||||||
|
:param repetition_penalty: Repetition penalty (1.0 means no penalty). See https://arxiv.org/pdf/1909.05858.pdf
|
||||||
:param num_beams: The number of beams to use for beam search.
|
:param num_beams: The number of beams to use for beam search.
|
||||||
:param bos_token_id: The id of the beginning of sentence token.
|
:param bos_token_id: The id of the beginning of sentence token.
|
||||||
:param eos_token_id: The id of the end of sentence token.
|
:param eos_token_id: The id of the end of sentence token.
|
||||||
@ -111,11 +114,11 @@ class RemoteGenerationMixin:
|
|||||||
|
|
||||||
if inputs is not None:
|
if inputs is not None:
|
||||||
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
|
assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
|
||||||
if session is not None and session.last_token_id is not None:
|
if session is not None and session.token_ids:
|
||||||
inputs = torch.cat([session.last_token_id, inputs], dim=1)
|
inputs = torch.cat([session.token_ids[-1], inputs], dim=1)
|
||||||
else:
|
else:
|
||||||
if session is not None and session.last_token_id is not None:
|
if session is not None and session.token_ids:
|
||||||
inputs = session.last_token_id
|
inputs = session.token_ids[-1]
|
||||||
else:
|
else:
|
||||||
assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
|
assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
|
||||||
inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
|
inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
|
||||||
@ -123,12 +126,14 @@ class RemoteGenerationMixin:
|
|||||||
|
|
||||||
if decoding_algorithm is None:
|
if decoding_algorithm is None:
|
||||||
if do_sample:
|
if do_sample:
|
||||||
decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
|
decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p, repetition_penalty)
|
||||||
elif num_beams is not None and num_beams > 1:
|
elif num_beams is not None and num_beams > 1:
|
||||||
decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size)
|
decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size)
|
||||||
else:
|
else:
|
||||||
if top_k is not None or top_p is not None:
|
if top_k is not None or top_p is not None or repetition_penalty is not None:
|
||||||
logger.warning("You passed top_k or top_p but did pass do_sample=True. Running greedy sampling")
|
logger.warning(
|
||||||
|
"You passed top_k, top_p, or repetition_penalty but did pass do_sample=True. Running greedy sampling"
|
||||||
|
)
|
||||||
decoding_algorithm = GreedyAlgorithm()
|
decoding_algorithm = GreedyAlgorithm()
|
||||||
|
|
||||||
if num_beams > 1:
|
if num_beams > 1:
|
||||||
@ -160,6 +165,12 @@ class RemoteGenerationMixin:
|
|||||||
else:
|
else:
|
||||||
context_manager = contextlib.nullcontext(session) # Doesn't actually enter session or exit from it
|
context_manager = contextlib.nullcontext(session) # Doesn't actually enter session or exit from it
|
||||||
with context_manager as session:
|
with context_manager as session:
|
||||||
|
if session.token_ids:
|
||||||
|
if inputs.shape[1] >= 2:
|
||||||
|
session.token_ids.append(inputs[:, 1:])
|
||||||
|
else:
|
||||||
|
session.token_ids.append(inputs)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
# Find samples with padded inputs.
|
# Find samples with padded inputs.
|
||||||
# They will be changed before all of the samples have right length.
|
# They will be changed before all of the samples have right length.
|
||||||
@ -183,7 +194,8 @@ class RemoteGenerationMixin:
|
|||||||
|
|
||||||
for constraint in constraints:
|
for constraint in constraints:
|
||||||
lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
|
lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
|
||||||
last_token_id, hypo_ids = decoding_algorithm(lm_logits)
|
token_ids = torch.cat(session.token_ids, dim=1) if session.token_ids else torch.empty(batch_size, 0, dtype=torch.int64)
|
||||||
|
last_token_id, hypo_ids = decoding_algorithm(token_ids, lm_logits)
|
||||||
|
|
||||||
# If some samples were padded, change only these samples
|
# If some samples were padded, change only these samples
|
||||||
if seq_idx < inputs.size(1):
|
if seq_idx < inputs.size(1):
|
||||||
@ -198,7 +210,7 @@ class RemoteGenerationMixin:
|
|||||||
outputs[i - 1] = outputs[i - 1][hypo_ids]
|
outputs[i - 1] = outputs[i - 1][hypo_ids]
|
||||||
|
|
||||||
outputs.append(last_token_id)
|
outputs.append(last_token_id)
|
||||||
session.last_token_id = last_token_id
|
session.token_ids.append(last_token_id)
|
||||||
seq_idx += 1
|
seq_idx += 1
|
||||||
if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
|
if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
|
||||||
break
|
break
|
||||||
@ -342,6 +354,7 @@ class RemoteGenerationMixin:
|
|||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
) -> DecodingAlgorithm:
|
) -> DecodingAlgorithm:
|
||||||
if (top_k is not None) and (top_p is not None):
|
if (top_k is not None) and (top_p is not None):
|
||||||
raise ValueError("You have to provide only top_k or top_p for sampling")
|
raise ValueError("You have to provide only top_k or top_p for sampling")
|
||||||
@ -349,6 +362,8 @@ class RemoteGenerationMixin:
|
|||||||
return TopKAlgorithm(top_k, temperature)
|
return TopKAlgorithm(top_k, temperature)
|
||||||
elif top_p is not None:
|
elif top_p is not None:
|
||||||
return NucleusAlgorithm(top_p, temperature)
|
return NucleusAlgorithm(top_p, temperature)
|
||||||
|
elif repetition_penalty is not None:
|
||||||
|
return RepetitionPenaltyAlgorithm(repetition_penalty, temperature)
|
||||||
else:
|
else:
|
||||||
return SamplingAlgorithm(temperature)
|
return SamplingAlgorithm(temperature)
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ class DecodingAlgorithm(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||||
"""
|
"""
|
||||||
:param logits: A tensor of shape (batch_size, seq_lenth, vocab_size)
|
:param logits: A tensor of shape (batch_size, seq_lenth, vocab_size)
|
||||||
:return: A tuple of selected token ids and corresponding hypotheses.
|
:return: A tuple of selected token ids and corresponding hypotheses.
|
||||||
@ -28,7 +28,7 @@ class GreedyAlgorithm(DecodingAlgorithm):
|
|||||||
The simplest algorithm for decoding. It selects the most probable token.
|
The simplest algorithm for decoding. It selects the most probable token.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||||
"""
|
"""
|
||||||
Returns the most probable token. The second returned object is always a range of integers
|
Returns the most probable token. The second returned object is always a range of integers
|
||||||
from 0 to batch_size - 1.
|
from 0 to batch_size - 1.
|
||||||
@ -51,7 +51,7 @@ class SamplingAlgorithm(DecodingAlgorithm):
|
|||||||
probs = torch.softmax(logits / self.temperature, -1)
|
probs = torch.softmax(logits / self.temperature, -1)
|
||||||
return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
|
return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
|
||||||
|
|
||||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||||
indices_to_remove = torch.full_like(logits, False, dtype=torch.bool)
|
indices_to_remove = torch.full_like(logits, False, dtype=torch.bool)
|
||||||
return self.sample(logits, indices_to_remove)
|
return self.sample(logits, indices_to_remove)
|
||||||
|
|
||||||
@ -61,7 +61,7 @@ class TopKAlgorithm(SamplingAlgorithm):
|
|||||||
super().__init__(temperature=temperature)
|
super().__init__(temperature=temperature)
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|
||||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||||
indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None]
|
indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None]
|
||||||
return self.sample(logits, indices_to_remove)
|
return self.sample(logits, indices_to_remove)
|
||||||
|
|
||||||
@ -71,7 +71,7 @@ class NucleusAlgorithm(SamplingAlgorithm):
|
|||||||
super().__init__(temperature=temperature)
|
super().__init__(temperature=temperature)
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
|
|
||||||
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=False, dim=-1)
|
sorted_logits, sorted_indices = torch.sort(logits, descending=False, dim=-1)
|
||||||
probs = torch.softmax(sorted_logits / self.temperature, -1)
|
probs = torch.softmax(sorted_logits / self.temperature, -1)
|
||||||
cumulative_probs = torch.cumsum(probs, dim=-1)
|
cumulative_probs = torch.cumsum(probs, dim=-1)
|
||||||
@ -82,6 +82,20 @@ class NucleusAlgorithm(SamplingAlgorithm):
|
|||||||
return self.sample(logits, indices_to_remove)
|
return self.sample(logits, indices_to_remove)
|
||||||
|
|
||||||
|
|
||||||
|
class RepetitionPenaltyAlgorithm(SamplingAlgorithm):
|
||||||
|
def __init__(self, repetition_penalty: float, temperature: float = 1.0) -> None:
|
||||||
|
super().__init__(temperature=temperature)
|
||||||
|
self.repetition_penalty = repetition_penalty
|
||||||
|
|
||||||
|
def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
||||||
|
score = torch.gather(logits, -1, token_ids)
|
||||||
|
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||||
|
score = torch.where(score < 0, score * self.repetition_penalty, score / self.repetition_penalty)
|
||||||
|
logits.scatter_(-1, token_ids, score)
|
||||||
|
|
||||||
|
return super().__call__(token_ids, logits)
|
||||||
|
|
||||||
|
|
||||||
class BeamSearchAlgorithm(DecodingAlgorithm):
|
class BeamSearchAlgorithm(DecodingAlgorithm):
|
||||||
def __init__(self, num_beams: int, batch_size: int) -> None:
|
def __init__(self, num_beams: int, batch_size: int) -> None:
|
||||||
self.num_beams = num_beams
|
self.num_beams = num_beams
|
||||||
@ -90,7 +104,7 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
|
|||||||
|
|
||||||
self._batch_beams = [list() for _ in range(batch_size)]
|
self._batch_beams = [list() for _ in range(batch_size)]
|
||||||
|
|
||||||
def __call__(self, logits: torch.Tensor):
|
def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor):
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||||
probs = torch.log_softmax(sorted_logits, -1)
|
probs = torch.log_softmax(sorted_logits, -1)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user