feature: better depth maps

This commit is contained in:
Bryce 2023-12-17 22:40:27 -08:00 committed by Bryce Drennan
parent 2eee741b20
commit 9a0e0cd1a7

View File

@ -43,7 +43,7 @@ def create_depth_map(img: "Tensor") -> "Tensor":
orig_size = img.shape[2:]
depth_pt = _create_depth_map_raw(img)
depth_pt = _create_depth_map_raw(img, max_size=1024)
# copy the depth map to the other channels
depth_pt = torch.cat([depth_pt, depth_pt, depth_pt], dim=0)
@ -61,14 +61,13 @@ def create_depth_map(img: "Tensor") -> "Tensor":
return depth_pt
def _create_depth_map_raw(img: "Tensor") -> "Tensor":
def _create_depth_map_raw(img: "Tensor", max_size=512) -> "Tensor":
import torch
from imaginairy.modules.midas.api import MiDaSInference, midas_device
model = MiDaSInference(model_type="dpt_hybrid").to(midas_device())
img = img.to(midas_device())
max_size = 512
# calculate new size such that image fits within 512x512 but keeps aspect ratio
if img.shape[2] > img.shape[3]: