From 3c523ab0d2e1f16381724f6e6f288cc9158ae086 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 3 Feb 2023 01:04:19 +0600 Subject: [PATCH] Fix TP crashing when hypo_ids are used (#249) --- src/petals/server/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 81f3a33..cd8dce4 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -99,7 +99,7 @@ class TransformerBackend(ModuleBackend): """If hypo_ids is specified, reorder elements of each cache tensor in-place by taking indices from hypo_ids""" if not is_dummy(hypo_ids): for cache_tensor in cache_tensors: - cache_tensor[...] = cache_tensor[hypo_ids] # in-place reorder cache by hypo ids + cache_tensor[...] = cache_tensor[hypo_ids.to(cache_tensor.device)] # in-place reorder cache by hypo ids def _select_layer_past(self, cache_tensors: Sequence[torch.Tensor], prefix_length: int) -> Sequence[torch.Tensor]: """Extract first {prefix_length} tokens and reshape them such that they can be used as layer_past"""