|
|
|
@ -99,7 +99,7 @@ class TransformerBackend(ModuleBackend):
|
|
|
|
|
"""If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids"""
|
|
|
|
|
if not is_dummy(hypo_ids):
|
|
|
|
|
for cache_tensor in cache_tensors:
|
|
|
|
|
cache_tensor[...] = cache_tensor[hypo_ids] # in-place reorder cache by hypo ids
|
|
|
|
|
cache_tensor[...] = cache_tensor[hypo_ids.to(cache_tensor.device)] # in-place reorder cache by hypo ids
|
|
|
|
|
|
|
|
|
|
def _select_layer_past(self, cache_tensors: Sequence[torch.Tensor], prefix_length: int) -> Sequence[torch.Tensor]:
|
|
|
|
|
"""Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past"""
|
|
|
|
|