Compare commits

...

7 Commits

Author SHA1 Message Date
Artem Chumachenko 5ca2a031f5 Comments 1 month ago
Artem Chumachenko 63a421bad9 Add llama to non-bs mix 1 month ago
Artem Chumachenko e3539c0759 Comments 1 month ago
Artem Chumachenko 6a1596b633
Update tests/test_full_model.py
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
1 month ago
Artem Chumachenko 1b4bb1a743 Another fix 1 month ago
Artem Chumachenko 16d97fcbce Fix benchmarks 1 month ago
Artem Chumachenko a447cbe16d Add assert about BS 1 month ago

@ -38,7 +38,7 @@ class RemotePastKeyValues(Cache):
self.seen_tokens += new_seen
def reorder_cache(self, beam_idx):
pass
raise NotImplementedError("Beam search reordering is not implemented yet")
_skipped_tokens = ContextVar("skipped_tokens", default=0)

@ -135,9 +135,7 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
return self.norm
class DistributedMixtralForCausalLM(
DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM
):
class DistributedMixtralForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM):
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected
@ -159,9 +157,12 @@ class DistributedMixtralForCausalLM(
return self.model
class DistributedMixtralForSequenceClassification(
DefaultRevisionMixin, FromPretrainedMixin, MixtralForSequenceClassification
):
class DistributedMixtralForSequenceClassification(FromPretrainedMixin, MixtralForSequenceClassification):
_keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
_keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected
config_class = DistributedMixtralConfig
def __init__(self, config: DistributedMixtralConfig):
MixtralPreTrainedModel.__init__(self, config)
self.num_labels = config.num_labels

@ -53,7 +53,7 @@ def get_block_size(
return round(n_params * bytes_per_value * (1 + eps))
def get_model_block(config, **kwargs):
def get_model_block(config, layer_idx: int = 0):
"""
The function to create a model block based on the block class
kwargs argument **only** is necessary for specific classes, like Mixtral.
@ -61,5 +61,5 @@ def get_model_block(config, **kwargs):
"""
if config.block_class == WrappedMixtralBlock:
config = PreTrainedModel._autoset_attn_implementation(config)
return config.block_class(config, kwargs.get("layer_idx", 0))
return config.block_class(config, layer_idx)
return config.block_class(config)

@ -210,12 +210,12 @@ def measure_compute_rps(
elapsed = 0
dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
# Skip the 1st step to exclude the initialization time
def step(cache_):
outputs = block.forward(dummy_input, use_cache=inference, layer_past=cache_ if inference else None)
return outputs[1] if inference else None
cache = step(cache)
# Skip the 1st step to exclude the initialization time
synchronize(device)
start_time = time.perf_counter()

@ -142,7 +142,8 @@ def test_sampling(tokenizer, model, ref_model, max_new_tokens=10):
@pytest.mark.skipif(
MODEL_NAME.lower().find("mixtral"), reason="Mixtral use DynamicCache, that can change based on BS choices"
"bloom" not in MODEL_NAME.lower(),
reason="Mixtral and Llama use DynamicCache, which can change based on beam search choices",
)
@pytest.mark.forked
def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5):

Loading…
Cancel
Save