diff --git a/imaginairy/img_processors/control_modes.py b/imaginairy/img_processors/control_modes.py index 0cf37b9..988d3bd 100644 --- a/imaginairy/img_processors/control_modes.py +++ b/imaginairy/img_processors/control_modes.py @@ -43,7 +43,7 @@ def create_depth_map(img: "Tensor") -> "Tensor": orig_size = img.shape[2:] - depth_pt = _create_depth_map_raw(img) + depth_pt = _create_depth_map_raw(img, max_size=1024) # copy the depth map to the other channels depth_pt = torch.cat([depth_pt, depth_pt, depth_pt], dim=0) @@ -61,14 +61,13 @@ def create_depth_map(img: "Tensor") -> "Tensor": return depth_pt -def _create_depth_map_raw(img: "Tensor") -> "Tensor": +def _create_depth_map_raw(img: "Tensor", max_size=512) -> "Tensor": import torch from imaginairy.modules.midas.api import MiDaSInference, midas_device model = MiDaSInference(model_type="dpt_hybrid").to(midas_device()) img = img.to(midas_device()) - max_size = 512 # calculate new size such that image fits within 512x512 but keeps aspect ratio if img.shape[2] > img.shape[3]: