Fix generate kwargs for backward compat

pull/464/head
Aleksandr Borzunov 10 months ago
parent 41f11962b5
commit 5ed96a44b1

@ -379,11 +379,11 @@ class InferenceSession:
self.close()
@property
def last_token_id(self) -> Optional[torch.Tensor]: # For compatibility with Petals < 2.1.0
def last_token_id(self) -> Optional[torch.Tensor]: # Backward compatibility with Petals < 2.1.0
return self.output_ids[:, -1:] if self.output_ids is not None else None
@last_token_id.setter
def last_token_id(self, value: torch.Tensor): # For compatibility with Petals < 2.1.0
def last_token_id(self, value: torch.Tensor): # Backward compatibility with Petals < 2.1.0
if self.output_ids is None:
raise RuntimeError("Can't override `last_token_id` since the session has not stepped yet")
self.output_ids[:, -1:] = value

@ -68,6 +68,8 @@ class RemoteGenerationMixin(_SkipTokensMixin):
def generate(
self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs
):
self._fix_generate_kwargs(kwargs)
if session is not None:
# If a session specified explicitly, use it
context_manager = self.use_session(session)
@ -122,6 +124,19 @@ class RemoteGenerationMixin(_SkipTokensMixin):
return result
@staticmethod
def _fix_generate_kwargs(kwargs: dict) -> dict:
# Suppress inappropriate "Both max_new_tokens and max_length" HF warning
if "max_length" in kwargs and kwargs["max_length"] is None:
del kwargs["max_length"]
# Support do_sample = {0, 1} for backward compatibility with Petals < 2.1.0
do_sample = kwargs.get("do_sample")
if isinstance(do_sample, int):
kwargs["do_sample"] = bool(do_sample)
return kwargs
@staticmethod
def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:
return dataclasses.replace(past_key_values, hypo_ids=beam_idx)

Loading…
Cancel
Save