Compare commits

..

No commits in common. '5ca2a031f53b72910c7578d0da59083a533181ea' and '5f91793f8ce429af06c9e66ba261188b5aebe125' have entirely different histories.

@ -38,7 +38,7 @@ class RemotePastKeyValues(Cache):
self.seen_tokens += new_seen
def reorder_cache(self, beam_idx):
raise NotImplementedError("Beam search reordering is not implemented yet")
pass
_skipped_tokens = ContextVar("skipped_tokens", default=0)

@ -135,7 +135,9 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
return self.norm
class DistributedMixtralForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM):
class DistributedMixtralForCausalLM(
DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM
):
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected
@ -157,12 +159,9 @@ class DistributedMixtralForCausalLM(FromPretrainedMixin, RemoteGenerationMixin,
return self.model
class DistributedMixtralForSequenceClassification(FromPretrainedMixin, MixtralForSequenceClassification):
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected
config_class = DistributedMixtralConfig
class DistributedMixtralForSequenceClassification(
DefaultRevisionMixin, FromPretrainedMixin, MixtralForSequenceClassification
):
def __init__(self, config: DistributedMixtralConfig):
MixtralPreTrainedModel.__init__(self, config)
self.num_labels = config.num_labels

@ -53,7 +53,7 @@ def get_block_size(
return round(n_params * bytes_per_value * (1 + eps))
def get_model_block(config, layer_idx: int = 0):
def get_model_block(config, **kwargs):
"""
The function to create a model block based on the block class
kwargs argument **only** is necessary for specific classes, like Mixtral.
@ -61,5 +61,5 @@ def get_model_block(config, layer_idx: int = 0):
"""
if config.block_class == WrappedMixtralBlock:
config = PreTrainedModel._autoset_attn_implementation(config)
return config.block_class(config, layer_idx)
return config.block_class(config, kwargs.get("layer_idx", 0))
return config.block_class(config)

@ -210,12 +210,12 @@ def measure_compute_rps(
elapsed = 0
dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
# Skip the 1st step to exclude the initialization time
def step(cache_):
outputs = block.forward(dummy_input, use_cache=inference, layer_past=cache_ if inference else None)
return outputs[1] if inference else None
cache = step(cache)
# Skip the 1st step to exclude the initialization time
synchronize(device)
start_time = time.perf_counter()

@ -142,8 +142,7 @@ def test_sampling(tokenizer, model, ref_model, max_new_tokens=10):
@pytest.mark.skipif(
"bloom" not in MODEL_NAME.lower(),
reason="Mixtral and Llama use DynamicCache, which can change based on beam search choices",
MODEL_NAME.lower().find("mixtral"), reason="Mixtral use DynamicCache, that can change based on BS choices"
)
@pytest.mark.forked
def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5):

Loading…
Cancel
Save