|
|
|
@ -7,7 +7,7 @@ import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mask_tile(tile, overlap, std_overlap, side="bottom"):
|
|
|
|
|
h, w = tile.size(2), tile.size(3)
|
|
|
|
|
b, c, h, w = tile.shape
|
|
|
|
|
top_overlap, bottom_overlap, right_overlap, left_overlap = overlap
|
|
|
|
|
(
|
|
|
|
|
std_top_overlap,
|
|
|
|
@ -25,7 +25,7 @@ def mask_tile(tile, overlap, std_overlap, side="bottom"):
|
|
|
|
|
lin_mask_left = (
|
|
|
|
|
torch.cat([zeros_mask, lin_mask_left], 0)
|
|
|
|
|
.repeat(h, 1)
|
|
|
|
|
.repeat(3, 1, 1)
|
|
|
|
|
.repeat(c, 1, 1)
|
|
|
|
|
.unsqueeze(0)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -33,7 +33,7 @@ def mask_tile(tile, overlap, std_overlap, side="bottom"):
|
|
|
|
|
lin_mask_right = (
|
|
|
|
|
torch.linspace(1, 0, right_overlap, device=tile.device)
|
|
|
|
|
.repeat(h, 1)
|
|
|
|
|
.repeat(3, 1, 1)
|
|
|
|
|
.repeat(c, 1, 1)
|
|
|
|
|
.unsqueeze(0)
|
|
|
|
|
)
|
|
|
|
|
if "top" in side:
|
|
|
|
@ -41,14 +41,14 @@ def mask_tile(tile, overlap, std_overlap, side="bottom"):
|
|
|
|
|
if top_overlap > std_top_overlap:
|
|
|
|
|
zeros_mask = torch.zeros(top_overlap - std_top_overlap, device=tile.device)
|
|
|
|
|
lin_mask_top = torch.cat([zeros_mask, lin_mask_top], 0)
|
|
|
|
|
lin_mask_top = lin_mask_top.repeat(w, 1).rot90(3).repeat(3, 1, 1).unsqueeze(0)
|
|
|
|
|
lin_mask_top = lin_mask_top.repeat(w, 1).rot90(3).repeat(c, 1, 1).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
if "bottom" in side:
|
|
|
|
|
lin_mask_bottom = (
|
|
|
|
|
torch.linspace(1, 0, std_bottom_overlap, device=tile.device)
|
|
|
|
|
.repeat(w, 1)
|
|
|
|
|
.rot90(3)
|
|
|
|
|
.repeat(3, 1, 1)
|
|
|
|
|
.repeat(c, 1, 1)
|
|
|
|
|
.unsqueeze(0)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|