mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
73 lines
2.6 KiB
Python
73 lines
2.6 KiB
Python
"""Functions for colorizing black and white images"""
|
|
|
|
import logging
|
|
|
|
from PIL import Image, ImageEnhance, ImageStat
|
|
|
|
from imaginairy.api import imagine
|
|
from imaginairy.enhancers.describe_image_blip import generate_caption
|
|
from imaginairy.schema import ControlInput, ImaginePrompt
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def colorize_img(img, max_width=1024, max_height=1024, caption=None):
|
|
if not caption:
|
|
caption = generate_caption(img, min_length=10)
|
|
caption = caption.replace("black and white", "color")
|
|
caption = caption.replace("old picture", "professional color photo")
|
|
caption = caption.replace("vintage photograph", "professional color photo")
|
|
caption = caption.replace("old photo", "professional color photo")
|
|
caption = caption.replace("vintage photo", "professional color photo")
|
|
caption = caption.replace("old color", "color")
|
|
caption = caption.replace(" old fashioned ", " ")
|
|
caption = caption.replace(" old time ", " ")
|
|
caption = caption.replace(" old ", " ")
|
|
logger.info(caption)
|
|
control_inputs = [
|
|
ControlInput(mode="colorize", image=img, strength=2),
|
|
]
|
|
prompt_add = ". color photo, sharp-focus, highly detailed, intricate, Canon 5D"
|
|
prompt = ImaginePrompt(
|
|
prompt=f"{caption}{prompt_add}",
|
|
init_image=img,
|
|
init_image_strength=0.0,
|
|
control_inputs=control_inputs,
|
|
size=(min(img.width, max_width), min(img.height, max_height)),
|
|
steps=30,
|
|
prompt_strength=12,
|
|
)
|
|
result = next(iter(imagine(prompt)))
|
|
colorized_img = replace_color(img, result.images["generated"])
|
|
|
|
# allows the algorithm some leeway for the overall brightness of the image
|
|
# results look better with this
|
|
colorized_img = match_brightness(colorized_img, result.images["generated"])
|
|
|
|
return colorized_img
|
|
|
|
|
|
def replace_color(target_img, color_src_img):
|
|
color_src_img = color_src_img.resize(target_img.size)
|
|
|
|
_, _, value = target_img.convert("HSV").split()
|
|
hue, saturation, _ = color_src_img.convert("HSV").split()
|
|
|
|
return Image.merge("HSV", (hue, saturation, value)).convert("RGB")
|
|
|
|
|
|
def calculate_brightness(image):
|
|
greyscale_image = image.convert("L")
|
|
stat = ImageStat.Stat(greyscale_image)
|
|
return stat.mean[0]
|
|
|
|
|
|
def match_brightness(target_img, source_img):
|
|
target_brightness = calculate_brightness(target_img)
|
|
source_brightness = calculate_brightness(source_img)
|
|
|
|
brightness_factor = source_brightness / target_brightness
|
|
|
|
enhancer = ImageEnhance.Brightness(target_img)
|
|
return enhancer.enhance(brightness_factor)
|