From b4d822afb275deccd32b1b26bda46b80a3719467 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Sun, 3 Sep 2023 01:16:00 +0400 Subject: [PATCH] Force use_cache=True in config only (#497) This reverts a part of #496 and instead overrides `use_cache` in `LlamaConfig`s only (so the correct value is visible by HF `.generate()` as well). --- src/petals/models/bloom/model.py | 3 ++- src/petals/models/llama/config.py | 1 + src/petals/models/llama/model.py | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/petals/models/bloom/model.py b/src/petals/models/bloom/model.py index 53e4a98..cf83822 100644 --- a/src/petals/models/bloom/model.py +++ b/src/petals/models/bloom/model.py @@ -43,7 +43,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel): attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, # Not used here but needed for HF Transformers compatibility + use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -63,6 +63,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel): attention_mask is None or (attention_mask == 1).all() ), f"Custom attention masks are not supported, {attention_mask=}" assert head_mask is None, f"Custom head masks are not supported, {head_mask=}" + assert use_cache is None or use_cache, f"{use_cache=} is not supported" assert not output_attentions, f"{output_attentions=} is not supported" assert not output_hidden_states, f"{output_hidden_states=} is not supported" assert return_dict is None or return_dict, f"{return_dict=} is not supported" diff --git a/src/petals/models/llama/config.py b/src/petals/models/llama/config.py index 43a5843..ae71a4c 100644 --- a/src/petals/models/llama/config.py +++ b/src/petals/models/llama/config.py @@ -43,4 +43,5 @@ class DistributedLlamaConfig(LlamaConfig, ClientConfig, PTuneConfig, LMHeadConfi result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs) config = result[0] if isinstance(result, tuple) else result config.pretraining_tp = 1 # This may give less accurate results but it doesn't matter if we use quantization + config.use_cache = True # use_cache=False leads to identical results but is slower and not supported by Petals return result diff --git a/src/petals/models/llama/model.py b/src/petals/models/llama/model.py index cf7d150..a9dfcc1 100644 --- a/src/petals/models/llama/model.py +++ b/src/petals/models/llama/model.py @@ -43,7 +43,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[RemotePastKeyValues] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, # Not used here but needed for HF Transformers compatibility + use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -65,6 +65,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel): assert ( position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all() ), f"Non-consecutive position_ids are not supported, {position_ids=}" + assert use_cache is None or use_cache, f"{use_cache=} is not supported" assert not output_attentions, f"{output_attentions=} is not supported" assert not output_hidden_states, f"{output_hidden_states=} is not supported" assert return_dict is None or return_dict, f"{return_dict=} is not supported"