From a26559ff654a6b6cb450016aacdca37348cc6d3d Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Wed, 30 Aug 2023 06:59:33 +0400 Subject: [PATCH] Fix `.generate(input_ids=...)` (#485) --- src/petals/client/remote_generation.py | 6 +++--- tests/test_full_model.py | 20 ++++++++++++++++++++ tests/test_remote_sequential.py | 2 +- 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/src/petals/client/remote_generation.py b/src/petals/client/remote_generation.py index 793b573..e392b4f 100644 --- a/src/petals/client/remote_generation.py +++ b/src/petals/client/remote_generation.py @@ -69,6 +69,8 @@ class RemoteGenerationMixin(_SkipTokensMixin): self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs ): self._fix_generate_kwargs(kwargs) + if inputs is None: + inputs = kwargs.pop("input_ids", None) if session is not None: # If a session specified explicitly, use it @@ -125,7 +127,7 @@ class RemoteGenerationMixin(_SkipTokensMixin): return result @staticmethod - def _fix_generate_kwargs(kwargs: dict) -> dict: + def _fix_generate_kwargs(kwargs: dict): # Suppress inappropriate "Both max_new_tokens and max_length" HF warning if "max_length" in kwargs and kwargs["max_length"] is None: del kwargs["max_length"] @@ -135,8 +137,6 @@ class RemoteGenerationMixin(_SkipTokensMixin): if isinstance(do_sample, int): kwargs["do_sample"] = bool(do_sample) - return kwargs - @staticmethod def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues: return dataclasses.replace(past_key_values, hypo_ids=beam_idx) diff --git a/tests/test_full_model.py b/tests/test_full_model.py index fafbd62..bbe6f08 100644 --- a/tests/test_full_model.py +++ b/tests/test_full_model.py @@ -149,3 +149,23 @@ def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, n outputs = make_generate_calls(model, inputs, **options) ref_outputs = ref_model.generate(inputs, **options) assert torch.allclose(outputs, ref_outputs), f"Beam search results are not identical to HF" + + +@pytest.mark.forked +def test_input_ids(tokenizer, model, ref_model, max_new_tokens=4): + inputs = tokenizer("A cat sat on a mat", return_tensors="pt") + assert inputs.keys() == {"input_ids", "attention_mask"} + + outputs = model.generate(**inputs, max_new_tokens=max_new_tokens) + ref_outputs = ref_model.generate(**inputs, max_new_tokens=max_new_tokens) + assert torch.allclose(outputs, ref_outputs), f"Outputs are not identical to HF" + + with model.inference_session(max_length=inputs["input_ids"].shape[1] + max_new_tokens): + outputs = torch.cat( + [ + model.generate(**inputs, max_new_tokens=2), + model.generate(None, max_new_tokens=max_new_tokens - 2), + ], + dim=1, + ) + assert torch.allclose(outputs, ref_outputs), f"Multi-call outputs are not identical to HF" diff --git a/tests/test_remote_sequential.py b/tests/test_remote_sequential.py index 533ba73..20c6011 100644 --- a/tests/test_remote_sequential.py +++ b/tests/test_remote_sequential.py @@ -126,6 +126,6 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3): (outputs_ref * output_proj).sum().backward() assert input_prompts_ref.grad is not None - assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=1e-2) + assert torch.allclose(input_prompts_ref.grad, input_prompts.grad, atol=3e-2) assert intermediate_prompts_ref.grad is not None assert torch.allclose(intermediate_prompts_ref.grad, intermediate_prompts.grad, atol=1e-2)