fix: use correct device for depth images on mps

pull/320/head
Bryce 1 year ago committed by Bryce Drennan
parent 17f3541d57
commit 3258af7e02

@ -15,6 +15,7 @@ from imaginairy.modules.midas.midas.transforms import (
PrepareForNet,
Resize,
)
from imaginairy.utils import get_device
ISL_PATHS = {
"dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
@ -181,7 +182,7 @@ def torch_image_to_depth_map(image_t: torch.Tensor, model_type="dpt_hybrid"):
depth_max = torch.amax(depth_t, dim=[1, 2, 3], keepdim=True)
depth_t = (depth_t - depth_min) / (depth_max - depth_min)
return depth_t
return depth_t.to(get_device())
class MiDaSInference(nn.Module):

Loading…
Cancel
Save