diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index b7a068b..9b26b55 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -171,7 +171,7 @@ class InferenceSession: self._server_inputs = [] # Used in case of server failures to regenerate attention caches on new servers self._position = 0 self._max_length = max_length - self.last_token_id = None + self.token_ids = [] @property def position(self) -> int: diff --git a/src/petals/client/remote_generation.py b/src/petals/client/remote_generation.py index 053e209..5e94673 100644 --- a/src/petals/client/remote_generation.py +++ b/src/petals/client/remote_generation.py @@ -10,6 +10,7 @@ from petals.utils.generation_algorithms import ( DecodingAlgorithm, GreedyAlgorithm, NucleusAlgorithm, + RepetitionPenaltyAlgorithm, SamplingAlgorithm, TopKAlgorithm, ) @@ -48,6 +49,7 @@ class RemoteGenerationMixin: temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None, + repetition_penalty: Optional[float] = None, num_beams: Optional[int] = 1, bos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None, @@ -69,6 +71,7 @@ class RemoteGenerationMixin: :param temperature: The temperature to use for sampling. :param top_k: The number of results to return. :param top_p: The cumulative probability of results to return. + :param repetition_penalty: Repetition penalty (1.0 means no penalty). See https://arxiv.org/pdf/1909.05858.pdf :param num_beams: The number of beams to use for beam search. :param bos_token_id: The id of the beginning of sentence token. :param eos_token_id: The id of the end of sentence token. @@ -111,11 +114,11 @@ class RemoteGenerationMixin: if inputs is not None: assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]" - if session is not None and session.last_token_id is not None: - inputs = torch.cat([session.last_token_id, inputs], dim=1) + if session is not None and session.token_ids: + inputs = torch.cat([session.token_ids[-1], inputs], dim=1) else: - if session is not None and session.last_token_id is not None: - inputs = session.last_token_id + if session is not None and session.token_ids: + inputs = session.token_ids[-1] else: assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs" inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device) @@ -123,12 +126,14 @@ class RemoteGenerationMixin: if decoding_algorithm is None: if do_sample: - decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p) + decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p, repetition_penalty) elif num_beams is not None and num_beams > 1: decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size) else: - if top_k is not None or top_p is not None: - logger.warning("You passed top_k or top_p but did pass do_sample=True. Running greedy sampling") + if top_k is not None or top_p is not None or repetition_penalty is not None: + logger.warning( + "You passed top_k, top_p, or repetition_penalty but did pass do_sample=True. Running greedy sampling" + ) decoding_algorithm = GreedyAlgorithm() if num_beams > 1: @@ -160,6 +165,12 @@ class RemoteGenerationMixin: else: context_manager = contextlib.nullcontext(session) # Doesn't actually enter session or exit from it with context_manager as session: + if session.token_ids: + if inputs.shape[1] >= 2: + session.token_ids.append(inputs[:, 1:]) + else: + session.token_ids.append(inputs) + outputs = [] # Find samples with padded inputs. # They will be changed before all of the samples have right length. @@ -183,7 +194,8 @@ class RemoteGenerationMixin: for constraint in constraints: lm_logits = constraint(last_token_id, lm_logits, hypo_ids) - last_token_id, hypo_ids = decoding_algorithm(lm_logits) + token_ids = torch.cat(session.token_ids, dim=1) if session.token_ids else torch.empty(batch_size, 0, dtype=torch.int64) + last_token_id, hypo_ids = decoding_algorithm(token_ids, lm_logits) # If some samples were padded, change only these samples if seq_idx < inputs.size(1): @@ -198,7 +210,7 @@ class RemoteGenerationMixin: outputs[i - 1] = outputs[i - 1][hypo_ids] outputs.append(last_token_id) - session.last_token_id = last_token_id + session.token_ids.append(last_token_id) seq_idx += 1 if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens: break @@ -342,6 +354,7 @@ class RemoteGenerationMixin: temperature: float = 1.0, top_k: Optional[int] = None, top_p: Optional[float] = None, + repetition_penalty: Optional[float] = None, ) -> DecodingAlgorithm: if (top_k is not None) and (top_p is not None): raise ValueError("You have to provide only top_k or top_p for sampling") @@ -349,6 +362,8 @@ class RemoteGenerationMixin: return TopKAlgorithm(top_k, temperature) elif top_p is not None: return NucleusAlgorithm(top_p, temperature) + elif repetition_penalty is not None: + return RepetitionPenaltyAlgorithm(repetition_penalty, temperature) else: return SamplingAlgorithm(temperature) diff --git a/src/petals/utils/generation_algorithms.py b/src/petals/utils/generation_algorithms.py index 9033371..d01e1da 100644 --- a/src/petals/utils/generation_algorithms.py +++ b/src/petals/utils/generation_algorithms.py @@ -14,7 +14,7 @@ class DecodingAlgorithm(ABC): """ @abstractmethod - def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: + def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: """ :param logits: A tensor of shape (batch_size, seq_lenth, vocab_size) :return: A tuple of selected token ids and corresponding hypotheses. @@ -28,7 +28,7 @@ class GreedyAlgorithm(DecodingAlgorithm): The simplest algorithm for decoding. It selects the most probable token. """ - def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: + def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: """ Returns the most probable token. The second returned object is always a range of integers from 0 to batch_size - 1. @@ -51,7 +51,7 @@ class SamplingAlgorithm(DecodingAlgorithm): probs = torch.softmax(logits / self.temperature, -1) return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0)) - def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: + def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: indices_to_remove = torch.full_like(logits, False, dtype=torch.bool) return self.sample(logits, indices_to_remove) @@ -61,7 +61,7 @@ class TopKAlgorithm(SamplingAlgorithm): super().__init__(temperature=temperature) self.top_k = top_k - def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: + def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None] return self.sample(logits, indices_to_remove) @@ -71,7 +71,7 @@ class NucleusAlgorithm(SamplingAlgorithm): super().__init__(temperature=temperature) self.top_p = top_p - def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: + def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: sorted_logits, sorted_indices = torch.sort(logits, descending=False, dim=-1) probs = torch.softmax(sorted_logits / self.temperature, -1) cumulative_probs = torch.cumsum(probs, dim=-1) @@ -82,6 +82,20 @@ class NucleusAlgorithm(SamplingAlgorithm): return self.sample(logits, indices_to_remove) +class RepetitionPenaltyAlgorithm(SamplingAlgorithm): + def __init__(self, repetition_penalty: float, temperature: float = 1.0) -> None: + super().__init__(temperature=temperature) + self.repetition_penalty = repetition_penalty + + def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]: + score = torch.gather(logits, -1, token_ids) + # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability + score = torch.where(score < 0, score * self.repetition_penalty, score / self.repetition_penalty) + logits.scatter_(-1, token_ids, score) + + return super().__call__(token_ids, logits) + + class BeamSearchAlgorithm(DecodingAlgorithm): def __init__(self, num_beams: int, batch_size: int) -> None: self.num_beams = num_beams @@ -90,7 +104,7 @@ class BeamSearchAlgorithm(DecodingAlgorithm): self._batch_beams = [list() for _ in range(batch_size)] - def __call__(self, logits: torch.Tensor): + def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor): sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) probs = torch.log_softmax(sorted_logits, -1)