|
|
|
@ -91,7 +91,7 @@ class TransformerBackend(ModuleBackend):
|
|
|
|
|
cache_tensors = []
|
|
|
|
|
for device, num_heads in zip(self.module.devices, self.shard_num_heads):
|
|
|
|
|
num_heads //= self.config.num_key_value_groups
|
|
|
|
|
if self.config.num_key_value_heads is not None:
|
|
|
|
|
if hasattr(self.config, "num_key_value_heads"):
|
|
|
|
|
num_heads = self.config.num_key_value_heads
|
|
|
|
|
keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device)
|
|
|
|
|
values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device)
|
|
|
|
|