fix bloom

pull/553/head
Artem Chumachenko 4 months ago
parent 08bbbd38f0
commit 81a5e70c89

@ -91,7 +91,7 @@ class TransformerBackend(ModuleBackend):
cache_tensors = []
for device, num_heads in zip(self.module.devices, self.shard_num_heads):
num_heads //= self.config.num_key_value_groups
if self.config.num_key_value_heads is not None:
if hasattr(self.config, "num_key_value_heads"):
num_heads = self.config.num_key_value_heads
keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device)
values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device)

Loading…
Cancel
Save