|
|
@ -69,6 +69,8 @@ class RemoteGenerationMixin(_SkipTokensMixin):
|
|
|
|
self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs
|
|
|
|
self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs
|
|
|
|
):
|
|
|
|
):
|
|
|
|
self._fix_generate_kwargs(kwargs)
|
|
|
|
self._fix_generate_kwargs(kwargs)
|
|
|
|
|
|
|
|
if inputs is None:
|
|
|
|
|
|
|
|
inputs = kwargs.pop("input_ids", None)
|
|
|
|
|
|
|
|
|
|
|
|
if session is not None:
|
|
|
|
if session is not None:
|
|
|
|
# If a session specified explicitly, use it
|
|
|
|
# If a session specified explicitly, use it
|
|
|
@ -125,7 +127,7 @@ class RemoteGenerationMixin(_SkipTokensMixin):
|
|
|
|
return result
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def _fix_generate_kwargs(kwargs: dict) -> dict:
|
|
|
|
def _fix_generate_kwargs(kwargs: dict):
|
|
|
|
# Suppress inappropriate "Both max_new_tokens and max_length" HF warning
|
|
|
|
# Suppress inappropriate "Both max_new_tokens and max_length" HF warning
|
|
|
|
if "max_length" in kwargs and kwargs["max_length"] is None:
|
|
|
|
if "max_length" in kwargs and kwargs["max_length"] is None:
|
|
|
|
del kwargs["max_length"]
|
|
|
|
del kwargs["max_length"]
|
|
|
@ -135,8 +137,6 @@ class RemoteGenerationMixin(_SkipTokensMixin):
|
|
|
|
if isinstance(do_sample, int):
|
|
|
|
if isinstance(do_sample, int):
|
|
|
|
kwargs["do_sample"] = bool(do_sample)
|
|
|
|
kwargs["do_sample"] = bool(do_sample)
|
|
|
|
|
|
|
|
|
|
|
|
return kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@staticmethod
|
|
|
|
def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:
|
|
|
|
def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:
|
|
|
|
return dataclasses.replace(past_key_values, hypo_ids=beam_idx)
|
|
|
|
return dataclasses.replace(past_key_values, hypo_ids=beam_idx)
|
|
|
|