|
|
|
@ -41,16 +41,37 @@ class RemoteGenerationMixin:
|
|
|
|
|
|
|
|
|
|
return self.transformer.h.inference_session(**kwargs)
|
|
|
|
|
|
|
|
|
|
def use_session(self, session: InferenceSession) -> ContextManager[InferenceSession]:
|
|
|
|
|
def use_session(self, session: Optional[InferenceSession]) -> ContextManager[InferenceSession]:
|
|
|
|
|
return self.transformer.h.use_session(session)
|
|
|
|
|
|
|
|
|
|
def generate(self, *args, session: Optional[InferenceSession] = None, **kwargs):
|
|
|
|
|
if session is None:
|
|
|
|
|
context_manager = self.inference_session(max_length=2048) # FIXME: Provide actual length
|
|
|
|
|
def generate(
|
|
|
|
|
self,
|
|
|
|
|
inputs: Optional[torch.Tensor] = None,
|
|
|
|
|
*args,
|
|
|
|
|
max_length: Optional[int] = None,
|
|
|
|
|
max_new_tokens: Optional[int] = None,
|
|
|
|
|
session: Optional[InferenceSession] = None,
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
|
if session is not None:
|
|
|
|
|
# If a session specified explicitly, use it
|
|
|
|
|
context_manager = self.use_session(session)
|
|
|
|
|
elif self.active_session is not None:
|
|
|
|
|
# If there's an active session, don't do anything
|
|
|
|
|
context_manager = contextlib.nullcontext()
|
|
|
|
|
else:
|
|
|
|
|
context_manager = contextlib.nullcontext(session) # Doesn't actually enter session or exit from it
|
|
|
|
|
with context_manager as session:
|
|
|
|
|
return super().generate(*args, **kwargs)
|
|
|
|
|
# If there's no active session, create a new one
|
|
|
|
|
assert (max_length is None) != (
|
|
|
|
|
max_new_tokens is None
|
|
|
|
|
), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches"
|
|
|
|
|
if max_length is not None:
|
|
|
|
|
session_max_length = max_length
|
|
|
|
|
else:
|
|
|
|
|
session_max_length = (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
|
|
|
|
|
context_manager = self.inference_session(max_length=session_max_length)
|
|
|
|
|
|
|
|
|
|
with context_manager:
|
|
|
|
|
return super().generate(inputs, *args, max_length=max_length, max_new_tokens=max_new_tokens, **kwargs)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:
|
|
|
|
|