pull/570/head
Artem Chumachenko 1 month ago
parent 6a1596b633
commit e3539c0759

@ -53,7 +53,7 @@ def get_block_size(
return round(n_params * bytes_per_value * (1 + eps))
def get_model_block(config, **kwargs):
def get_model_block(config, layer_idx: int = 0):
"""
The function to create a model block based on the block class
kwargs argument **only** is necessary for specific classes, like Mixtral.
@ -61,5 +61,5 @@ def get_model_block(config, **kwargs):
"""
if config.block_class == WrappedMixtralBlock:
config = PreTrainedModel._autoset_attn_implementation(config)
return config.block_class(config, kwargs.get("layer_idx", 0))
return config.block_class(config, layer_idx)
return config.block_class(config)

@ -210,12 +210,12 @@ def measure_compute_rps(
elapsed = 0
dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
# Skip the 1st step to exclude the initialization time
def step(cache_):
outputs = block.forward(dummy_input, use_cache=inference, layer_past=cache_ if inference else None)
return outputs[1] if inference else None
cache = step(cache)
# Skip the 1st step to exclude the initialization time
synchronize(device)
start_time = time.perf_counter()

@ -142,7 +142,8 @@ def test_sampling(tokenizer, model, ref_model, max_new_tokens=10):
@pytest.mark.skipif(
MODEL_NAME.lower().find("mixtral"), reason="Mixtral uses DynamicCache, which can change based on beam search choices"
MODEL_NAME.lower().find("mixtral") > -1,
reason="Mixtral uses DynamicCache, which can change based on beam search choices",
)
@pytest.mark.forked
def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5):

Loading…
Cancel
Save