feature: better depth maps

pull/421/head
Bryce 6 months ago committed by Bryce Drennan
parent 2eee741b20
commit 9a0e0cd1a7

@ -43,7 +43,7 @@ def create_depth_map(img: "Tensor") -> "Tensor":
orig_size = img.shape[2:]
depth_pt = _create_depth_map_raw(img)
depth_pt = _create_depth_map_raw(img, max_size=1024)
# copy the depth map to the other channels
depth_pt = torch.cat([depth_pt, depth_pt, depth_pt], dim=0)
@ -61,14 +61,13 @@ def create_depth_map(img: "Tensor") -> "Tensor":
return depth_pt
def _create_depth_map_raw(img: "Tensor") -> "Tensor":
def _create_depth_map_raw(img: "Tensor", max_size=512) -> "Tensor":
import torch
from imaginairy.modules.midas.api import MiDaSInference, midas_device
model = MiDaSInference(model_type="dpt_hybrid").to(midas_device())
img = img.to(midas_device())
max_size = 512
# calculate new size such that image fits within 512x512 but keeps aspect ratio
if img.shape[2] > img.shape[3]:

Loading…
Cancel
Save