Fix optimizations, test LLaMA properly

pull/513/head
Max Ryabinin 8 months ago
parent 78c7f58200
commit 94b052db21

@ -144,22 +144,12 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
)
return self.pre_attn_graph(hidden_states)
def _post_attn(self, residual, hidden_states):
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
def _optimized_post_attn(self, residual, hidden_states):
def _optimized_output_layernorm(self, hidden_states):
if self.post_attn_graph is None:
self.post_attn_graph = make_inference_graphed_callable(
self._post_attn, sample_args=(residual, hidden_states)
self.post_attention_layernorm.forward, sample_args=(hidden_states,)
)
return self.post_attn_graph(residual, hidden_states)
return self.post_attn_graph(hidden_states)
def forward(
self,
@ -201,10 +191,18 @@ class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
hidden_states = self._optimized_post_attn(residual, hidden_states)
hidden_states = self._optimized_output_layernorm(hidden_states)
else:
hidden_states = self._post_attn(residual, hidden_states)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
@ -239,6 +237,8 @@ class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
seq_length_with_past = seq_length_with_past + past_key_values_length
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
assert position_ids is None
# embed positions
if attention_mask is None:
attention_mask = torch.ones(

@ -178,12 +178,12 @@ class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer):
@pytest.mark.skipif(
all(model_name not in MODEL_NAME for model_name in ("falcon", "llama")),
reason="This test is applicable only to Falcon and LLaMa models",
all(model_name not in MODEL_NAME.lower() for model_name in ("falcon", "llama")),
reason="This test is applicable only to Falcon and LLaMA models",
)
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
@pytest.mark.forked
def test_falcon(device):
def test_optimized_block(device):
if device == "cuda:0" and not torch.cuda.is_available():
pytest.skip("CUDA tests can be run only in CUDA-enabled setups")
@ -194,15 +194,15 @@ def test_falcon(device):
quant_type = QuantType.NONE
block = config.block_class(config).to(dtype)
block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
block = convert_block(block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
if "falcon" in MODEL_NAME:
if "falcon" in MODEL_NAME.lower():
unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
elif "llama" in MODEL_NAME:
elif "llama" in MODEL_NAME.lower():
unopt_block = UnoptimizedWrappedLlamaBlock(config).to(dtype)
unopt_block = convert_block(
unopt_block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
unopt_block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
)
unopt_block.load_state_dict(block.state_dict())

Loading…
Cancel
Save