Fix TP crashing when hypo_ids are used (#249)

pull/235/head^2
Alexander Borzunov 1 year ago committed by GitHub
parent b8a6788490
commit 3c523ab0d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -99,7 +99,7 @@ class TransformerBackend(ModuleBackend):
"""If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids"""
if not is_dummy(hypo_ids):
for cache_tensor in cache_tensors:
cache_tensor[...] = cache_tensor[hypo_ids] # in-place reorder cache by hypo ids
cache_tensor[...] = cache_tensor[hypo_ids.to(cache_tensor.device)] # in-place reorder cache by hypo ids
def _select_layer_past(self, cache_tensors: Sequence[torch.Tensor], prefix_length: int) -> Sequence[torch.Tensor]:
"""Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past"""

Loading…
Cancel
Save