Remove inputs_embeds support

pull/485/head
Aleksandr Borzunov 9 months ago
parent 7fe2635ee6
commit 06accbcf40

@ -71,9 +71,6 @@ class RemoteGenerationMixin(_SkipTokensMixin):
self._fix_generate_kwargs(kwargs)
if inputs is None:
inputs = kwargs.pop("input_ids", None)
inputs_len = inputs.shape[1] if inputs is not None else 0
if "inputs_embeds" in kwargs:
inputs_len = kwargs["inputs_embeds"].shape[1]
if session is not None:
# If a session specified explicitly, use it
@ -93,7 +90,7 @@ class RemoteGenerationMixin(_SkipTokensMixin):
if max_length is not None:
session_max_length = max_length
else:
session_max_length = inputs_len + max_new_tokens
session_max_length = (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
context_manager = self.inference_session(max_length=session_max_length)
with context_manager as session:

@ -152,7 +152,7 @@ def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, n
@pytest.mark.forked
def test_input_ids_and_embeds(tokenizer, model, ref_model, max_new_tokens=4):
def test_input_ids(tokenizer, model, ref_model, max_new_tokens=4):
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")
assert inputs.keys() == {"input_ids", "attention_mask"}
@ -168,17 +168,3 @@ def test_input_ids_and_embeds(tokenizer, model, ref_model, max_new_tokens=4):
]
)
assert torch.allclose(outputs, ref_outputs), f"Multi-call outputs are not identical to HF"
inputs_embeds = model.transformer.word_embeddings(inputs["input_ids"])
outputs = model.generate(inputs_embeds=inputs_embeds, max_new_tokens=max_new_tokens)
ref_outputs = ref_model.generate(inputs_embeds=inputs_embeds, max_new_tokens=max_new_tokens)
assert torch.allclose(outputs, ref_outputs), f"Outputs are not identical to HF"
with model.inference_session(max_length=inputs["input_ids"].shape[1] + max_new_tokens):
outputs = torch.cat(
[
model.generate(inputs_embeds=inputs_embeds, max_new_tokens=2),
model.generate(None, max_new_tokens=max_new_tokens - 2),
]
)
assert torch.allclose(outputs, ref_outputs), f"Outputs are not identical to HF"

Loading…
Cancel
Save