Fix llama

pull/464/head
Aleksandr Borzunov 10 months ago
parent 813090e4fa
commit 56cfacdfb9

@ -49,11 +49,7 @@ class RemoteGenerationMixin:
return self.transformer.h.use_session(session)
def generate(
self,
inputs: Optional[torch.Tensor] = None,
*args,
session: Optional[InferenceSession] = None,
**kwargs
self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs
):
if session is not None:
# If a session specified explicitly, use it

@ -62,10 +62,9 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
assert (
attention_mask is None or (attention_mask == 1).all()
), f"Custom attention masks are not supported, {attention_mask=}"
if position_ids is not None:
start_pos = position_ids[0].item()
expected = torch.arange(start_pos, start_pos + input_shape[1], dtype=torch.long, device=position_ids.device)
assert (position_ids == expected).all(), f"Custom position_ids are not supported, {position_ids=}"
assert (
position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
), f"Non-consecutive position_ids are not supported, {position_ids=}"
assert use_cache is None or use_cache, f"{use_cache=} is not supported"
assert not output_attentions, f"{output_attentions=} is not supported"
assert not output_hidden_states, f"{output_hidden_states=} is not supported"

Loading…
Cancel
Save