|
|
|
@ -239,6 +239,64 @@ def noop(img: "Tensor") -> "Tensor":
|
|
|
|
|
|
|
|
|
|
FunctionType = Union["Callable[[Tensor, Tensor], Tensor]", "Callable[[Tensor], Tensor]"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def adaptive_threshold_binarize(img: "Tensor") -> "Tensor":
|
|
|
|
|
"""
|
|
|
|
|
Use adaptive thresholding to binarize the image.
|
|
|
|
|
|
|
|
|
|
Using OpenCV for adaptive thresholding as it provides robust and efficient implementation.
|
|
|
|
|
The output tensor will have values between 0 and 1.
|
|
|
|
|
"""
|
|
|
|
|
import cv2
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
from imaginairy.utils import get_device
|
|
|
|
|
|
|
|
|
|
# img = img.to("cpu")
|
|
|
|
|
# img = img.to(get_device())
|
|
|
|
|
|
|
|
|
|
if img.dim() != 4:
|
|
|
|
|
raise ValueError("Input should be a 4d tensor")
|
|
|
|
|
if img.size(1) != 3:
|
|
|
|
|
raise ValueError("Input should have 3 channels")
|
|
|
|
|
|
|
|
|
|
if not torch.all((img >= -1) & (img <= 1)):
|
|
|
|
|
raise ValueError("All tensor values must be between -1 and 1")
|
|
|
|
|
|
|
|
|
|
normalized = (img + 1) / 2
|
|
|
|
|
|
|
|
|
|
# returns img if it is already grayscale
|
|
|
|
|
if torch.allclose(
|
|
|
|
|
normalized[:, 0, :, :], normalized[:, 1, :, :]
|
|
|
|
|
) and torch.allclose(normalized[:, 1, :, :], normalized[:, 2, :, :]):
|
|
|
|
|
return normalized
|
|
|
|
|
|
|
|
|
|
# grayscale = normalized.mean(dim=1, keepdim=True)
|
|
|
|
|
grayscale = to_grayscale(img)
|
|
|
|
|
grayscale = grayscale[:, 0:1, :, :]
|
|
|
|
|
|
|
|
|
|
grayscale_np = grayscale.squeeze(1).numpy()
|
|
|
|
|
|
|
|
|
|
blockSize = 129
|
|
|
|
|
C = 2
|
|
|
|
|
for i in range(grayscale_np.shape[0]):
|
|
|
|
|
grayscale_np[i] = cv2.adaptiveThreshold(
|
|
|
|
|
(grayscale_np[i] * 255).astype(np.uint8),
|
|
|
|
|
255,
|
|
|
|
|
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
|
|
|
|
|
cv2.THRESH_BINARY,
|
|
|
|
|
blockSize,
|
|
|
|
|
C,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
grayscale_np = grayscale_np / 255
|
|
|
|
|
|
|
|
|
|
binary = torch.from_numpy(grayscale_np).unsqueeze(1).to(get_device()).float()
|
|
|
|
|
|
|
|
|
|
return binary.repeat(1, 3, 1, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CONTROL_MODES: Dict[str, FunctionType] = {
|
|
|
|
|
"canny": create_canny_edges,
|
|
|
|
|
"depth": create_depth_map,
|
|
|
|
@ -252,4 +310,5 @@ CONTROL_MODES: Dict[str, FunctionType] = {
|
|
|
|
|
"inpaint": inpaint_prep,
|
|
|
|
|
# "details": noop,
|
|
|
|
|
"colorize": to_grayscale,
|
|
|
|
|
"qrcode": adaptive_threshold_binarize,
|
|
|
|
|
}
|
|
|
|
|