|
|
|
@ -16,10 +16,8 @@ from transformers.models.falcon.modeling_falcon import (
|
|
|
|
|
FalconDecoderLayer,
|
|
|
|
|
FalconLinear,
|
|
|
|
|
FalconMLP,
|
|
|
|
|
FalconModel,
|
|
|
|
|
FalconRotaryEmbedding,
|
|
|
|
|
LayerNorm,
|
|
|
|
|
build_alibi_tensor,
|
|
|
|
|
dropout_add,
|
|
|
|
|
rotate_half,
|
|
|
|
|
)
|
|
|
|
@ -61,11 +59,26 @@ class OptimizedFalconRotaryEmbedding(FalconRotaryEmbedding):
|
|
|
|
|
return tuple(o.detach() for o in self.static_outputs)
|
|
|
|
|
|
|
|
|
|
def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
|
|
|
|
|
total_length = seq_len + past_key_values_length
|
|
|
|
|
if self.seq_len_cached == -1:
|
|
|
|
|
# warm up the cache
|
|
|
|
|
super().cos_sin(1, INFERENCE_MAX_LENGTH - 1, device=device, dtype=dtype)
|
|
|
|
|
return super().cos_sin(
|
|
|
|
|
seq_len=seq_len, past_key_values_length=past_key_values_length, device=device, dtype=dtype
|
|
|
|
|
total_length = max(INFERENCE_MAX_LENGTH, total_length)
|
|
|
|
|
|
|
|
|
|
if total_length > self.seq_len_cached:
|
|
|
|
|
self.seq_len_cached = total_length
|
|
|
|
|
t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
|
|
|
|
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
|
|
|
|
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
|
|
|
|
|
|
|
|
|
if dtype in [torch.float16, torch.bfloat16]:
|
|
|
|
|
emb = emb.float()
|
|
|
|
|
|
|
|
|
|
self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype))
|
|
|
|
|
self.register_buffer("sin_cached", emb.cos()[None, :, :].type(dtype))
|
|
|
|
|
|
|
|
|
|
return (
|
|
|
|
|
self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
|
|
|
|
|
self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward(self, query, key, past_key_values_length=0):
|
|
|
|
@ -372,86 +385,3 @@ class WrappedFalconBlock(OptimizedFalconDecoderLayer):
|
|
|
|
|
state = state[:, :, 0]
|
|
|
|
|
state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
|
|
|
|
|
return state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UnoptimizedWrappedFalconBlock(FalconDecoderLayer):
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
|
*args,
|
|
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
|
alibi: Optional[torch.Tensor] = None,
|
|
|
|
|
layer_past: Optional[KVCache] = None,
|
|
|
|
|
use_cache: bool = False,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
batch_size, seq_length = hidden_states.shape[:2]
|
|
|
|
|
|
|
|
|
|
if layer_past is not None:
|
|
|
|
|
layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)
|
|
|
|
|
past_length = 0 if layer_past is None else layer_past[0].shape[1]
|
|
|
|
|
seq_length_with_past = seq_length + past_length
|
|
|
|
|
|
|
|
|
|
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
|
|
|
|
if alibi is None and self.config.alibi:
|
|
|
|
|
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
|
|
|
|
|
attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
|
|
|
|
|
|
|
|
|
|
outputs = super().forward(
|
|
|
|
|
hidden_states,
|
|
|
|
|
*args,
|
|
|
|
|
attention_mask=attention_mask,
|
|
|
|
|
alibi=alibi,
|
|
|
|
|
layer_past=layer_past,
|
|
|
|
|
use_cache=use_cache,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if use_cache:
|
|
|
|
|
present_key_value = outputs[-1]
|
|
|
|
|
present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)
|
|
|
|
|
outputs = outputs[:-1] + (present_key_value,)
|
|
|
|
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:
|
|
|
|
|
key_states, value_states = key_value
|
|
|
|
|
|
|
|
|
|
key_states = key_states.permute(0, 2, 1)
|
|
|
|
|
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
|
|
|
|
|
|
|
|
|
|
if self.config.new_decoder_architecture:
|
|
|
|
|
key_states = self._expand_states(key_states)
|
|
|
|
|
value_states = self._expand_states(value_states)
|
|
|
|
|
|
|
|
|
|
return (key_states, value_states)
|
|
|
|
|
|
|
|
|
|
def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:
|
|
|
|
|
key_states, value_states = key_value
|
|
|
|
|
|
|
|
|
|
if self.config.new_decoder_architecture:
|
|
|
|
|
key_states = self._collapse_states(key_states)
|
|
|
|
|
value_states = self._collapse_states(value_states)
|
|
|
|
|
|
|
|
|
|
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
|
|
|
|
|
key_states = key_states.permute(0, 2, 1)
|
|
|
|
|
|
|
|
|
|
return (key_states, value_states)
|
|
|
|
|
|
|
|
|
|
def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
|
|
|
|
|
batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads
|
|
|
|
|
|
|
|
|
|
state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)
|
|
|
|
|
state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1) # No copy
|
|
|
|
|
state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim) # Involves a copy
|
|
|
|
|
return state
|
|
|
|
|
|
|
|
|
|
def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
|
|
|
|
|
batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads
|
|
|
|
|
|
|
|
|
|
state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)
|
|
|
|
|
state = state[:, :, 0]
|
|
|
|
|
state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
|
|
|
|
|
return state
|
|
|
|
|