|
|
|
@ -68,6 +68,8 @@ class RemoteGenerationMixin(_SkipTokensMixin):
|
|
|
|
|
def generate(
|
|
|
|
|
self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs
|
|
|
|
|
):
|
|
|
|
|
self._fix_generate_kwargs(kwargs)
|
|
|
|
|
|
|
|
|
|
if session is not None:
|
|
|
|
|
# If a session specified explicitly, use it
|
|
|
|
|
context_manager = self.use_session(session)
|
|
|
|
@ -122,6 +124,19 @@ class RemoteGenerationMixin(_SkipTokensMixin):
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _fix_generate_kwargs(kwargs: dict) -> dict:
|
|
|
|
|
# Suppress inappropriate "Both max_new_tokens and max_length" HF warning
|
|
|
|
|
if "max_length" in kwargs and kwargs["max_length"] is None:
|
|
|
|
|
del kwargs["max_length"]
|
|
|
|
|
|
|
|
|
|
# Support do_sample = {0, 1} for backward compatibility with Petals < 2.1.0
|
|
|
|
|
do_sample = kwargs.get("do_sample")
|
|
|
|
|
if isinstance(do_sample, int):
|
|
|
|
|
kwargs["do_sample"] = bool(do_sample)
|
|
|
|
|
|
|
|
|
|
return kwargs
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:
|
|
|
|
|
return dataclasses.replace(past_key_values, hypo_ids=beam_idx)
|
|
|
|
|