pull/570/head
Artem Chumachenko 2 months ago
parent 6a1596b633
commit e3539c0759

@ -53,7 +53,7 @@ def get_block_size(
return round(n_params * bytes_per_value * (1 + eps)) return round(n_params * bytes_per_value * (1 + eps))
def get_model_block(config, **kwargs): def get_model_block(config, layer_idx: int = 0):
""" """
The function to create a model block based on the block class The function to create a model block based on the block class
kwargs argument **only** is necessary for specific classes, like Mixtral. kwargs argument **only** is necessary for specific classes, like Mixtral.
@ -61,5 +61,5 @@ def get_model_block(config, **kwargs):
""" """
if config.block_class == WrappedMixtralBlock: if config.block_class == WrappedMixtralBlock:
config = PreTrainedModel._autoset_attn_implementation(config) config = PreTrainedModel._autoset_attn_implementation(config)
return config.block_class(config, kwargs.get("layer_idx", 0)) return config.block_class(config, layer_idx)
return config.block_class(config) return config.block_class(config)

@ -210,12 +210,12 @@ def measure_compute_rps(
elapsed = 0 elapsed = 0
dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype) dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
# Skip the 1st step to exclude the initialization time
def step(cache_): def step(cache_):
outputs = block.forward(dummy_input, use_cache=inference, layer_past=cache_ if inference else None) outputs = block.forward(dummy_input, use_cache=inference, layer_past=cache_ if inference else None)
return outputs[1] if inference else None return outputs[1] if inference else None
cache = step(cache) cache = step(cache)
# Skip the 1st step to exclude the initialization time
synchronize(device) synchronize(device)
start_time = time.perf_counter() start_time = time.perf_counter()

@ -142,7 +142,8 @@ def test_sampling(tokenizer, model, ref_model, max_new_tokens=10):
@pytest.mark.skipif( @pytest.mark.skipif(
MODEL_NAME.lower().find("mixtral"), reason="Mixtral uses DynamicCache, which can change based on beam search choices" MODEL_NAME.lower().find("mixtral") > -1,
reason="Mixtral uses DynamicCache, which can change based on beam search choices",
) )
@pytest.mark.forked @pytest.mark.forked
def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5): def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5):

Loading…
Cancel
Save