Update generate(..., session=session) logic

pull/464/head
Aleksandr Borzunov 10 months ago
parent 6914902c06
commit fb52114446

@ -41,16 +41,37 @@ class RemoteGenerationMixin:
return self.transformer.h.inference_session(**kwargs)
def use_session(self, session: InferenceSession) -> ContextManager[InferenceSession]:
def use_session(self, session: Optional[InferenceSession]) -> ContextManager[InferenceSession]:
return self.transformer.h.use_session(session)
def generate(self, *args, session: Optional[InferenceSession] = None, **kwargs):
if session is None:
context_manager = self.inference_session(max_length=2048) # FIXME: Provide actual length
def generate(
self,
inputs: Optional[torch.Tensor] = None,
*args,
max_length: Optional[int] = None,
max_new_tokens: Optional[int] = None,
session: Optional[InferenceSession] = None,
**kwargs
):
if session is not None:
# If a session specified explicitly, use it
context_manager = self.use_session(session)
elif self.active_session is not None:
# If there's an active session, don't do anything
context_manager = contextlib.nullcontext()
else:
context_manager = contextlib.nullcontext(session) # Doesn't actually enter session or exit from it
with context_manager as session:
return super().generate(*args, **kwargs)
# If there's no active session, create a new one
assert (max_length is None) != (
max_new_tokens is None
), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches"
if max_length is not None:
session_max_length = max_length
else:
session_max_length = (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
context_manager = self.inference_session(max_length=session_max_length)
with context_manager:
return super().generate(inputs, *args, max_length=max_length, max_new_tokens=max_new_tokens, **kwargs)
@staticmethod
def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:

@ -64,7 +64,7 @@ class RemoteSequential(nn.Module):
return self._thread_local.active_session
@contextmanager
def use_session(self, session: InferenceSession) -> InferenceSession:
def use_session(self, session: Optional[InferenceSession]) -> InferenceSession:
"""Inside this context, forward() will use the specified InferenceSession."""
try:

@ -126,11 +126,11 @@ def test_sampling(tokenizer, models, sampling_options, max_new_tokens=4):
), f"Sampling is not identical to HF with {inputs.shape=}, {sampling_options=}"
def test_beam_search_generation(tokenizer, models, max_new_tokens=4, num_beams=2):
def test_beam_search_generation(tokenizer, models, max_new_tokens=4, num_beams=6):
model, ref_model = models
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
outputs = model.generate(inputs, max_new_tokens=max_new_tokens, num_beams=num_beams)
ref_outputs = ref_model.generate(inputs, max_new_tokens=max_new_tokens, num_beams=num_beams)
outputs = model.generate(inputs, max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False)
ref_outputs = ref_model.generate(inputs, max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False)
assert torch.allclose(outputs, ref_outputs), "Beam search results are not identical to HF"

Loading…
Cancel
Save