mirror of https://github.com/kritiksoman/GIMP-ML
code update
parent
88ec58d101
commit
e170b42172
@ -1,2 +1,3 @@
|
||||
include gimpml/plugins/colorpalette/color_palette.png
|
||||
include gimpml/tools/model_info.csv
|
||||
|
||||
|
@ -1,17 +1,15 @@
|
||||
# from .kmeans import get_kmeans as kmeans
|
||||
# from .deblur import getdeblur as deblur
|
||||
# from .deepcolor import get_deepcolor as deepcolor
|
||||
# from .deepdehaze import get_dehaze as dehaze
|
||||
# from .deepdenoise import get_denoise as denoise
|
||||
# from .deepmatting import get_newalpha as matting
|
||||
# from .enlighten import get_enlighten as enlighten
|
||||
from .tools.kmeans import get_kmeans as kmeans
|
||||
from .tools.deblur import get_deblur as deblur
|
||||
from .tools.coloring import get_deepcolor as deepcolor
|
||||
from .tools.dehaze import get_dehaze as dehaze
|
||||
from .tools.denoise import get_denoise as denoise
|
||||
from .tools.matting import get_matting as matting
|
||||
from .tools.enlighten import get_enlighten as enlighten
|
||||
# from .facegen import get_newface as newface
|
||||
# from .faceparse import get_face as parseface
|
||||
# from .interpolateframes import get_inter as interpolateframe
|
||||
from .tools.faceparse import get_face as parseface
|
||||
from .tools.interpolation import get_inter as interpolateframe
|
||||
from .tools.monodepth import get_mono_depth as depth
|
||||
from .tools.complete_install import setup_python_weights
|
||||
# from .semseg import get_sem_seg as semseg
|
||||
# from .super_resolution import get_super as super
|
||||
# from .inpainting import get_inpaint as inpaint
|
||||
# from .syncWeights import sync as sync
|
||||
|
||||
from .tools.semseg import get_seg as semseg
|
||||
from .tools.superresolution import get_super as super
|
||||
from .tools.inpainting import get_inpaint as inpaint
|
||||
|
@ -0,0 +1,240 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding: utf-8
|
||||
"""
|
||||
.d8888b. 8888888 888b d888 8888888b. 888b d888 888
|
||||
d88P Y88b 888 8888b d8888 888 Y88b 8888b d8888 888
|
||||
888 888 888 88888b.d88888 888 888 88888b.d88888 888
|
||||
888 888 888Y88888P888 888 d88P 888Y88888P888 888
|
||||
888 88888 888 888 Y888P 888 8888888P" 888 Y888P 888 888
|
||||
888 888 888 888 Y8P 888 888 888 Y8P 888 888
|
||||
Y88b d88P 888 888 " 888 888 888 " 888 888
|
||||
"Y8888P88 8888888 888 888 888 888 888 88888888
|
||||
|
||||
|
||||
Extracts the monocular depth of the current layer.
|
||||
"""
|
||||
import sys
|
||||
import gi
|
||||
|
||||
gi.require_version('Gimp', '3.0')
|
||||
from gi.repository import Gimp
|
||||
|
||||
gi.require_version('GimpUi', '3.0')
|
||||
from gi.repository import GimpUi
|
||||
from gi.repository import GObject
|
||||
from gi.repository import GLib
|
||||
from gi.repository import Gio
|
||||
|
||||
gi.require_version('Gtk', '3.0')
|
||||
from gi.repository import Gtk
|
||||
|
||||
import gettext
|
||||
|
||||
_ = gettext.gettext
|
||||
|
||||
|
||||
def N_(message): return message
|
||||
|
||||
|
||||
import subprocess
|
||||
import pickle
|
||||
import os
|
||||
|
||||
|
||||
def coloring(procedure, image, n_drawables, drawables, force_cpu, progress_bar):
|
||||
# layers = Gimp.Image.get_selected_layers(image)
|
||||
# Gimp.get_pdb().run_procedure('gimp-message', [GObject.Value(GObject.TYPE_STRING, "Error")])
|
||||
config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "tools")
|
||||
with open(os.path.join(config_path, 'gimp_ml_config.pkl'), 'rb') as file:
|
||||
data_output = pickle.load(file)
|
||||
weight_path = data_output["weight_path"]
|
||||
python_path = data_output["python_path"]
|
||||
plugin_path = os.path.join(config_path, 'coloring.py')
|
||||
|
||||
Gimp.context_push()
|
||||
image.undo_group_start()
|
||||
|
||||
for index, drawable in enumerate(drawables):
|
||||
interlace, compression = 0, 2
|
||||
Gimp.get_pdb().run_procedure('file-png-save', [
|
||||
GObject.Value(Gimp.RunMode, Gimp.RunMode.NONINTERACTIVE),
|
||||
GObject.Value(Gimp.Image, image),
|
||||
GObject.Value(GObject.TYPE_INT, 1),
|
||||
GObject.Value(Gimp.ObjectArray, Gimp.ObjectArray.new(Gimp.Drawable, [drawable], 1)),
|
||||
GObject.Value(Gio.File,
|
||||
Gio.File.new_for_path(os.path.join(weight_path, '..', 'cache' + str(index) + '.png'))),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, interlace),
|
||||
GObject.Value(GObject.TYPE_INT, compression),
|
||||
# write all PNG chunks except oFFs(ets)
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, False),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
])
|
||||
|
||||
with open(os.path.join(weight_path, '..', 'gimp_ml_run.pkl'), 'wb') as file:
|
||||
pickle.dump({"force_cpu": bool(force_cpu)}, file)
|
||||
|
||||
subprocess.call([python_path, plugin_path])
|
||||
|
||||
result = Gimp.file_load(Gimp.RunMode.NONINTERACTIVE,
|
||||
Gio.file_new_for_path(os.path.join(weight_path, '..', 'cache.png')))
|
||||
result_layer = result.get_active_layer()
|
||||
copy = Gimp.Layer.new_from_drawable(result_layer, image)
|
||||
copy.set_name("Coloring")
|
||||
copy.set_mode(Gimp.LayerMode.NORMAL_LEGACY) # DIFFERENCE_LEGACY
|
||||
image.insert_layer(copy, None, -1)
|
||||
|
||||
image.undo_group_end()
|
||||
Gimp.context_pop()
|
||||
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.SUCCESS, GLib.Error())
|
||||
|
||||
|
||||
def run(procedure, run_mode, image, n_drawables, layer, args, data):
|
||||
# gio_file = args.index(0)
|
||||
# bucket_size = args.index(0)
|
||||
force_cpu = args.index(1)
|
||||
# output_format = args.index(2)
|
||||
|
||||
progress_bar = None
|
||||
config = None
|
||||
|
||||
if run_mode == Gimp.RunMode.INTERACTIVE:
|
||||
|
||||
config = procedure.create_config()
|
||||
|
||||
# Set properties from arguments. These properties will be changed by the UI.
|
||||
# config.set_property("file", gio_file)
|
||||
# config.set_property("bucket_size", bucket_size)
|
||||
config.set_property("force_cpu", force_cpu)
|
||||
# config.set_property("output_format", output_format)
|
||||
config.begin_run(image, run_mode, args)
|
||||
|
||||
GimpUi.init("coloring.py")
|
||||
use_header_bar = Gtk.Settings.get_default().get_property("gtk-dialogs-use-header")
|
||||
dialog = GimpUi.Dialog(use_header_bar=use_header_bar,
|
||||
title=_("Coloring..."))
|
||||
dialog.add_button("_Cancel", Gtk.ResponseType.CANCEL)
|
||||
dialog.add_button("_OK", Gtk.ResponseType.OK)
|
||||
|
||||
vbox = Gtk.Box(orientation=Gtk.Orientation.VERTICAL,
|
||||
homogeneous=False, spacing=10)
|
||||
dialog.get_content_area().add(vbox)
|
||||
vbox.show()
|
||||
|
||||
# Create grid to set all the properties inside.
|
||||
grid = Gtk.Grid()
|
||||
grid.set_column_homogeneous(False)
|
||||
grid.set_border_width(10)
|
||||
grid.set_column_spacing(10)
|
||||
grid.set_row_spacing(10)
|
||||
vbox.add(grid)
|
||||
grid.show()
|
||||
|
||||
# # Bucket size parameter
|
||||
# label = Gtk.Label.new_with_mnemonic(_("_Bucket Size"))
|
||||
# grid.attach(label, 0, 1, 1, 1)
|
||||
# label.show()
|
||||
# spin = GimpUi.prop_spin_button_new(config, "bucket_size", step_increment=0.001, page_increment=0.1, digits=3)
|
||||
# grid.attach(spin, 1, 1, 1, 1)
|
||||
# spin.show()
|
||||
|
||||
# Force CPU parameter
|
||||
spin = GimpUi.prop_check_button_new(config, "force_cpu", _("Force _CPU"))
|
||||
spin.set_tooltip_text(_("If checked, CPU is used for model inference."
|
||||
" Otherwise, GPU will be used if available."))
|
||||
grid.attach(spin, 1, 2, 1, 1)
|
||||
spin.show()
|
||||
|
||||
# # Output format parameter
|
||||
# label = Gtk.Label.new_with_mnemonic(_("_Output Format"))
|
||||
# grid.attach(label, 0, 3, 1, 1)
|
||||
# label.show()
|
||||
# combo = GimpUi.prop_string_combo_box_new(config, "output_format", output_format_enum.get_tree_model(), 0, 1)
|
||||
# grid.attach(combo, 1, 3, 1, 1)
|
||||
# combo.show()
|
||||
|
||||
progress_bar = Gtk.ProgressBar()
|
||||
vbox.add(progress_bar)
|
||||
progress_bar.show()
|
||||
|
||||
dialog.show()
|
||||
if dialog.run() != Gtk.ResponseType.OK:
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.CANCEL,
|
||||
GLib.Error())
|
||||
|
||||
result = coloring(procedure, image, n_drawables, layer, force_cpu, progress_bar)
|
||||
|
||||
# If the execution was successful, save parameters so they will be restored next time we show dialog.
|
||||
if result.index(0) == Gimp.PDBStatusType.SUCCESS and config is not None:
|
||||
config.end_run(Gimp.PDBStatusType.SUCCESS)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Coloring(Gimp.PlugIn):
|
||||
## Parameters ##
|
||||
__gproperties__ = {
|
||||
# "filename": (str,
|
||||
# # TODO: I wanted this property to be a path (and not just str) , so I could use
|
||||
# # prop_file_chooser_button_new to open a file dialog. However, it fails without an error message.
|
||||
# # Gimp.ConfigPath,
|
||||
# _("Histogram _File"),
|
||||
# _("Histogram _File"),
|
||||
# "coloring.csv",
|
||||
# # Gimp.ConfigPathType.FILE,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "file": (Gio.File,
|
||||
# _("Histogram _File"),
|
||||
# "Histogram export file",
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "bucket_size": (float,
|
||||
# _("_Bucket Size"),
|
||||
# "Bucket Size",
|
||||
# 0.001, 1.0, 0.01,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
"force_cpu": (bool,
|
||||
_("Force _CPU"),
|
||||
"Force CPU",
|
||||
False,
|
||||
GObject.ParamFlags.READWRITE),
|
||||
# "output_format": (str,
|
||||
# _("Output format"),
|
||||
# "Output format: 'pixel count', 'normalized', 'percent'",
|
||||
# "pixel count",
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
}
|
||||
|
||||
## GimpPlugIn virtual methods ##
|
||||
def do_query_procedures(self):
|
||||
self.set_translation_domain("gimp30-python",
|
||||
Gio.file_new_for_path(Gimp.locale_directory()))
|
||||
return ['coloring']
|
||||
|
||||
def do_create_procedure(self, name):
|
||||
procedure = None
|
||||
if name == 'coloring':
|
||||
procedure = Gimp.ImageProcedure.new(self, name, Gimp.PDBProcType.PLUGIN, run, None)
|
||||
procedure.set_image_types("*")
|
||||
procedure.set_sensitivity_mask(
|
||||
Gimp.ProcedureSensitivityMask.DRAWABLE | Gimp.ProcedureSensitivityMask.DRAWABLES)
|
||||
procedure.set_documentation(
|
||||
N_("Extracts the monocular depth of the current layer."),
|
||||
globals()["__doc__"], # This includes the docstring, on the top of the file
|
||||
name)
|
||||
procedure.set_menu_label(N_("_Coloring..."))
|
||||
procedure.set_attribution("Kritik Soman",
|
||||
"GIMP-ML",
|
||||
"2021")
|
||||
procedure.add_menu_path("<Image>/Layer/GIMP-ML/")
|
||||
|
||||
# procedure.add_argument_from_property(self, "file")
|
||||
# procedure.add_argument_from_property(self, "bucket_size")
|
||||
procedure.add_argument_from_property(self, "force_cpu")
|
||||
# procedure.add_argument_from_property(self, "output_format")
|
||||
|
||||
return procedure
|
||||
|
||||
|
||||
Gimp.main(Coloring.__gtype__, sys.argv)
|
Binary file not shown.
After Width: | Height: | Size: 34 KiB |
@ -0,0 +1,105 @@
|
||||
#!/usr/bin/env python3
|
||||
#coding: utf-8
|
||||
"""
|
||||
.d8888b. 8888888 888b d888 8888888b. 888b d888 888
|
||||
d88P Y88b 888 8888b d8888 888 Y88b 8888b d8888 888
|
||||
888 888 888 88888b.d88888 888 888 88888b.d88888 888
|
||||
888 888 888Y88888P888 888 d88P 888Y88888P888 888
|
||||
888 88888 888 888 Y888P 888 8888888P" 888 Y888P 888 888
|
||||
888 888 888 888 Y8P 888 888 888 Y8P 888 888
|
||||
Y88b d88P 888 888 " 888 888 888 " 888 888
|
||||
"Y8888P88 8888888 888 888 888 888 888 88888888
|
||||
|
||||
|
||||
Opens the color palette as a new image file in GIMP.
|
||||
"""
|
||||
import gi
|
||||
|
||||
gi.require_version('Gimp', '3.0')
|
||||
from gi.repository import Gimp
|
||||
from gi.repository import GObject
|
||||
from gi.repository import GLib
|
||||
from gi.repository import Gio
|
||||
import time
|
||||
import sys
|
||||
import os
|
||||
import gettext
|
||||
|
||||
_ = gettext.gettext
|
||||
|
||||
|
||||
def N_(message): return message
|
||||
|
||||
|
||||
def colorpalette(procedure, run_mode, image, n_drawables, drawable, args, data):
|
||||
image_new = Gimp.Image.new(1200, 675, 0) # 0 for RGB
|
||||
display = Gimp.Display.new(image_new)
|
||||
|
||||
|
||||
result = Gimp.file_load(Gimp.RunMode.NONINTERACTIVE, Gio.file_new_for_path(
|
||||
os.path.join(os.path.dirname(os.path.realpath(__file__)), 'color_palette.png')))
|
||||
result_layer = result.get_active_layer()
|
||||
copy = Gimp.Layer.new_from_drawable(result_layer, image_new)
|
||||
copy.set_name("Color Palette")
|
||||
copy.set_mode(Gimp.LayerMode.NORMAL_LEGACY)# DIFFERENCE_LEGACY
|
||||
image_new.insert_layer(copy, None, -1)
|
||||
Gimp.displays_flush()
|
||||
|
||||
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.SUCCESS, GLib.Error())
|
||||
|
||||
|
||||
class ColorPalette(Gimp.PlugIn):
|
||||
## Parameters ##
|
||||
__gproperties__ = {
|
||||
# "name": (str,
|
||||
# _("Layer name"),
|
||||
# _("Layer name"),
|
||||
# _("Clouds"),
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "color": (Gimp.RGB,
|
||||
# _("Fog color"),
|
||||
# _("Fog color"),
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "turbulence": (float,
|
||||
# _("Turbulence"),
|
||||
# _("Turbulence"),
|
||||
# 0.0, 10.0, 1.0,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "opacity": (float,
|
||||
# _("Opacity"),
|
||||
# _("Opacity"),
|
||||
# 0.0, 100.0, 100.0,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
}
|
||||
|
||||
## GimpPlugIn virtual methods ##
|
||||
def do_query_procedures(self):
|
||||
self.set_translation_domain("gimp30-python",
|
||||
Gio.file_new_for_path(Gimp.locale_directory()))
|
||||
|
||||
return ['colorpalette']
|
||||
|
||||
def do_create_procedure(self, name):
|
||||
procedure = Gimp.ImageProcedure.new(self, name,
|
||||
Gimp.PDBProcType.PLUGIN,
|
||||
colorpalette, None)
|
||||
procedure.set_image_types("RGB*, GRAY*");
|
||||
procedure.set_documentation(N_("Add a layer of fog"),
|
||||
"Adds a layer of fog to the image.",
|
||||
name)
|
||||
procedure.set_menu_label(N_("_Color Palette..."))
|
||||
procedure.set_attribution("Kritik Soman",
|
||||
"GIMP-ML",
|
||||
"2021")
|
||||
procedure.add_menu_path("<Image>/Layer/GIMP-ML/")
|
||||
|
||||
# procedure.add_argument_from_property(self, "name")
|
||||
# TODO: add support for GBoxed values.
|
||||
# procedure.add_argument_from_property(self, "color")
|
||||
# procedure.add_argument_from_property(self, "turbulence")
|
||||
# procedure.add_argument_from_property(self, "opacity")
|
||||
return procedure
|
||||
|
||||
|
||||
Gimp.main(ColorPalette.__gtype__, sys.argv)
|
@ -0,0 +1,227 @@
|
||||
#!/usr/bin/env python3
|
||||
#coding: utf-8
|
||||
"""
|
||||
.d8888b. 8888888 888b d888 8888888b. 888b d888 888
|
||||
d88P Y88b 888 8888b d8888 888 Y88b 8888b d8888 888
|
||||
888 888 888 88888b.d88888 888 888 88888b.d88888 888
|
||||
888 888 888Y88888P888 888 d88P 888Y88888P888 888
|
||||
888 88888 888 888 Y888P 888 8888888P" 888 Y888P 888 888
|
||||
888 888 888 888 Y8P 888 888 888 Y8P 888 888
|
||||
Y88b d88P 888 888 " 888 888 888 " 888 888
|
||||
"Y8888P88 8888888 888 888 888 888 888 88888888
|
||||
|
||||
|
||||
Deblur the current layer.
|
||||
"""
|
||||
import sys
|
||||
import gi
|
||||
gi.require_version('Gimp', '3.0')
|
||||
from gi.repository import Gimp
|
||||
gi.require_version('GimpUi', '3.0')
|
||||
from gi.repository import GimpUi
|
||||
from gi.repository import GObject
|
||||
from gi.repository import GLib
|
||||
from gi.repository import Gio
|
||||
gi.require_version('Gtk', '3.0')
|
||||
from gi.repository import Gtk
|
||||
|
||||
import gettext
|
||||
_ = gettext.gettext
|
||||
def N_(message): return message
|
||||
|
||||
import subprocess
|
||||
import pickle
|
||||
import os
|
||||
|
||||
def deblur(procedure, image, drawable, force_cpu, progress_bar):
|
||||
config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "tools")
|
||||
with open(os.path.join(config_path, 'gimp_ml_config.pkl'), 'rb') as file:
|
||||
data_output = pickle.load(file)
|
||||
weight_path = data_output["weight_path"]
|
||||
python_path = data_output["python_path"]
|
||||
plugin_path = os.path.join(config_path, 'deblur.py')
|
||||
|
||||
Gimp.context_push()
|
||||
image.undo_group_start()
|
||||
|
||||
interlace, compression = 0, 2
|
||||
Gimp.get_pdb().run_procedure('file-png-save', [
|
||||
GObject.Value(Gimp.RunMode, Gimp.RunMode.NONINTERACTIVE),
|
||||
GObject.Value(Gimp.Image, image),
|
||||
GObject.Value(GObject.TYPE_INT, 1),
|
||||
GObject.Value(Gimp.ObjectArray, Gimp.ObjectArray.new(Gimp.Drawable, drawable, 1)),
|
||||
GObject.Value(Gio.File, Gio.File.new_for_path(os.path.join(weight_path, '..', 'cache.png'))),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, interlace),
|
||||
GObject.Value(GObject.TYPE_INT, compression),
|
||||
# write all PNG chunks except oFFs(ets)
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, False),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
])
|
||||
|
||||
with open(os.path.join(weight_path, '..', 'gimp_ml_run.pkl'), 'wb') as file:
|
||||
pickle.dump({"force_cpu": bool(force_cpu)}, file)
|
||||
|
||||
subprocess.call([python_path, plugin_path])
|
||||
|
||||
result = Gimp.file_load(Gimp.RunMode.NONINTERACTIVE, Gio.file_new_for_path(os.path.join(weight_path, '..', 'cache.png')))
|
||||
result_layer = result.get_active_layer()
|
||||
copy = Gimp.Layer.new_from_drawable(result_layer, image)
|
||||
copy.set_name("Deblur")
|
||||
copy.set_mode(Gimp.LayerMode.NORMAL_LEGACY)#DIFFERENCE_LEGACY
|
||||
image.insert_layer(copy, None, -1)
|
||||
|
||||
image.undo_group_end()
|
||||
Gimp.context_pop()
|
||||
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.SUCCESS, GLib.Error())
|
||||
|
||||
|
||||
|
||||
def run(procedure, run_mode, image, n_drawables, layer, args, data):
|
||||
# gio_file = args.index(0)
|
||||
# bucket_size = args.index(0)
|
||||
force_cpu = args.index(1)
|
||||
# output_format = args.index(2)
|
||||
|
||||
progress_bar = None
|
||||
config = None
|
||||
|
||||
if run_mode == Gimp.RunMode.INTERACTIVE:
|
||||
|
||||
config = procedure.create_config()
|
||||
|
||||
# Set properties from arguments. These properties will be changed by the UI.
|
||||
#config.set_property("file", gio_file)
|
||||
#config.set_property("bucket_size", bucket_size)
|
||||
config.set_property("force_cpu", force_cpu)
|
||||
#config.set_property("output_format", output_format)
|
||||
config.begin_run(image, run_mode, args)
|
||||
|
||||
GimpUi.init("deblur.py")
|
||||
use_header_bar = Gtk.Settings.get_default().get_property("gtk-dialogs-use-header")
|
||||
dialog = GimpUi.Dialog(use_header_bar=use_header_bar,
|
||||
title=_("Deblur..."))
|
||||
dialog.add_button("_Cancel", Gtk.ResponseType.CANCEL)
|
||||
dialog.add_button("_OK", Gtk.ResponseType.OK)
|
||||
|
||||
vbox = Gtk.Box(orientation=Gtk.Orientation.VERTICAL,
|
||||
homogeneous=False, spacing=10)
|
||||
dialog.get_content_area().add(vbox)
|
||||
vbox.show()
|
||||
|
||||
# Create grid to set all the properties inside.
|
||||
grid = Gtk.Grid()
|
||||
grid.set_column_homogeneous(False)
|
||||
grid.set_border_width(10)
|
||||
grid.set_column_spacing(10)
|
||||
grid.set_row_spacing(10)
|
||||
vbox.add(grid)
|
||||
grid.show()
|
||||
|
||||
# # Bucket size parameter
|
||||
# label = Gtk.Label.new_with_mnemonic(_("_Bucket Size"))
|
||||
# grid.attach(label, 0, 1, 1, 1)
|
||||
# label.show()
|
||||
# spin = GimpUi.prop_spin_button_new(config, "bucket_size", step_increment=0.001, page_increment=0.1, digits=3)
|
||||
# grid.attach(spin, 1, 1, 1, 1)
|
||||
# spin.show()
|
||||
|
||||
# Force CPU parameter
|
||||
spin = GimpUi.prop_check_button_new(config, "force_cpu", _("Force _CPU"))
|
||||
spin.set_tooltip_text(_("If checked, CPU is used for model inference."
|
||||
" Otherwise, GPU will be used if available."))
|
||||
grid.attach(spin, 1, 2, 1, 1)
|
||||
spin.show()
|
||||
|
||||
# # Output format parameter
|
||||
# label = Gtk.Label.new_with_mnemonic(_("_Output Format"))
|
||||
# grid.attach(label, 0, 3, 1, 1)
|
||||
# label.show()
|
||||
# combo = GimpUi.prop_string_combo_box_new(config, "output_format", output_format_enum.get_tree_model(), 0, 1)
|
||||
# grid.attach(combo, 1, 3, 1, 1)
|
||||
# combo.show()
|
||||
|
||||
progress_bar = Gtk.ProgressBar()
|
||||
vbox.add(progress_bar)
|
||||
progress_bar.show()
|
||||
|
||||
dialog.show()
|
||||
if dialog.run() != Gtk.ResponseType.OK:
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.CANCEL,
|
||||
GLib.Error())
|
||||
|
||||
result = deblur(procedure, image, layer, force_cpu, progress_bar)
|
||||
|
||||
# If the execution was successful, save parameters so they will be restored next time we show dialog.
|
||||
if result.index(0) == Gimp.PDBStatusType.SUCCESS and config is not None:
|
||||
config.end_run(Gimp.PDBStatusType.SUCCESS)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Deblur(Gimp.PlugIn):
|
||||
|
||||
## Parameters ##
|
||||
__gproperties__ = {
|
||||
# "filename": (str,
|
||||
# # TODO: I wanted this property to be a path (and not just str) , so I could use
|
||||
# # prop_file_chooser_button_new to open a file dialog. However, it fails without an error message.
|
||||
# # Gimp.ConfigPath,
|
||||
# _("Histogram _File"),
|
||||
# _("Histogram _File"),
|
||||
# "deblur.csv",
|
||||
# # Gimp.ConfigPathType.FILE,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "file": (Gio.File,
|
||||
# _("Histogram _File"),
|
||||
# "Histogram export file",
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "bucket_size": (float,
|
||||
# _("_Bucket Size"),
|
||||
# "Bucket Size",
|
||||
# 0.001, 1.0, 0.01,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
"force_cpu": (bool,
|
||||
_("Force _CPU"),
|
||||
"Force CPU",
|
||||
False,
|
||||
GObject.ParamFlags.READWRITE),
|
||||
# "output_format": (str,
|
||||
# _("Output format"),
|
||||
# "Output format: 'pixel count', 'normalized', 'percent'",
|
||||
# "pixel count",
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
}
|
||||
|
||||
## GimpPlugIn virtual methods ##
|
||||
def do_query_procedures(self):
|
||||
self.set_translation_domain("gimp30-python",
|
||||
Gio.file_new_for_path(Gimp.locale_directory()))
|
||||
return ['deblur']
|
||||
|
||||
def do_create_procedure(self, name):
|
||||
procedure = None
|
||||
if name == 'deblur':
|
||||
procedure = Gimp.ImageProcedure.new(self, name, Gimp.PDBProcType.PLUGIN, run, None)
|
||||
procedure.set_image_types("*")
|
||||
procedure.set_documentation (
|
||||
N_("Deblur the current layer."),
|
||||
globals()["__doc__"], # This includes the docstring, on the top of the file
|
||||
name)
|
||||
procedure.set_menu_label(N_("Deblur..."))
|
||||
procedure.set_attribution("Kritik Soman",
|
||||
"GIMP-ML",
|
||||
"2021")
|
||||
procedure.add_menu_path("<Image>/Layer/GIMP-ML/")
|
||||
|
||||
# procedure.add_argument_from_property(self, "file")
|
||||
# procedure.add_argument_from_property(self, "bucket_size")
|
||||
procedure.add_argument_from_property(self, "force_cpu")
|
||||
# procedure.add_argument_from_property(self, "output_format")
|
||||
|
||||
return procedure
|
||||
|
||||
|
||||
Gimp.main(Deblur.__gtype__, sys.argv)
|
@ -0,0 +1,227 @@
|
||||
#!/usr/bin/env python3
|
||||
#coding: utf-8
|
||||
"""
|
||||
.d8888b. 8888888 888b d888 8888888b. 888b d888 888
|
||||
d88P Y88b 888 8888b d8888 888 Y88b 8888b d8888 888
|
||||
888 888 888 88888b.d88888 888 888 88888b.d88888 888
|
||||
888 888 888Y88888P888 888 d88P 888Y88888P888 888
|
||||
888 88888 888 888 Y888P 888 8888888P" 888 Y888P 888 888
|
||||
888 888 888 888 Y8P 888 888 888 Y8P 888 888
|
||||
Y88b d88P 888 888 " 888 888 888 " 888 888
|
||||
"Y8888P88 8888888 888 888 888 888 888 88888888
|
||||
|
||||
|
||||
Dehazes the current layer.
|
||||
"""
|
||||
import sys
|
||||
import gi
|
||||
gi.require_version('Gimp', '3.0')
|
||||
from gi.repository import Gimp
|
||||
gi.require_version('GimpUi', '3.0')
|
||||
from gi.repository import GimpUi
|
||||
from gi.repository import GObject
|
||||
from gi.repository import GLib
|
||||
from gi.repository import Gio
|
||||
gi.require_version('Gtk', '3.0')
|
||||
from gi.repository import Gtk
|
||||
|
||||
import gettext
|
||||
_ = gettext.gettext
|
||||
def N_(message): return message
|
||||
|
||||
import subprocess
|
||||
import pickle
|
||||
import os
|
||||
|
||||
def dehaze(procedure, image, drawable, force_cpu, progress_bar):
|
||||
config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "tools")
|
||||
with open(os.path.join(config_path, 'gimp_ml_config.pkl'), 'rb') as file:
|
||||
data_output = pickle.load(file)
|
||||
weight_path = data_output["weight_path"]
|
||||
python_path = data_output["python_path"]
|
||||
plugin_path = os.path.join(config_path, 'dehaze.py')
|
||||
|
||||
Gimp.context_push()
|
||||
image.undo_group_start()
|
||||
|
||||
interlace, compression = 0, 2
|
||||
Gimp.get_pdb().run_procedure('file-png-save', [
|
||||
GObject.Value(Gimp.RunMode, Gimp.RunMode.NONINTERACTIVE),
|
||||
GObject.Value(Gimp.Image, image),
|
||||
GObject.Value(GObject.TYPE_INT, 1),
|
||||
GObject.Value(Gimp.ObjectArray, Gimp.ObjectArray.new(Gimp.Drawable, drawable, 1)),
|
||||
GObject.Value(Gio.File, Gio.File.new_for_path(os.path.join(weight_path, '..', 'cache.png'))),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, interlace),
|
||||
GObject.Value(GObject.TYPE_INT, compression),
|
||||
# write all PNG chunks except oFFs(ets)
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, False),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
])
|
||||
|
||||
with open(os.path.join(weight_path, '..', 'gimp_ml_run.pkl'), 'wb') as file:
|
||||
pickle.dump({"force_cpu": bool(force_cpu)}, file)
|
||||
|
||||
subprocess.call([python_path, plugin_path])
|
||||
|
||||
result = Gimp.file_load(Gimp.RunMode.NONINTERACTIVE, Gio.file_new_for_path(os.path.join(weight_path, '..', 'cache.png')))
|
||||
result_layer = result.get_active_layer()
|
||||
copy = Gimp.Layer.new_from_drawable(result_layer, image)
|
||||
copy.set_name("Dehaze")
|
||||
copy.set_mode(Gimp.LayerMode.NORMAL_LEGACY)#DIFFERENCE_LEGACY
|
||||
image.insert_layer(copy, None, -1)
|
||||
|
||||
image.undo_group_end()
|
||||
Gimp.context_pop()
|
||||
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.SUCCESS, GLib.Error())
|
||||
|
||||
|
||||
|
||||
def run(procedure, run_mode, image, n_drawables, layer, args, data):
|
||||
# gio_file = args.index(0)
|
||||
# bucket_size = args.index(0)
|
||||
force_cpu = args.index(1)
|
||||
# output_format = args.index(2)
|
||||
|
||||
progress_bar = None
|
||||
config = None
|
||||
|
||||
if run_mode == Gimp.RunMode.INTERACTIVE:
|
||||
|
||||
config = procedure.create_config()
|
||||
|
||||
# Set properties from arguments. These properties will be changed by the UI.
|
||||
#config.set_property("file", gio_file)
|
||||
#config.set_property("bucket_size", bucket_size)
|
||||
config.set_property("force_cpu", force_cpu)
|
||||
#config.set_property("output_format", output_format)
|
||||
config.begin_run(image, run_mode, args)
|
||||
|
||||
GimpUi.init("dehaze.py")
|
||||
use_header_bar = Gtk.Settings.get_default().get_property("gtk-dialogs-use-header")
|
||||
dialog = GimpUi.Dialog(use_header_bar=use_header_bar,
|
||||
title=_("Dehaze..."))
|
||||
dialog.add_button("_Cancel", Gtk.ResponseType.CANCEL)
|
||||
dialog.add_button("_OK", Gtk.ResponseType.OK)
|
||||
|
||||
vbox = Gtk.Box(orientation=Gtk.Orientation.VERTICAL,
|
||||
homogeneous=False, spacing=10)
|
||||
dialog.get_content_area().add(vbox)
|
||||
vbox.show()
|
||||
|
||||
# Create grid to set all the properties inside.
|
||||
grid = Gtk.Grid()
|
||||
grid.set_column_homogeneous(False)
|
||||
grid.set_border_width(10)
|
||||
grid.set_column_spacing(10)
|
||||
grid.set_row_spacing(10)
|
||||
vbox.add(grid)
|
||||
grid.show()
|
||||
|
||||
# # Bucket size parameter
|
||||
# label = Gtk.Label.new_with_mnemonic(_("_Bucket Size"))
|
||||
# grid.attach(label, 0, 1, 1, 1)
|
||||
# label.show()
|
||||
# spin = GimpUi.prop_spin_button_new(config, "bucket_size", step_increment=0.001, page_increment=0.1, digits=3)
|
||||
# grid.attach(spin, 1, 1, 1, 1)
|
||||
# spin.show()
|
||||
|
||||
# Force CPU parameter
|
||||
spin = GimpUi.prop_check_button_new(config, "force_cpu", _("Force _CPU"))
|
||||
spin.set_tooltip_text(_("If checked, CPU is used for model inference."
|
||||
" Otherwise, GPU will be used if available."))
|
||||
grid.attach(spin, 1, 2, 1, 1)
|
||||
spin.show()
|
||||
|
||||
# # Output format parameter
|
||||
# label = Gtk.Label.new_with_mnemonic(_("_Output Format"))
|
||||
# grid.attach(label, 0, 3, 1, 1)
|
||||
# label.show()
|
||||
# combo = GimpUi.prop_string_combo_box_new(config, "output_format", output_format_enum.get_tree_model(), 0, 1)
|
||||
# grid.attach(combo, 1, 3, 1, 1)
|
||||
# combo.show()
|
||||
|
||||
progress_bar = Gtk.ProgressBar()
|
||||
vbox.add(progress_bar)
|
||||
progress_bar.show()
|
||||
|
||||
dialog.show()
|
||||
if dialog.run() != Gtk.ResponseType.OK:
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.CANCEL,
|
||||
GLib.Error())
|
||||
|
||||
result = dehaze(procedure, image, layer, force_cpu, progress_bar)
|
||||
|
||||
# If the execution was successful, save parameters so they will be restored next time we show dialog.
|
||||
if result.index(0) == Gimp.PDBStatusType.SUCCESS and config is not None:
|
||||
config.end_run(Gimp.PDBStatusType.SUCCESS)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Dehaze(Gimp.PlugIn):
|
||||
|
||||
## Parameters ##
|
||||
__gproperties__ = {
|
||||
# "filename": (str,
|
||||
# # TODO: I wanted this property to be a path (and not just str) , so I could use
|
||||
# # prop_file_chooser_button_new to open a file dialog. However, it fails without an error message.
|
||||
# # Gimp.ConfigPath,
|
||||
# _("Histogram _File"),
|
||||
# _("Histogram _File"),
|
||||
# "dehaze.csv",
|
||||
# # Gimp.ConfigPathType.FILE,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "file": (Gio.File,
|
||||
# _("Histogram _File"),
|
||||
# "Histogram export file",
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "bucket_size": (float,
|
||||
# _("_Bucket Size"),
|
||||
# "Bucket Size",
|
||||
# 0.001, 1.0, 0.01,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
"force_cpu": (bool,
|
||||
_("Force _CPU"),
|
||||
"Force CPU",
|
||||
False,
|
||||
GObject.ParamFlags.READWRITE),
|
||||
# "output_format": (str,
|
||||
# _("Output format"),
|
||||
# "Output format: 'pixel count', 'normalized', 'percent'",
|
||||
# "pixel count",
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
}
|
||||
|
||||
## GimpPlugIn virtual methods ##
|
||||
def do_query_procedures(self):
|
||||
self.set_translation_domain("gimp30-python",
|
||||
Gio.file_new_for_path(Gimp.locale_directory()))
|
||||
return ['dehaze']
|
||||
|
||||
def do_create_procedure(self, name):
|
||||
procedure = None
|
||||
if name == 'dehaze':
|
||||
procedure = Gimp.ImageProcedure.new(self, name, Gimp.PDBProcType.PLUGIN, run, None)
|
||||
procedure.set_image_types("*")
|
||||
procedure.set_documentation (
|
||||
N_("Dehazes the current layer."),
|
||||
globals()["__doc__"], # This includes the docstring, on the top of the file
|
||||
name)
|
||||
procedure.set_menu_label(N_("_Dehaze..."))
|
||||
procedure.set_attribution("Kritik Soman",
|
||||
"GIMP-ML",
|
||||
"2021")
|
||||
procedure.add_menu_path("<Image>/Layer/GIMP-ML/")
|
||||
|
||||
# procedure.add_argument_from_property(self, "file")
|
||||
# procedure.add_argument_from_property(self, "bucket_size")
|
||||
procedure.add_argument_from_property(self, "force_cpu")
|
||||
# procedure.add_argument_from_property(self, "output_format")
|
||||
|
||||
return procedure
|
||||
|
||||
|
||||
Gimp.main(Dehaze.__gtype__, sys.argv)
|
@ -0,0 +1,227 @@
|
||||
#!/usr/bin/env python3
|
||||
#coding: utf-8
|
||||
"""
|
||||
.d8888b. 8888888 888b d888 8888888b. 888b d888 888
|
||||
d88P Y88b 888 8888b d8888 888 Y88b 8888b d8888 888
|
||||
888 888 888 88888b.d88888 888 888 88888b.d88888 888
|
||||
888 888 888Y88888P888 888 d88P 888Y88888P888 888
|
||||
888 88888 888 888 Y888P 888 8888888P" 888 Y888P 888 888
|
||||
888 888 888 888 Y8P 888 888 888 Y8P 888 888
|
||||
Y88b d88P 888 888 " 888 888 888 " 888 888
|
||||
"Y8888P88 8888888 888 888 888 888 888 88888888
|
||||
|
||||
|
||||
denoises the current layer.
|
||||
"""
|
||||
import sys
|
||||
import gi
|
||||
gi.require_version('Gimp', '3.0')
|
||||
from gi.repository import Gimp
|
||||
gi.require_version('GimpUi', '3.0')
|
||||
from gi.repository import GimpUi
|
||||
from gi.repository import GObject
|
||||
from gi.repository import GLib
|
||||
from gi.repository import Gio
|
||||
gi.require_version('Gtk', '3.0')
|
||||
from gi.repository import Gtk
|
||||
|
||||
import gettext
|
||||
_ = gettext.gettext
|
||||
def N_(message): return message
|
||||
|
||||
import subprocess
|
||||
import pickle
|
||||
import os
|
||||
|
||||
def denoise(procedure, image, drawable, force_cpu, progress_bar):
|
||||
config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "tools")
|
||||
with open(os.path.join(config_path, 'gimp_ml_config.pkl'), 'rb') as file:
|
||||
data_output = pickle.load(file)
|
||||
weight_path = data_output["weight_path"]
|
||||
python_path = data_output["python_path"]
|
||||
plugin_path = os.path.join(config_path, 'denoise.py')
|
||||
|
||||
Gimp.context_push()
|
||||
image.undo_group_start()
|
||||
|
||||
interlace, compression = 0, 2
|
||||
Gimp.get_pdb().run_procedure('file-png-save', [
|
||||
GObject.Value(Gimp.RunMode, Gimp.RunMode.NONINTERACTIVE),
|
||||
GObject.Value(Gimp.Image, image),
|
||||
GObject.Value(GObject.TYPE_INT, 1),
|
||||
GObject.Value(Gimp.ObjectArray, Gimp.ObjectArray.new(Gimp.Drawable, drawable, 1)),
|
||||
GObject.Value(Gio.File, Gio.File.new_for_path(os.path.join(weight_path, '..', 'cache.png'))),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, interlace),
|
||||
GObject.Value(GObject.TYPE_INT, compression),
|
||||
# write all PNG chunks except oFFs(ets)
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, False),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
])
|
||||
|
||||
with open(os.path.join(weight_path, '..', 'gimp_ml_run.pkl'), 'wb') as file:
|
||||
pickle.dump({"force_cpu": bool(force_cpu)}, file)
|
||||
|
||||
subprocess.call([python_path, plugin_path])
|
||||
|
||||
result = Gimp.file_load(Gimp.RunMode.NONINTERACTIVE, Gio.file_new_for_path(os.path.join(weight_path, '..', 'cache.png')))
|
||||
result_layer = result.get_active_layer()
|
||||
copy = Gimp.Layer.new_from_drawable(result_layer, image)
|
||||
copy.set_name("Denoise")
|
||||
copy.set_mode(Gimp.LayerMode.NORMAL_LEGACY)#DIFFERENCE_LEGACY
|
||||
image.insert_layer(copy, None, -1)
|
||||
|
||||
image.undo_group_end()
|
||||
Gimp.context_pop()
|
||||
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.SUCCESS, GLib.Error())
|
||||
|
||||
|
||||
|
||||
def run(procedure, run_mode, image, n_drawables, layer, args, data):
|
||||
# gio_file = args.index(0)
|
||||
# bucket_size = args.index(0)
|
||||
force_cpu = args.index(1)
|
||||
# output_format = args.index(2)
|
||||
|
||||
progress_bar = None
|
||||
config = None
|
||||
|
||||
if run_mode == Gimp.RunMode.INTERACTIVE:
|
||||
|
||||
config = procedure.create_config()
|
||||
|
||||
# Set properties from arguments. These properties will be changed by the UI.
|
||||
#config.set_property("file", gio_file)
|
||||
#config.set_property("bucket_size", bucket_size)
|
||||
config.set_property("force_cpu", force_cpu)
|
||||
#config.set_property("output_format", output_format)
|
||||
config.begin_run(image, run_mode, args)
|
||||
|
||||
GimpUi.init("denoise.py")
|
||||
use_header_bar = Gtk.Settings.get_default().get_property("gtk-dialogs-use-header")
|
||||
dialog = GimpUi.Dialog(use_header_bar=use_header_bar,
|
||||
title=_("Denoise..."))
|
||||
dialog.add_button("_Cancel", Gtk.ResponseType.CANCEL)
|
||||
dialog.add_button("_OK", Gtk.ResponseType.OK)
|
||||
|
||||
vbox = Gtk.Box(orientation=Gtk.Orientation.VERTICAL,
|
||||
homogeneous=False, spacing=10)
|
||||
dialog.get_content_area().add(vbox)
|
||||
vbox.show()
|
||||
|
||||
# Create grid to set all the properties inside.
|
||||
grid = Gtk.Grid()
|
||||
grid.set_column_homogeneous(False)
|
||||
grid.set_border_width(10)
|
||||
grid.set_column_spacing(10)
|
||||
grid.set_row_spacing(10)
|
||||
vbox.add(grid)
|
||||
grid.show()
|
||||
|
||||
# # Bucket size parameter
|
||||
# label = Gtk.Label.new_with_mnemonic(_("_Bucket Size"))
|
||||
# grid.attach(label, 0, 1, 1, 1)
|
||||
# label.show()
|
||||
# spin = GimpUi.prop_spin_button_new(config, "bucket_size", step_increment=0.001, page_increment=0.1, digits=3)
|
||||
# grid.attach(spin, 1, 1, 1, 1)
|
||||
# spin.show()
|
||||
|
||||
# Force CPU parameter
|
||||
spin = GimpUi.prop_check_button_new(config, "force_cpu", _("Force _CPU"))
|
||||
spin.set_tooltip_text(_("If checked, CPU is used for model inference."
|
||||
" Otherwise, GPU will be used if available."))
|
||||
grid.attach(spin, 1, 2, 1, 1)
|
||||
spin.show()
|
||||
|
||||
# # Output format parameter
|
||||
# label = Gtk.Label.new_with_mnemonic(_("_Output Format"))
|
||||
# grid.attach(label, 0, 3, 1, 1)
|
||||
# label.show()
|
||||
# combo = GimpUi.prop_string_combo_box_new(config, "output_format", output_format_enum.get_tree_model(), 0, 1)
|
||||
# grid.attach(combo, 1, 3, 1, 1)
|
||||
# combo.show()
|
||||
|
||||
progress_bar = Gtk.ProgressBar()
|
||||
vbox.add(progress_bar)
|
||||
progress_bar.show()
|
||||
|
||||
dialog.show()
|
||||
if dialog.run() != Gtk.ResponseType.OK:
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.CANCEL,
|
||||
GLib.Error())
|
||||
|
||||
result = denoise(procedure, image, layer, force_cpu, progress_bar)
|
||||
|
||||
# If the execution was successful, save parameters so they will be restored next time we show dialog.
|
||||
if result.index(0) == Gimp.PDBStatusType.SUCCESS and config is not None:
|
||||
config.end_run(Gimp.PDBStatusType.SUCCESS)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Denoise(Gimp.PlugIn):
|
||||
|
||||
## Parameters ##
|
||||
__gproperties__ = {
|
||||
# "filename": (str,
|
||||
# # TODO: I wanted this property to be a path (and not just str) , so I could use
|
||||
# # prop_file_chooser_button_new to open a file dialog. However, it fails without an error message.
|
||||
# # Gimp.ConfigPath,
|
||||
# _("Histogram _File"),
|
||||
# _("Histogram _File"),
|
||||
# "denoise.csv",
|
||||
# # Gimp.ConfigPathType.FILE,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "file": (Gio.File,
|
||||
# _("Histogram _File"),
|
||||
# "Histogram export file",
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "bucket_size": (float,
|
||||
# _("_Bucket Size"),
|
||||
# "Bucket Size",
|
||||
# 0.001, 1.0, 0.01,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
"force_cpu": (bool,
|
||||
_("Force _CPU"),
|
||||
"Force CPU",
|
||||
False,
|
||||
GObject.ParamFlags.READWRITE),
|
||||
# "output_format": (str,
|
||||
# _("Output format"),
|
||||
# "Output format: 'pixel count', 'normalized', 'percent'",
|
||||
# "pixel count",
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
}
|
||||
|
||||
## GimpPlugIn virtual methods ##
|
||||
def do_query_procedures(self):
|
||||
self.set_translation_domain("gimp30-python",
|
||||
Gio.file_new_for_path(Gimp.locale_directory()))
|
||||
return ['denoise']
|
||||
|
||||
def do_create_procedure(self, name):
|
||||
procedure = None
|
||||
if name == 'denoise':
|
||||
procedure = Gimp.ImageProcedure.new(self, name, Gimp.PDBProcType.PLUGIN, run, None)
|
||||
procedure.set_image_types("*")
|
||||
procedure.set_documentation (
|
||||
N_("Denoises the current layer."),
|
||||
globals()["__doc__"], # This includes the docstring, on the top of the file
|
||||
name)
|
||||
procedure.set_menu_label(N_("Denoise..."))
|
||||
procedure.set_attribution("Kritik Soman",
|
||||
"GIMP-ML",
|
||||
"2021")
|
||||
procedure.add_menu_path("<Image>/Layer/GIMP-ML/")
|
||||
|
||||
# procedure.add_argument_from_property(self, "file")
|
||||
# procedure.add_argument_from_property(self, "bucket_size")
|
||||
procedure.add_argument_from_property(self, "force_cpu")
|
||||
# procedure.add_argument_from_property(self, "output_format")
|
||||
|
||||
return procedure
|
||||
|
||||
|
||||
Gimp.main(Denoise.__gtype__, sys.argv)
|
@ -0,0 +1,227 @@
|
||||
#!/usr/bin/env python3
|
||||
#coding: utf-8
|
||||
"""
|
||||
.d8888b. 8888888 888b d888 8888888b. 888b d888 888
|
||||
d88P Y88b 888 8888b d8888 888 Y88b 8888b d8888 888
|
||||
888 888 888 88888b.d88888 888 888 88888b.d88888 888
|
||||
888 888 888Y88888P888 888 d88P 888Y88888P888 888
|
||||
888 88888 888 888 Y888P 888 8888888P" 888 Y888P 888 888
|
||||
888 888 888 888 Y8P 888 888 888 Y8P 888 888
|
||||
Y88b d88P 888 888 " 888 888 888 " 888 888
|
||||
"Y8888P88 8888888 888 888 888 888 888 88888888
|
||||
|
||||
|
||||
Extracts the monocular depth of the current layer.
|
||||
"""
|
||||
import sys
|
||||
import gi
|
||||
gi.require_version('Gimp', '3.0')
|
||||
from gi.repository import Gimp
|
||||
gi.require_version('GimpUi', '3.0')
|
||||
from gi.repository import GimpUi
|
||||
from gi.repository import GObject
|
||||
from gi.repository import GLib
|
||||
from gi.repository import Gio
|
||||
gi.require_version('Gtk', '3.0')
|
||||
from gi.repository import Gtk
|
||||
|
||||
import gettext
|
||||
_ = gettext.gettext
|
||||
def N_(message): return message
|
||||
|
||||
import subprocess
|
||||
import pickle
|
||||
import os
|
||||
|
||||
def enlighten(procedure, image, drawable, force_cpu, progress_bar):
|
||||
config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "tools")
|
||||
with open(os.path.join(config_path, 'gimp_ml_config.pkl'), 'rb') as file:
|
||||
data_output = pickle.load(file)
|
||||
weight_path = data_output["weight_path"]
|
||||
python_path = data_output["python_path"]
|
||||
plugin_path = os.path.join(config_path, 'enlighten.py')
|
||||
|
||||
Gimp.context_push()
|
||||
image.undo_group_start()
|
||||
|
||||
interlace, compression = 0, 2
|
||||
Gimp.get_pdb().run_procedure('file-png-save', [
|
||||
GObject.Value(Gimp.RunMode, Gimp.RunMode.NONINTERACTIVE),
|
||||
GObject.Value(Gimp.Image, image),
|
||||
GObject.Value(GObject.TYPE_INT, 1),
|
||||
GObject.Value(Gimp.ObjectArray, Gimp.ObjectArray.new(Gimp.Drawable, drawable, 1)),
|
||||
GObject.Value(Gio.File, Gio.File.new_for_path(os.path.join(weight_path, '..', 'cache.png'))),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, interlace),
|
||||
GObject.Value(GObject.TYPE_INT, compression),
|
||||
# write all PNG chunks except oFFs(ets)
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, False),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
])
|
||||
|
||||
with open(os.path.join(weight_path, '..', 'gimp_ml_run.pkl'), 'wb') as file:
|
||||
pickle.dump({"force_cpu": bool(force_cpu)}, file)
|
||||
|
||||
subprocess.call([python_path, plugin_path])
|
||||
|
||||
result = Gimp.file_load(Gimp.RunMode.NONINTERACTIVE, Gio.file_new_for_path(os.path.join(weight_path, '..', 'cache.png')))
|
||||
result_layer = result.get_active_layer()
|
||||
copy = Gimp.Layer.new_from_drawable(result_layer, image)
|
||||
copy.set_name("Enlightened")
|
||||
copy.set_mode(Gimp.LayerMode.NORMAL_LEGACY)#DIFFERENCE_LEGACY
|
||||
image.insert_layer(copy, None, -1)
|
||||
|
||||
image.undo_group_end()
|
||||
Gimp.context_pop()
|
||||
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.SUCCESS, GLib.Error())
|
||||
|
||||
|
||||
|
||||
def run(procedure, run_mode, image, n_drawables, layer, args, data):
|
||||
# gio_file = args.index(0)
|
||||
# bucket_size = args.index(0)
|
||||
force_cpu = args.index(1)
|
||||
# output_format = args.index(2)
|
||||
|
||||
progress_bar = None
|
||||
config = None
|
||||
|
||||
if run_mode == Gimp.RunMode.INTERACTIVE:
|
||||
|
||||
config = procedure.create_config()
|
||||
|
||||
# Set properties from arguments. These properties will be changed by the UI.
|
||||
#config.set_property("file", gio_file)
|
||||
#config.set_property("bucket_size", bucket_size)
|
||||
config.set_property("force_cpu", force_cpu)
|
||||
#config.set_property("output_format", output_format)
|
||||
config.begin_run(image, run_mode, args)
|
||||
|
||||
GimpUi.init("enlighten.py")
|
||||
use_header_bar = Gtk.Settings.get_default().get_property("gtk-dialogs-use-header")
|
||||
dialog = GimpUi.Dialog(use_header_bar=use_header_bar,
|
||||
title=_("Mono Depth..."))
|
||||
dialog.add_button("_Cancel", Gtk.ResponseType.CANCEL)
|
||||
dialog.add_button("_OK", Gtk.ResponseType.OK)
|
||||
|
||||
vbox = Gtk.Box(orientation=Gtk.Orientation.VERTICAL,
|
||||
homogeneous=False, spacing=10)
|
||||
dialog.get_content_area().add(vbox)
|
||||
vbox.show()
|
||||
|
||||
# Create grid to set all the properties inside.
|
||||
grid = Gtk.Grid()
|
||||
grid.set_column_homogeneous(False)
|
||||
grid.set_border_width(10)
|
||||
grid.set_column_spacing(10)
|
||||
grid.set_row_spacing(10)
|
||||
vbox.add(grid)
|
||||
grid.show()
|
||||
|
||||
# # Bucket size parameter
|
||||
# label = Gtk.Label.new_with_mnemonic(_("_Bucket Size"))
|
||||
# grid.attach(label, 0, 1, 1, 1)
|
||||
# label.show()
|
||||
# spin = GimpUi.prop_spin_button_new(config, "bucket_size", step_increment=0.001, page_increment=0.1, digits=3)
|
||||
# grid.attach(spin, 1, 1, 1, 1)
|
||||
# spin.show()
|
||||
|
||||
# Force CPU parameter
|
||||
spin = GimpUi.prop_check_button_new(config, "force_cpu", _("Force _CPU"))
|
||||
spin.set_tooltip_text(_("If checked, CPU is used for model inference."
|
||||
" Otherwise, GPU will be used if available."))
|
||||
grid.attach(spin, 1, 2, 1, 1)
|
||||
spin.show()
|
||||
|
||||
# # Output format parameter
|
||||
# label = Gtk.Label.new_with_mnemonic(_("_Output Format"))
|
||||
# grid.attach(label, 0, 3, 1, 1)
|
||||
# label.show()
|
||||
# combo = GimpUi.prop_string_combo_box_new(config, "output_format", output_format_enum.get_tree_model(), 0, 1)
|
||||
# grid.attach(combo, 1, 3, 1, 1)
|
||||
# combo.show()
|
||||
|
||||
progress_bar = Gtk.ProgressBar()
|
||||
vbox.add(progress_bar)
|
||||
progress_bar.show()
|
||||
|
||||
dialog.show()
|
||||
if dialog.run() != Gtk.ResponseType.OK:
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.CANCEL,
|
||||
GLib.Error())
|
||||
|
||||
result = enlighten(procedure, image, layer, force_cpu, progress_bar)
|
||||
|
||||
# If the execution was successful, save parameters so they will be restored next time we show dialog.
|
||||
if result.index(0) == Gimp.PDBStatusType.SUCCESS and config is not None:
|
||||
config.end_run(Gimp.PDBStatusType.SUCCESS)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Enlighten(Gimp.PlugIn):
|
||||
|
||||
## Parameters ##
|
||||
__gproperties__ = {
|
||||
# "filename": (str,
|
||||
# # TODO: I wanted this property to be a path (and not just str) , so I could use
|
||||
# # prop_file_chooser_button_new to open a file dialog. However, it fails without an error message.
|
||||
# # Gimp.ConfigPath,
|
||||
# _("Histogram _File"),
|
||||
# _("Histogram _File"),
|
||||
# "monodepth.csv",
|
||||
# # Gimp.ConfigPathType.FILE,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "file": (Gio.File,
|
||||
# _("Histogram _File"),
|
||||
# "Histogram export file",
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "bucket_size": (float,
|
||||
# _("_Bucket Size"),
|
||||
# "Bucket Size",
|
||||
# 0.001, 1.0, 0.01,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
"force_cpu": (bool,
|
||||
_("Force _CPU"),
|
||||
"Force CPU",
|
||||
False,
|
||||
GObject.ParamFlags.READWRITE),
|
||||
# "output_format": (str,
|
||||
# _("Output format"),
|
||||
# "Output format: 'pixel count', 'normalized', 'percent'",
|
||||
# "pixel count",
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
}
|
||||
|
||||
## GimpPlugIn virtual methods ##
|
||||
def do_query_procedures(self):
|
||||
self.set_translation_domain("gimp30-python",
|
||||
Gio.file_new_for_path(Gimp.locale_directory()))
|
||||
return ['enlighten']
|
||||
|
||||
def do_create_procedure(self, name):
|
||||
procedure = None
|
||||
if name == 'enlighten':
|
||||
procedure = Gimp.ImageProcedure.new(self, name, Gimp.PDBProcType.PLUGIN, run, None)
|
||||
procedure.set_image_types("*")
|
||||
procedure.set_documentation (
|
||||
N_("Enlightens the current layer."),
|
||||
globals()["__doc__"], # This includes the docstring, on the top of the file
|
||||
name)
|
||||
procedure.set_menu_label(N_("Enlighten..."))
|
||||
procedure.set_attribution("Kritik Soman",
|
||||
"GIMP-ML",
|
||||
"2021")
|
||||
procedure.add_menu_path("<Image>/Layer/GIMP-ML/")
|
||||
|
||||
# procedure.add_argument_from_property(self, "file")
|
||||
# procedure.add_argument_from_property(self, "bucket_size")
|
||||
procedure.add_argument_from_property(self, "force_cpu")
|
||||
# procedure.add_argument_from_property(self, "output_format")
|
||||
|
||||
return procedure
|
||||
|
||||
|
||||
Gimp.main(Enlighten.__gtype__, sys.argv)
|
@ -0,0 +1,227 @@
|
||||
#!/usr/bin/env python3
|
||||
#coding: utf-8
|
||||
"""
|
||||
.d8888b. 8888888 888b d888 8888888b. 888b d888 888
|
||||
d88P Y88b 888 8888b d8888 888 Y88b 8888b d8888 888
|
||||
888 888 888 88888b.d88888 888 888 88888b.d88888 888
|
||||
888 888 888Y88888P888 888 d88P 888Y88888P888 888
|
||||
888 88888 888 888 Y888P 888 8888888P" 888 Y888P 888 888
|
||||
888 888 888 888 Y8P 888 888 888 Y8P 888 888
|
||||
Y88b d88P 888 888 " 888 888 888 " 888 888
|
||||
"Y8888P88 8888888 888 888 888 888 888 88888888
|
||||
|
||||
|
||||
semantic segmentation for a portrait present in the current layer.
|
||||
"""
|
||||
import sys
|
||||
import gi
|
||||
gi.require_version('Gimp', '3.0')
|
||||
from gi.repository import Gimp
|
||||
gi.require_version('GimpUi', '3.0')
|
||||
from gi.repository import GimpUi
|
||||
from gi.repository import GObject
|
||||
from gi.repository import GLib
|
||||
from gi.repository import Gio
|
||||
gi.require_version('Gtk', '3.0')
|
||||
from gi.repository import Gtk
|
||||
|
||||
import gettext
|
||||
_ = gettext.gettext
|
||||
def N_(message): return message
|
||||
|
||||
import subprocess
|
||||
import pickle
|
||||
import os
|
||||
|
||||
def faceparse(procedure, image, drawable, force_cpu, progress_bar):
|
||||
config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "tools")
|
||||
with open(os.path.join(config_path, 'gimp_ml_config.pkl'), 'rb') as file:
|
||||
data_output = pickle.load(file)
|
||||
weight_path = data_output["weight_path"]
|
||||
python_path = data_output["python_path"]
|
||||
plugin_path = os.path.join(config_path, 'faceparse.py')
|
||||
|
||||
Gimp.context_push()
|
||||
image.undo_group_start()
|
||||
|
||||
interlace, compression = 0, 2
|
||||
Gimp.get_pdb().run_procedure('file-png-save', [
|
||||
GObject.Value(Gimp.RunMode, Gimp.RunMode.NONINTERACTIVE),
|
||||
GObject.Value(Gimp.Image, image),
|
||||
GObject.Value(GObject.TYPE_INT, 1),
|
||||
GObject.Value(Gimp.ObjectArray, Gimp.ObjectArray.new(Gimp.Drawable, drawable, 1)),
|
||||
GObject.Value(Gio.File, Gio.File.new_for_path(os.path.join(weight_path, '..', 'cache.png'))),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, interlace),
|
||||
GObject.Value(GObject.TYPE_INT, compression),
|
||||
# write all PNG chunks except oFFs(ets)
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, False),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
])
|
||||
|
||||
with open(os.path.join(weight_path, '..', 'gimp_ml_run.pkl'), 'wb') as file:
|
||||
pickle.dump({"force_cpu": bool(force_cpu)}, file)
|
||||
|
||||
subprocess.call([python_path, plugin_path])
|
||||
|
||||
result = Gimp.file_load(Gimp.RunMode.NONINTERACTIVE, Gio.file_new_for_path(os.path.join(weight_path, '..', 'cache.png')))
|
||||
result_layer = result.get_active_layer()
|
||||
copy = Gimp.Layer.new_from_drawable(result_layer, image)
|
||||
copy.set_name("Face Parse")
|
||||
copy.set_mode(Gimp.LayerMode.NORMAL_LEGACY)#DIFFERENCE_LEGACY
|
||||
image.insert_layer(copy, None, -1)
|
||||
|
||||
image.undo_group_end()
|
||||
Gimp.context_pop()
|
||||
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.SUCCESS, GLib.Error())
|
||||
|
||||
|
||||
|
||||
def run(procedure, run_mode, image, n_drawables, layer, args, data):
|
||||
# gio_file = args.index(0)
|
||||
# bucket_size = args.index(0)
|
||||
force_cpu = args.index(1)
|
||||
# output_format = args.index(2)
|
||||
|
||||
progress_bar = None
|
||||
config = None
|
||||
|
||||
if run_mode == Gimp.RunMode.INTERACTIVE:
|
||||
|
||||
config = procedure.create_config()
|
||||
|
||||
# Set properties from arguments. These properties will be changed by the UI.
|
||||
#config.set_property("file", gio_file)
|
||||
#config.set_property("bucket_size", bucket_size)
|
||||
config.set_property("force_cpu", force_cpu)
|
||||
#config.set_property("output_format", output_format)
|
||||
config.begin_run(image, run_mode, args)
|
||||
|
||||
GimpUi.init("faceparse.py")
|
||||
use_header_bar = Gtk.Settings.get_default().get_property("gtk-dialogs-use-header")
|
||||
dialog = GimpUi.Dialog(use_header_bar=use_header_bar,
|
||||
title=_("Face Parse..."))
|
||||
dialog.add_button("_Cancel", Gtk.ResponseType.CANCEL)
|
||||
dialog.add_button("_OK", Gtk.ResponseType.OK)
|
||||
|
||||
vbox = Gtk.Box(orientation=Gtk.Orientation.VERTICAL,
|
||||
homogeneous=False, spacing=10)
|
||||
dialog.get_content_area().add(vbox)
|
||||
vbox.show()
|
||||
|
||||
# Create grid to set all the properties inside.
|
||||
grid = Gtk.Grid()
|
||||
grid.set_column_homogeneous(False)
|
||||
grid.set_border_width(10)
|
||||
grid.set_column_spacing(10)
|
||||
grid.set_row_spacing(10)
|
||||
vbox.add(grid)
|
||||
grid.show()
|
||||
|
||||
# # Bucket size parameter
|
||||
# label = Gtk.Label.new_with_mnemonic(_("_Bucket Size"))
|
||||
# grid.attach(label, 0, 1, 1, 1)
|
||||
# label.show()
|
||||
# spin = GimpUi.prop_spin_button_new(config, "bucket_size", step_increment=0.001, page_increment=0.1, digits=3)
|
||||
# grid.attach(spin, 1, 1, 1, 1)
|
||||
# spin.show()
|
||||
|
||||
# Force CPU parameter
|
||||
spin = GimpUi.prop_check_button_new(config, "force_cpu", _("Force _CPU"))
|
||||
spin.set_tooltip_text(_("If checked, CPU is used for model inference."
|
||||
" Otherwise, GPU will be used if available."))
|
||||
grid.attach(spin, 1, 2, 1, 1)
|
||||
spin.show()
|
||||
|
||||
# # Output format parameter
|
||||
# label = Gtk.Label.new_with_mnemonic(_("_Output Format"))
|
||||
# grid.attach(label, 0, 3, 1, 1)
|
||||
# label.show()
|
||||
# combo = GimpUi.prop_string_combo_box_new(config, "output_format", output_format_enum.get_tree_model(), 0, 1)
|
||||
# grid.attach(combo, 1, 3, 1, 1)
|
||||
# combo.show()
|
||||
|
||||
progress_bar = Gtk.ProgressBar()
|
||||
vbox.add(progress_bar)
|
||||
progress_bar.show()
|
||||
|
||||
dialog.show()
|
||||
if dialog.run() != Gtk.ResponseType.OK:
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.CANCEL,
|
||||
GLib.Error())
|
||||
|
||||
result = faceparse(procedure, image, layer, force_cpu, progress_bar)
|
||||
|
||||
# If the execution was successful, save parameters so they will be restored next time we show dialog.
|
||||
if result.index(0) == Gimp.PDBStatusType.SUCCESS and config is not None:
|
||||
config.end_run(Gimp.PDBStatusType.SUCCESS)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class FaceParse(Gimp.PlugIn):
|
||||
|
||||
## Parameters ##
|
||||
__gproperties__ = {
|
||||
# "filename": (str,
|
||||
# # TODO: I wanted this property to be a path (and not just str) , so I could use
|
||||
# # prop_file_chooser_button_new to open a file dialog. However, it fails without an error message.
|
||||
# # Gimp.ConfigPath,
|
||||
# _("Histogram _File"),
|
||||
# _("Histogram _File"),
|
||||
# "monodepth.csv",
|
||||
# # Gimp.ConfigPathType.FILE,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "file": (Gio.File,
|
||||
# _("Histogram _File"),
|
||||
# "Histogram export file",
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "bucket_size": (float,
|
||||
# _("_Bucket Size"),
|
||||
# "Bucket Size",
|
||||
# 0.001, 1.0, 0.01,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
"force_cpu": (bool,
|
||||
_("Force _CPU"),
|
||||
"Force CPU",
|
||||
False,
|
||||
GObject.ParamFlags.READWRITE),
|
||||
# "output_format": (str,
|
||||
# _("Output format"),
|
||||
# "Output format: 'pixel count', 'normalized', 'percent'",
|
||||
# "pixel count",
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
}
|
||||
|
||||
## GimpPlugIn virtual methods ##
|
||||
def do_query_procedures(self):
|
||||
self.set_translation_domain("gimp30-python",
|
||||
Gio.file_new_for_path(Gimp.locale_directory()))
|
||||
return ['faceparse']
|
||||
|
||||
def do_create_procedure(self, name):
|
||||
procedure = None
|
||||
if name == 'faceparse':
|
||||
procedure = Gimp.ImageProcedure.new(self, name, Gimp.PDBProcType.PLUGIN, run, None)
|
||||
procedure.set_image_types("*")
|
||||
procedure.set_documentation (
|
||||
N_("Performs semantic segmentation for a portrait in the current layer."),
|
||||
globals()["__doc__"], # This includes the docstring, on the top of the file
|
||||
name)
|
||||
procedure.set_menu_label(N_("_Face Parse..."))
|
||||
procedure.set_attribution("Kritik Soman",
|
||||
"GIMP-ML",
|
||||
"2021")
|
||||
procedure.add_menu_path("<Image>/Layer/GIMP-ML/")
|
||||
|
||||
# procedure.add_argument_from_property(self, "file")
|
||||
# procedure.add_argument_from_property(self, "bucket_size")
|
||||
procedure.add_argument_from_property(self, "force_cpu")
|
||||
# procedure.add_argument_from_property(self, "output_format")
|
||||
|
||||
return procedure
|
||||
|
||||
|
||||
Gimp.main(FaceParse.__gtype__, sys.argv)
|
@ -0,0 +1,240 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding: utf-8
|
||||
"""
|
||||
.d8888b. 8888888 888b d888 8888888b. 888b d888 888
|
||||
d88P Y88b 888 8888b d8888 888 Y88b 8888b d8888 888
|
||||
888 888 888 88888b.d88888 888 888 88888b.d88888 888
|
||||
888 888 888Y88888P888 888 d88P 888Y88888P888 888
|
||||
888 88888 888 888 Y888P 888 8888888P" 888 Y888P 888 888
|
||||
888 888 888 888 Y8P 888 888 888 Y8P 888 888
|
||||
Y88b d88P 888 888 " 888 888 888 " 888 888
|
||||
"Y8888P88 8888888 888 888 888 888 888 88888888
|
||||
|
||||
|
||||
Extracts the monocular depth of the current layer.
|
||||
"""
|
||||
import sys
|
||||
import gi
|
||||
|
||||
gi.require_version('Gimp', '3.0')
|
||||
from gi.repository import Gimp
|
||||
|
||||
gi.require_version('GimpUi', '3.0')
|
||||
from gi.repository import GimpUi
|
||||
from gi.repository import GObject
|
||||
from gi.repository import GLib
|
||||
from gi.repository import Gio
|
||||
|
||||
gi.require_version('Gtk', '3.0')
|
||||
from gi.repository import Gtk
|
||||
|
||||
import gettext
|
||||
|
||||
_ = gettext.gettext
|
||||
|
||||
|
||||
def N_(message): return message
|
||||
|
||||
|
||||
import subprocess
|
||||
import pickle
|
||||
import os
|
||||
|
||||
|
||||
def inpainting(procedure, image, n_drawables, drawables, force_cpu, progress_bar):
|
||||
# layers = Gimp.Image.get_selected_layers(image)
|
||||
# Gimp.get_pdb().run_procedure('gimp-message', [GObject.Value(GObject.TYPE_STRING, "Error")])
|
||||
config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "tools")
|
||||
with open(os.path.join(config_path, 'gimp_ml_config.pkl'), 'rb') as file:
|
||||
data_output = pickle.load(file)
|
||||
weight_path = data_output["weight_path"]
|
||||
python_path = data_output["python_path"]
|
||||
plugin_path = os.path.join(config_path, 'inpainting.py')
|
||||
|
||||
Gimp.context_push()
|
||||
image.undo_group_start()
|
||||
|
||||
for index, drawable in enumerate(drawables):
|
||||
interlace, compression = 0, 2
|
||||
Gimp.get_pdb().run_procedure('file-png-save', [
|
||||
GObject.Value(Gimp.RunMode, Gimp.RunMode.NONINTERACTIVE),
|
||||
GObject.Value(Gimp.Image, image),
|
||||
GObject.Value(GObject.TYPE_INT, 1),
|
||||
GObject.Value(Gimp.ObjectArray, Gimp.ObjectArray.new(Gimp.Drawable, [drawable], 1)),
|
||||
GObject.Value(Gio.File,
|
||||
Gio.File.new_for_path(os.path.join(weight_path, '..', 'cache' + str(index) + '.png'))),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, interlace),
|
||||
GObject.Value(GObject.TYPE_INT, compression),
|
||||
# write all PNG chunks except oFFs(ets)
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, False),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
])
|
||||
|
||||
with open(os.path.join(weight_path, '..', 'gimp_ml_run.pkl'), 'wb') as file:
|
||||
pickle.dump({"force_cpu": bool(force_cpu)}, file)
|
||||
|
||||
subprocess.call([python_path, plugin_path])
|
||||
|
||||
result = Gimp.file_load(Gimp.RunMode.NONINTERACTIVE,
|
||||
Gio.file_new_for_path(os.path.join(weight_path, '..', 'cache.png')))
|
||||
result_layer = result.get_active_layer()
|
||||
copy = Gimp.Layer.new_from_drawable(result_layer, image)
|
||||
copy.set_name("In Painting")
|
||||
copy.set_mode(Gimp.LayerMode.NORMAL_LEGACY) # DIFFERENCE_LEGACY
|
||||
image.insert_layer(copy, None, -1)
|
||||
|
||||
image.undo_group_end()
|
||||
Gimp.context_pop()
|
||||
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.SUCCESS, GLib.Error())
|
||||
|
||||
|
||||
def run(procedure, run_mode, image, n_drawables, layer, args, data):
|
||||
# gio_file = args.index(0)
|
||||
# bucket_size = args.index(0)
|
||||
force_cpu = args.index(1)
|
||||
# output_format = args.index(2)
|
||||
|
||||
progress_bar = None
|
||||
config = None
|
||||
|
||||
if run_mode == Gimp.RunMode.INTERACTIVE:
|
||||
|
||||
config = procedure.create_config()
|
||||
|
||||
# Set properties from arguments. These properties will be changed by the UI.
|
||||
# config.set_property("file", gio_file)
|
||||
# config.set_property("bucket_size", bucket_size)
|
||||
config.set_property("force_cpu", force_cpu)
|
||||
# config.set_property("output_format", output_format)
|
||||
config.begin_run(image, run_mode, args)
|
||||
|
||||
GimpUi.init("inpainting.py")
|
||||
use_header_bar = Gtk.Settings.get_default().get_property("gtk-dialogs-use-header")
|
||||
dialog = GimpUi.Dialog(use_header_bar=use_header_bar,
|
||||
title=_("In Painting..."))
|
||||
dialog.add_button("_Cancel", Gtk.ResponseType.CANCEL)
|
||||
dialog.add_button("_OK", Gtk.ResponseType.OK)
|
||||
|
||||
vbox = Gtk.Box(orientation=Gtk.Orientation.VERTICAL,
|
||||
homogeneous=False, spacing=10)
|
||||
dialog.get_content_area().add(vbox)
|
||||
vbox.show()
|
||||
|
||||
# Create grid to set all the properties inside.
|
||||
grid = Gtk.Grid()
|
||||
grid.set_column_homogeneous(False)
|
||||
grid.set_border_width(10)
|
||||
grid.set_column_spacing(10)
|
||||
grid.set_row_spacing(10)
|
||||
vbox.add(grid)
|
||||
grid.show()
|
||||
|
||||
# # Bucket size parameter
|
||||
# label = Gtk.Label.new_with_mnemonic(_("_Bucket Size"))
|
||||
# grid.attach(label, 0, 1, 1, 1)
|
||||
# label.show()
|
||||
# spin = GimpUi.prop_spin_button_new(config, "bucket_size", step_increment=0.001, page_increment=0.1, digits=3)
|
||||
# grid.attach(spin, 1, 1, 1, 1)
|
||||
# spin.show()
|
||||
|
||||
# Force CPU parameter
|
||||
spin = GimpUi.prop_check_button_new(config, "force_cpu", _("Force _CPU"))
|
||||
spin.set_tooltip_text(_("If checked, CPU is used for model inference."
|
||||
" Otherwise, GPU will be used if available."))
|
||||
grid.attach(spin, 1, 2, 1, 1)
|
||||
spin.show()
|
||||
|
||||
# # Output format parameter
|
||||
# label = Gtk.Label.new_with_mnemonic(_("_Output Format"))
|
||||
# grid.attach(label, 0, 3, 1, 1)
|
||||
# label.show()
|
||||
# combo = GimpUi.prop_string_combo_box_new(config, "output_format", output_format_enum.get_tree_model(), 0, 1)
|
||||
# grid.attach(combo, 1, 3, 1, 1)
|
||||
# combo.show()
|
||||
|
||||
progress_bar = Gtk.ProgressBar()
|
||||
vbox.add(progress_bar)
|
||||
progress_bar.show()
|
||||
|
||||
dialog.show()
|
||||
if dialog.run() != Gtk.ResponseType.OK:
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.CANCEL,
|
||||
GLib.Error())
|
||||
|
||||
result = inpainting(procedure, image, n_drawables, layer, force_cpu, progress_bar)
|
||||
|
||||
# If the execution was successful, save parameters so they will be restored next time we show dialog.
|
||||
if result.index(0) == Gimp.PDBStatusType.SUCCESS and config is not None:
|
||||
config.end_run(Gimp.PDBStatusType.SUCCESS)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class InPainting(Gimp.PlugIn):
|
||||
## Parameters ##
|
||||
__gproperties__ = {
|
||||
# "filename": (str,
|
||||
# # TODO: I wanted this property to be a path (and not just str) , so I could use
|
||||
# # prop_file_chooser_button_new to open a file dialog. However, it fails without an error message.
|
||||
# # Gimp.ConfigPath,
|
||||
# _("Histogram _File"),
|
||||
# _("Histogram _File"),
|
||||
# "inpainting.csv",
|
||||
# # Gimp.ConfigPathType.FILE,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "file": (Gio.File,
|
||||
# _("Histogram _File"),
|
||||
# "Histogram export file",
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "bucket_size": (float,
|
||||
# _("_Bucket Size"),
|
||||
# "Bucket Size",
|
||||
# 0.001, 1.0, 0.01,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
"force_cpu": (bool,
|
||||
_("Force _CPU"),
|
||||
"Force CPU",
|
||||
False,
|
||||
GObject.ParamFlags.READWRITE),
|
||||
# "output_format": (str,
|
||||
# _("Output format"),
|
||||
# "Output format: 'pixel count', 'normalized', 'percent'",
|
||||
# "pixel count",
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
}
|
||||
|
||||
## GimpPlugIn virtual methods ##
|
||||
def do_query_procedures(self):
|
||||
self.set_translation_domain("gimp30-python",
|
||||
Gio.file_new_for_path(Gimp.locale_directory()))
|
||||
return ['inpainting']
|
||||
|
||||
def do_create_procedure(self, name):
|
||||
procedure = None
|
||||
if name == 'inpainting':
|
||||
procedure = Gimp.ImageProcedure.new(self, name, Gimp.PDBProcType.PLUGIN, run, None)
|
||||
procedure.set_image_types("*")
|
||||
procedure.set_sensitivity_mask(
|
||||
Gimp.ProcedureSensitivityMask.DRAWABLE | Gimp.ProcedureSensitivityMask.DRAWABLES)
|
||||
procedure.set_documentation(
|
||||
N_("Extracts the monocular depth of the current layer."),
|
||||
globals()["__doc__"], # This includes the docstring, on the top of the file
|
||||
name)
|
||||
procedure.set_menu_label(N_("_In Painting..."))
|
||||
procedure.set_attribution("Kritik Soman",
|
||||
"GIMP-ML",
|
||||
"2021")
|
||||
procedure.add_menu_path("<Image>/Layer/GIMP-ML/")
|
||||
|
||||
# procedure.add_argument_from_property(self, "file")
|
||||
# procedure.add_argument_from_property(self, "bucket_size")
|
||||
procedure.add_argument_from_property(self, "force_cpu")
|
||||
# procedure.add_argument_from_property(self, "output_format")
|
||||
|
||||
return procedure
|
||||
|
||||
|
||||
Gimp.main(InPainting.__gtype__, sys.argv)
|
@ -0,0 +1,156 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (C) 1997 James Henstridge <james@daa.com.au>
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation; either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
import gi
|
||||
gi.require_version('Gimp', '3.0')
|
||||
from gi.repository import Gimp
|
||||
gi.require_version('GimpUi', '3.0')
|
||||
from gi.repository import GimpUi
|
||||
from gi.repository import GObject
|
||||
from gi.repository import GLib
|
||||
from gi.repository import Gio
|
||||
import time
|
||||
import sys
|
||||
|
||||
import gettext
|
||||
_ = gettext.gettext
|
||||
def N_(message): return message
|
||||
|
||||
def inpainting(procedure, run_mode, image, n_drawables, drawables, args, data):
|
||||
config = procedure.create_config()
|
||||
config.begin_run(image, run_mode, args)
|
||||
|
||||
if run_mode == Gimp.RunMode.INTERACTIVE:
|
||||
GimpUi.init('python-fu-inpainting')
|
||||
dialog = GimpUi.ProcedureDialog.new(procedure, config)
|
||||
dialog.get_color_widget('color', True, GimpUi.ColorAreaType.FLAT)
|
||||
dialog.fill(None)
|
||||
if not dialog.run():
|
||||
dialog.destroy()
|
||||
config.end_run(Gimp.PDBStatusType.CANCEL)
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.CANCEL, GLib.Error())
|
||||
else:
|
||||
dialog.destroy()
|
||||
|
||||
color = config.get_property('color')
|
||||
name = config.get_property('name')
|
||||
turbulence = config.get_property('turbulence')
|
||||
opacity = config.get_property('opacity')
|
||||
|
||||
Gimp.context_push()
|
||||
image.undo_group_start()
|
||||
|
||||
if image.get_base_type() is Gimp.ImageBaseType.RGB:
|
||||
type = Gimp.ImageType.RGBA_IMAGE
|
||||
else:
|
||||
type = Gimp.ImageType.GRAYA_IMAGE
|
||||
for drawable in drawables:
|
||||
fog = Gimp.Layer.new(image, name,
|
||||
drawable.get_width(), drawable.get_height(),
|
||||
type, opacity,
|
||||
Gimp.LayerMode.NORMAL)
|
||||
fog.fill(Gimp.FillType.TRANSPARENT)
|
||||
image.insert_layer(fog, drawable.get_parent(),
|
||||
image.get_item_position(drawable))
|
||||
|
||||
Gimp.context_set_background(color)
|
||||
fog.edit_fill(Gimp.FillType.BACKGROUND)
|
||||
|
||||
# create a layer mask for the new layer
|
||||
mask = fog.create_mask(0)
|
||||
fog.add_mask(mask)
|
||||
|
||||
# add some clouds to the layer
|
||||
Gimp.get_pdb().run_procedure('plug-in-plasma', [
|
||||
GObject.Value(Gimp.RunMode, Gimp.RunMode.NONINTERACTIVE),
|
||||
GObject.Value(Gimp.Image, image),
|
||||
GObject.Value(Gimp.Drawable, mask),
|
||||
GObject.Value(GObject.TYPE_INT, int(time.time())),
|
||||
GObject.Value(GObject.TYPE_DOUBLE, turbulence),
|
||||
])
|
||||
|
||||
# apply the clouds to the layer
|
||||
fog.remove_mask(Gimp.MaskApplyMode.APPLY)
|
||||
fog.set_visible(True)
|
||||
|
||||
Gimp.displays_flush()
|
||||
|
||||
image.undo_group_end()
|
||||
Gimp.context_pop()
|
||||
|
||||
config.end_run(Gimp.PDBStatusType.SUCCESS)
|
||||
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.SUCCESS, GLib.Error())
|
||||
|
||||
_color = Gimp.RGB()
|
||||
_color.set(240.0, 0, 0)
|
||||
|
||||
class InPainting (Gimp.PlugIn):
|
||||
## Parameters ##
|
||||
__gproperties__ = {
|
||||
"name": (str,
|
||||
_("Layer _name"),
|
||||
_("Layer name"),
|
||||
_("Clouds"),
|
||||
GObject.ParamFlags.READWRITE),
|
||||
"turbulence": (float,
|
||||
_("_Turbulence"),
|
||||
_("Turbulence"),
|
||||
0.0, 10.0, 1.0,
|
||||
GObject.ParamFlags.READWRITE),
|
||||
"opacity": (float,
|
||||
_("O_pacity"),
|
||||
_("Opacity"),
|
||||
0.0, 100.0, 100.0,
|
||||
GObject.ParamFlags.READWRITE),
|
||||
}
|
||||
# I use a different syntax for this property because I think it is
|
||||
# supposed to allow setting a default, except it doesn't seem to
|
||||
# work. I still leave it this way for now until we figure this out
|
||||
# as it should be the better syntax.
|
||||
color = GObject.Property(type =Gimp.RGB, default=_color,
|
||||
nick =_("Fog _color"),
|
||||
blurb=_("Fog color"))
|
||||
|
||||
## GimpPlugIn virtual methods ##
|
||||
def do_query_procedures(self):
|
||||
self.set_translation_domain("gimp30-python",
|
||||
Gio.file_new_for_path(Gimp.locale_directory()))
|
||||
|
||||
return [ 'python-fu-foggify' ]
|
||||
|
||||
def do_create_procedure(self, name):
|
||||
procedure = Gimp.ImageProcedure.new(self, name,
|
||||
Gimp.PDBProcType.PLUGIN,
|
||||
inpainting, None)
|
||||
procedure.set_image_types("RGB*, GRAY*");
|
||||
procedure.set_sensitivity_mask (Gimp.ProcedureSensitivityMask.DRAWABLES)
|
||||
procedure.set_documentation (N_("Add a layer of fog"),
|
||||
"Adds a layer of fog to the image.",
|
||||
name)
|
||||
procedure.set_menu_label(N_("In painting..."))
|
||||
procedure.set_attribution("James Henstridge",
|
||||
"James Henstridge",
|
||||
"1999,2007")
|
||||
procedure.add_menu_path("<Image>/Layer/GIMP-ML/")
|
||||
|
||||
procedure.add_argument_from_property(self, "name")
|
||||
procedure.add_argument_from_property(self, "color")
|
||||
procedure.add_argument_from_property(self, "turbulence")
|
||||
procedure.add_argument_from_property(self, "opacity")
|
||||
return procedure
|
||||
|
||||
Gimp.main(InPainting.__gtype__, sys.argv)
|
@ -0,0 +1,273 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding: utf-8
|
||||
"""
|
||||
.d8888b. 8888888 888b d888 8888888b. 888b d888 888
|
||||
d88P Y88b 888 8888b d8888 888 Y88b 8888b d8888 888
|
||||
888 888 888 88888b.d88888 888 888 88888b.d88888 888
|
||||
888 888 888Y88888P888 888 d88P 888Y88888P888 888
|
||||
888 88888 888 888 Y888P 888 8888888P" 888 Y888P 888 888
|
||||
888 888 888 888 Y8P 888 888 888 Y8P 888 888
|
||||
Y88b d88P 888 888 " 888 888 888 " 888 888
|
||||
"Y8888P88 8888888 888 888 888 888 888 88888888
|
||||
|
||||
|
||||
Extracts the monocular depth of the current layer.
|
||||
"""
|
||||
import sys
|
||||
import gi
|
||||
|
||||
gi.require_version('Gimp', '3.0')
|
||||
from gi.repository import Gimp
|
||||
|
||||
gi.require_version('GimpUi', '3.0')
|
||||
from gi.repository import GimpUi
|
||||
from gi.repository import GObject
|
||||
from gi.repository import GLib
|
||||
from gi.repository import Gio
|
||||
|
||||
gi.require_version('Gtk', '3.0')
|
||||
from gi.repository import Gtk
|
||||
|
||||
import gettext
|
||||
|
||||
_ = gettext.gettext
|
||||
|
||||
|
||||
def N_(message): return message
|
||||
|
||||
|
||||
import subprocess
|
||||
import pickle
|
||||
import os
|
||||
|
||||
|
||||
def interpolation(procedure, image, n_drawables, drawables, force_cpu, progress_bar, gio_file):
|
||||
# layers = Gimp.Image.get_selected_layers(image)
|
||||
# Gimp.get_pdb().run_procedure('gimp-message', [GObject.Value(GObject.TYPE_STRING, "Error")])
|
||||
config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "tools")
|
||||
with open(os.path.join(config_path, 'gimp_ml_config.pkl'), 'rb') as file:
|
||||
data_output = pickle.load(file)
|
||||
weight_path = data_output["weight_path"]
|
||||
python_path = data_output["python_path"]
|
||||
plugin_path = os.path.join(config_path, 'interpolation.py')
|
||||
|
||||
Gimp.context_push()
|
||||
image.undo_group_start()
|
||||
|
||||
for index, drawable in enumerate(drawables):
|
||||
interlace, compression = 0, 2
|
||||
Gimp.get_pdb().run_procedure('file-png-save', [
|
||||
GObject.Value(Gimp.RunMode, Gimp.RunMode.NONINTERACTIVE),
|
||||
GObject.Value(Gimp.Image, image),
|
||||
GObject.Value(GObject.TYPE_INT, 1),
|
||||
GObject.Value(Gimp.ObjectArray, Gimp.ObjectArray.new(Gimp.Drawable, [drawable], 1)),
|
||||
GObject.Value(Gio.File,
|
||||
Gio.File.new_for_path(os.path.join(weight_path, '..', 'cache' + str(index) + '.png'))),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, interlace),
|
||||
GObject.Value(GObject.TYPE_INT, compression),
|
||||
# write all PNG chunks except oFFs(ets)
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, False),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
])
|
||||
|
||||
with open(os.path.join(weight_path, '..', 'gimp_ml_run.pkl'), 'wb') as file:
|
||||
pickle.dump({"force_cpu": bool(force_cpu), "gio_file": str(gio_file)}, file)
|
||||
|
||||
subprocess.call([python_path, plugin_path])
|
||||
|
||||
result = Gimp.file_load(Gimp.RunMode.NONINTERACTIVE,
|
||||
Gio.file_new_for_path(os.path.join(weight_path, '..', 'cache.png')))
|
||||
result_layer = result.get_active_layer()
|
||||
copy = Gimp.Layer.new_from_drawable(result_layer, image)
|
||||
copy.set_name("interpolation")
|
||||
copy.set_mode(Gimp.LayerMode.NORMAL_LEGACY) # DIFFERENCE_LEGACY
|
||||
image.insert_layer(copy, None, -1)
|
||||
|
||||
image.undo_group_end()
|
||||
Gimp.context_pop()
|
||||
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.SUCCESS, GLib.Error())
|
||||
|
||||
|
||||
def run(procedure, run_mode, image, n_drawables, layer, args, data):
|
||||
gio_file = args.index(0)
|
||||
# bucket_size = args.index(0)
|
||||
force_cpu = args.index(1)
|
||||
# output_format = args.index(2)
|
||||
|
||||
progress_bar = None
|
||||
config = None
|
||||
|
||||
if run_mode == Gimp.RunMode.INTERACTIVE:
|
||||
|
||||
config = procedure.create_config()
|
||||
|
||||
# Set properties from arguments. These properties will be changed by the UI.
|
||||
# config.set_property("file", gio_file)
|
||||
# config.set_property("bucket_size", bucket_size)
|
||||
config.set_property("force_cpu", force_cpu)
|
||||
# config.set_property("output_format", output_format)
|
||||
config.begin_run(image, run_mode, args)
|
||||
|
||||
GimpUi.init("interpolation.py")
|
||||
use_header_bar = Gtk.Settings.get_default().get_property("gtk-dialogs-use-header")
|
||||
dialog = GimpUi.Dialog(use_header_bar=use_header_bar,
|
||||
title=_("interpolation..."))
|
||||
dialog.add_button("_Cancel", Gtk.ResponseType.CANCEL)
|
||||
dialog.add_button("_OK", Gtk.ResponseType.OK)
|
||||
|
||||
vbox = Gtk.Box(orientation=Gtk.Orientation.VERTICAL,
|
||||
homogeneous=False, spacing=10)
|
||||
dialog.get_content_area().add(vbox)
|
||||
vbox.show()
|
||||
|
||||
# Create grid to set all the properties inside.
|
||||
grid = Gtk.Grid()
|
||||
grid.set_column_homogeneous(False)
|
||||
grid.set_border_width(10)
|
||||
grid.set_column_spacing(10)
|
||||
grid.set_row_spacing(10)
|
||||
vbox.add(grid)
|
||||
grid.show()
|
||||
|
||||
|
||||
# UI for the file parameter
|
||||
|
||||
def choose_file(widget):
|
||||
if file_chooser_dialog.run() == Gtk.ResponseType.OK:
|
||||
if file_chooser_dialog.get_file() is not None:
|
||||
config.set_property("file", file_chooser_dialog.get_file())
|
||||
file_entry.set_text(file_chooser_dialog.get_file().get_path())
|
||||
file_chooser_dialog.hide()
|
||||
|
||||
file_chooser_button = Gtk.Button.new_with_mnemonic(label=_("_Folder..."))
|
||||
grid.attach(file_chooser_button, 0, 0, 1, 1)
|
||||
file_chooser_button.show()
|
||||
file_chooser_button.connect("clicked", choose_file)
|
||||
|
||||
file_entry = Gtk.Entry.new()
|
||||
grid.attach(file_entry, 1, 0, 1, 1)
|
||||
file_entry.set_width_chars(40)
|
||||
file_entry.set_placeholder_text(_("Choose export folder..."))
|
||||
if gio_file is not None:
|
||||
file_entry.set_text(gio_file.get_path())
|
||||
file_entry.show()
|
||||
|
||||
file_chooser_dialog = Gtk.FileChooserDialog(use_header_bar=use_header_bar,
|
||||
title=_("Frame Export folder..."),
|
||||
action=Gtk.FileChooserAction.SELECT_FOLDER)
|
||||
file_chooser_dialog.add_button("_Cancel", Gtk.ResponseType.CANCEL)
|
||||
file_chooser_dialog.add_button("_OK", Gtk.ResponseType.OK)
|
||||
|
||||
# # Bucket size parameter
|
||||
# label = Gtk.Label.new_with_mnemonic(_("_Bucket Size"))
|
||||
# grid.attach(label, 0, 1, 1, 1)
|
||||
# label.show()
|
||||
# spin = GimpUi.prop_spin_button_new(config, "bucket_size", step_increment=0.001, page_increment=0.1, digits=3)
|
||||
# grid.attach(spin, 1, 1, 1, 1)
|
||||
# spin.show()
|
||||
|
||||
|
||||
|
||||
# Force CPU parameter
|
||||
spin = GimpUi.prop_check_button_new(config, "force_cpu", _("Force _CPU"))
|
||||
spin.set_tooltip_text(_("If checked, CPU is used for model inference."
|
||||
" Otherwise, GPU will be used if available."))
|
||||
grid.attach(spin, 1, 2, 1, 1)
|
||||
spin.show()
|
||||
|
||||
# # Output format parameter
|
||||
# label = Gtk.Label.new_with_mnemonic(_("_Output Format"))
|
||||
# grid.attach(label, 0, 3, 1, 1)
|
||||
# label.show()
|
||||
# combo = GimpUi.prop_string_combo_box_new(config, "output_format", output_format_enum.get_tree_model(), 0, 1)
|
||||
# grid.attach(combo, 1, 3, 1, 1)
|
||||
# combo.show()
|
||||
|
||||
progress_bar = Gtk.ProgressBar()
|
||||
vbox.add(progress_bar)
|
||||
progress_bar.show()
|
||||
|
||||
dialog.show()
|
||||
if dialog.run() != Gtk.ResponseType.OK:
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.CANCEL,
|
||||
GLib.Error())
|
||||
|
||||
gio_file = file_entry.get_text()
|
||||
|
||||
result = interpolation(procedure, image, n_drawables, layer, force_cpu, progress_bar, gio_file)
|
||||
|
||||
# If the execution was successful, save parameters so they will be restored next time we show dialog.
|
||||
if result.index(0) == Gimp.PDBStatusType.SUCCESS and config is not None:
|
||||
config.end_run(Gimp.PDBStatusType.SUCCESS)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Interpolation(Gimp.PlugIn):
|
||||
## Parameters ##
|
||||
__gproperties__ = {
|
||||
# "filename": (str,
|
||||
# # TODO: I wanted this property to be a path (and not just str) , so I could use
|
||||
# # prop_file_chooser_button_new to open a file dialog. However, it fails without an error message.
|
||||
# # Gimp.ConfigPath,
|
||||
# _("Histogram _File"),
|
||||
# _("Histogram _File"),
|
||||
# "interpolation.csv",
|
||||
# # Gimp.ConfigPathType.FILE,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
"file": (Gio.File,
|
||||
_("Histogram _File"),
|
||||
"Histogram export file",
|
||||
GObject.ParamFlags.READWRITE),
|
||||
# "bucket_size": (float,
|
||||
# _("_Bucket Size"),
|
||||
# "Bucket Size",
|
||||
# 0.001, 1.0, 0.01,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
"force_cpu": (bool,
|
||||
_("Force _CPU"),
|
||||
"Force CPU",
|
||||
False,
|
||||
GObject.ParamFlags.READWRITE),
|
||||
# "output_format": (str,
|
||||
# _("Output format"),
|
||||
# "Output format: 'pixel count', 'normalized', 'percent'",
|
||||
# "pixel count",
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
}
|
||||
|
||||
## GimpPlugIn virtual methods ##
|
||||
def do_query_procedures(self):
|
||||
self.set_translation_domain("gimp30-python",
|
||||
Gio.file_new_for_path(Gimp.locale_directory()))
|
||||
return ['interpolation']
|
||||
|
||||
def do_create_procedure(self, name):
|
||||
procedure = None
|
||||
if name == 'interpolation':
|
||||
procedure = Gimp.ImageProcedure.new(self, name, Gimp.PDBProcType.PLUGIN, run, None)
|
||||
procedure.set_image_types("*")
|
||||
procedure.set_sensitivity_mask(
|
||||
Gimp.ProcedureSensitivityMask.DRAWABLE | Gimp.ProcedureSensitivityMask.DRAWABLES)
|
||||
procedure.set_documentation(
|
||||
N_("Extracts the monocular depth of the current layer."),
|
||||
globals()["__doc__"], # This includes the docstring, on the top of the file
|
||||
name)
|
||||
procedure.set_menu_label(N_("_Interpolation..."))
|
||||
procedure.set_attribution("Kritik Soman",
|
||||
"GIMP-ML",
|
||||
"2021")
|
||||
procedure.add_menu_path("<Image>/Layer/GIMP-ML/")
|
||||
|
||||
procedure.add_argument_from_property(self, "file")
|
||||
# procedure.add_argument_from_property(self, "bucket_size")
|
||||
procedure.add_argument_from_property(self, "force_cpu")
|
||||
# procedure.add_argument_from_property(self, "output_format")
|
||||
|
||||
return procedure
|
||||
|
||||
|
||||
Gimp.main(Interpolation.__gtype__, sys.argv)
|
@ -0,0 +1,339 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding: utf-8
|
||||
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation; either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
"""
|
||||
Performs k means clustering for current layer.
|
||||
"""
|
||||
|
||||
import csv
|
||||
import math
|
||||
import sys
|
||||
|
||||
import gi
|
||||
|
||||
gi.require_version('Gimp', '3.0')
|
||||
from gi.repository import Gimp
|
||||
|
||||
gi.require_version('GimpUi', '3.0')
|
||||
from gi.repository import GimpUi
|
||||
from gi.repository import GObject
|
||||
from gi.repository import GLib
|
||||
from gi.repository import Gio
|
||||
|
||||
gi.require_version('Gtk', '3.0')
|
||||
from gi.repository import Gtk
|
||||
|
||||
import gettext
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
|
||||
_ = gettext.gettext
|
||||
|
||||
|
||||
def N_(message): return message
|
||||
|
||||
|
||||
class StringEnum:
|
||||
"""
|
||||
Helper class for when you want to use strings as keys of an enum. The values would be
|
||||
user facing strings that might undergo translation.
|
||||
|
||||
The constructor accepts an even amount of arguments. Each pair of arguments
|
||||
is a key/value pair.
|
||||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
self.keys = []
|
||||
self.values = []
|
||||
|
||||
for i in range(len(args) // 2):
|
||||
self.keys.append(args[i * 2])
|
||||
self.values.append(args[i * 2 + 1])
|
||||
|
||||
def get_tree_model(self):
|
||||
""" Get a tree model that can be used in GTK widgets. """
|
||||
tree_model = Gtk.ListStore(GObject.TYPE_STRING, GObject.TYPE_STRING)
|
||||
for i in range(len(self.keys)):
|
||||
tree_model.append([self.keys[i], self.values[i]])
|
||||
return tree_model
|
||||
|
||||
def __getattr__(self, name):
|
||||
""" Implements access to the key. For example, if you provided a key "red", then you could access it by
|
||||
referring to
|
||||
my_enum.red
|
||||
It may seem silly as "my_enum.red" is longer to write then just "red",
|
||||
but this provides verification that the key is indeed inside enum. """
|
||||
key = name.replace("_", " ")
|
||||
if key in self.keys:
|
||||
return key
|
||||
raise AttributeError("No such key string " + key)
|
||||
|
||||
|
||||
# output_format_enum = StringEnum(
|
||||
# "pixel count", _("Pixel count"),
|
||||
# "normalized", _("Normalized"),
|
||||
# "percent", _("Percent")
|
||||
# )
|
||||
|
||||
|
||||
def k_means(procedure, image, drawable, n_cluster, position, progress_bar):
|
||||
config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "tools")
|
||||
with open(os.path.join(config_path, 'gimp_ml_config.pkl'), 'rb') as file:
|
||||
data_output = pickle.load(file)
|
||||
weight_path = data_output["weight_path"]
|
||||
python_path = data_output["python_path"]
|
||||
plugin_path = os.path.join(config_path, 'kmeans.py')
|
||||
|
||||
Gimp.context_push()
|
||||
image.undo_group_start()
|
||||
|
||||
interlace, compression = 0, 2
|
||||
Gimp.get_pdb().run_procedure('file-png-save', [
|
||||
GObject.Value(Gimp.RunMode, Gimp.RunMode.NONINTERACTIVE),
|
||||
GObject.Value(Gimp.Image, image),
|
||||
GObject.Value(GObject.TYPE_INT, 1),
|
||||
GObject.Value(Gimp.ObjectArray, Gimp.ObjectArray.new(Gimp.Drawable, drawable, 1)),
|
||||
GObject.Value(Gio.File, Gio.File.new_for_path(os.path.join(weight_path, '..', 'cache.png'))),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, interlace),
|
||||
GObject.Value(GObject.TYPE_INT, compression),
|
||||
# write all PNG chunks except oFFs(ets)
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, False),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
])
|
||||
|
||||
with open(os.path.join(weight_path, '..', 'gimp_ml_run.pkl'), 'wb') as file:
|
||||
pickle.dump({"n_cluster": int(n_cluster), "position": bool(position)}, file)
|
||||
|
||||
subprocess.call([python_path, plugin_path])
|
||||
|
||||
result = Gimp.file_load(Gimp.RunMode.NONINTERACTIVE, Gio.file_new_for_path(os.path.join(weight_path, '..', 'cache.png')))
|
||||
result_layer = result.get_active_layer()
|
||||
copy = Gimp.Layer.new_from_drawable(result_layer, image)
|
||||
copy.set_name("K Means")
|
||||
copy.set_mode(Gimp.LayerMode.NORMAL_LEGACY)#DIFFERENCE_LEGACY
|
||||
image.insert_layer(copy, None, -1)
|
||||
|
||||
image.undo_group_end()
|
||||
Gimp.context_pop()
|
||||
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.SUCCESS, GLib.Error())
|
||||
|
||||
|
||||
def run(procedure, run_mode, image, n_drawables, layer, args, data):
|
||||
# gio_file = args.index(0)
|
||||
n_cluster = args.index(0)
|
||||
position = args.index(1)
|
||||
# output_format = args.index(3)
|
||||
# force_cpu = args.index(2)
|
||||
|
||||
progress_bar = None
|
||||
config = None
|
||||
|
||||
if run_mode == Gimp.RunMode.INTERACTIVE:
|
||||
|
||||
config = procedure.create_config()
|
||||
|
||||
# Set properties from arguments. These properties will be changed by the UI.
|
||||
# config.set_property("file", gio_file)
|
||||
# config.set_property("n_cluster", n_cluster)
|
||||
# config.set_property("sample_average", sample_average)
|
||||
# config.set_property("output_format", output_format)
|
||||
# config.set_property("force_cpu", force_cpu)
|
||||
config.begin_run(image, run_mode, args)
|
||||
|
||||
GimpUi.init("kmeans.py")
|
||||
use_header_bar = Gtk.Settings.get_default().get_property("gtk-dialogs-use-header")
|
||||
dialog = GimpUi.Dialog(use_header_bar=use_header_bar,
|
||||
title=_("K Means..."))
|
||||
dialog.add_button("_Cancel", Gtk.ResponseType.CANCEL)
|
||||
dialog.add_button("_OK", Gtk.ResponseType.OK)
|
||||
|
||||
vbox = Gtk.Box(orientation=Gtk.Orientation.VERTICAL,
|
||||
homogeneous=False, spacing=10)
|
||||
dialog.get_content_area().add(vbox)
|
||||
vbox.show()
|
||||
|
||||
# Create grid to set all the properties inside.
|
||||
grid = Gtk.Grid()
|
||||
grid.set_column_homogeneous(False)
|
||||
grid.set_border_width(10)
|
||||
grid.set_column_spacing(10)
|
||||
grid.set_row_spacing(10)
|
||||
vbox.add(grid)
|
||||
grid.show()
|
||||
|
||||
# UI for the file parameter
|
||||
|
||||
# def choose_file(widget):
|
||||
# if file_chooser_dialog.run() == Gtk.ResponseType.OK:
|
||||
# if file_chooser_dialog.get_file() is not None:
|
||||
# config.set_property("file", file_chooser_dialog.get_file())
|
||||
# file_entry.set_text(file_chooser_dialog.get_file().get_path())
|
||||
# file_chooser_dialog.hide()
|
||||
#
|
||||
# file_chooser_button = Gtk.Button.new_with_mnemonic(label=_("_File..."))
|
||||
# grid.attach(file_chooser_button, 0, 0, 1, 1)
|
||||
# file_chooser_button.show()
|
||||
# file_chooser_button.connect("clicked", choose_file)
|
||||
#
|
||||
# file_entry = Gtk.Entry.new()
|
||||
# grid.attach(file_entry, 1, 0, 1, 1)
|
||||
# file_entry.set_width_chars(40)
|
||||
# file_entry.set_placeholder_text(_("Choose export file..."))
|
||||
# # if gio_file is not None:
|
||||
# # file_entry.set_text(gio_file.get_path())
|
||||
# file_entry.show()
|
||||
#
|
||||
# file_chooser_dialog = Gtk.FileChooserDialog(use_header_bar=use_header_bar,
|
||||
# title=_("Histogram Export file..."),
|
||||
# action=Gtk.FileChooserAction.SAVE)
|
||||
# file_chooser_dialog.add_button("_Cancel", Gtk.ResponseType.CANCEL)
|
||||
# file_chooser_dialog.add_button("_OK", Gtk.ResponseType.OK)
|
||||
|
||||
# n_cluster parameter
|
||||
label = Gtk.Label.new_with_mnemonic(_("_Clusters"))
|
||||
grid.attach(label, 0, 1, 1, 1)
|
||||
label.show()
|
||||
spin = GimpUi.prop_spin_button_new(config, "n_cluster", step_increment=1, page_increment=10, digits=0)
|
||||
grid.attach(spin, 1, 1, 1, 1)
|
||||
spin.show()
|
||||
|
||||
# Sample average parameter
|
||||
spin = GimpUi.prop_check_button_new(config, "position", _("Use _Position"))
|
||||
spin.set_tooltip_text(_("If checked, x, y coordinates will be used as features for k means clustering"))
|
||||
grid.attach(spin, 1, 2, 1, 1)
|
||||
spin.show()
|
||||
|
||||
# # Output format parameter
|
||||
# label = Gtk.Label.new_with_mnemonic(_("_Output Format"))
|
||||
# grid.attach(label, 0, 3, 1, 1)
|
||||
# label.show()
|
||||
# combo = GimpUi.prop_string_combo_box_new(config, "output_format", output_format_enum.get_tree_model(), 0, 1)
|
||||
# grid.attach(combo, 1, 3, 1, 1)
|
||||
# combo.show()
|
||||
|
||||
# # Force CPU parameter
|
||||
# spin = GimpUi.prop_check_button_new(config, "force_cpu", _("Force _CPU"))
|
||||
# spin.set_tooltip_text(_("If checked, CPU is used for model inference."
|
||||
# " Otherwise, GPU will be used if available."))
|
||||
# grid.attach(spin, 1, 3, 1, 1)
|
||||
# spin.show()
|
||||
|
||||
progress_bar = Gtk.ProgressBar()
|
||||
vbox.add(progress_bar)
|
||||
progress_bar.show()
|
||||
|
||||
dialog.show()
|
||||
if dialog.run() != Gtk.ResponseType.OK:
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.CANCEL,
|
||||
GLib.Error())
|
||||
|
||||
# Extract values from UI
|
||||
# gio_file = Gio.file_new_for_path(file_entry.get_text()) # config.get_property("file")
|
||||
n_cluster = config.get_property("n_cluster")
|
||||
position = config.get_property("position")
|
||||
# force_cpu = config.get_property("force_cpu")
|
||||
|
||||
# if gio_file is None:
|
||||
# error = 'No file given'
|
||||
# return procedure.new_return_values(Gimp.PDBStatusType.CALLING_ERROR,
|
||||
# GLib.Error(error))
|
||||
|
||||
result = k_means(procedure, image, layer,
|
||||
n_cluster, position, progress_bar)
|
||||
|
||||
# If the execution was successful, save parameters so they will be restored next time we show dialog.
|
||||
if result.index(0) == Gimp.PDBStatusType.SUCCESS and config is not None:
|
||||
config.end_run(Gimp.PDBStatusType.SUCCESS)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Kmeans(Gimp.PlugIn):
|
||||
## Parameters ##
|
||||
__gproperties__ = {
|
||||
# "filename": (str,
|
||||
# # TODO: I wanted this property to be a path (and not just str) , so I could use
|
||||
# # prop_file_chooser_button_new to open a file dialog. However, it fails without an error message.
|
||||
# # Gimp.ConfigPath,
|
||||
# _("Histogram _File"),
|
||||
# _("Histogram _File"),
|
||||
# "k_means.csv",
|
||||
# # Gimp.ConfigPathType.FILE,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "file": (Gio.File,
|
||||
# _("Histogram _File"),
|
||||
# "Histogram export file",
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
"n_cluster": (float,
|
||||
_("_Clusters"),
|
||||
"Number of clusters",
|
||||
3, 64, 5,
|
||||
GObject.ParamFlags.READWRITE),
|
||||
"position": (bool,
|
||||
_("Use _Position"),
|
||||
"Use as position of pixels",
|
||||
False,
|
||||
GObject.ParamFlags.READWRITE),
|
||||
# "output_format": (str,
|
||||
# _("Output format"),
|
||||
# "Output format: 'pixel count', 'normalized', 'percent'",
|
||||
# "pixel count",
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "force_cpu": (bool,
|
||||
# _("Force _CPU"),
|
||||
# "Force CPU",
|
||||
# False,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
}
|
||||
|
||||
## GimpPlugIn virtual methods ##
|
||||
def do_query_procedures(self):
|
||||
self.set_translation_domain("gimp30-python",
|
||||
Gio.file_new_for_path(Gimp.locale_directory()))
|
||||
return ['kmeans']
|
||||
|
||||
def do_create_procedure(self, name):
|
||||
procedure = None
|
||||
if name == 'kmeans':
|
||||
procedure = Gimp.ImageProcedure.new(self, name,
|
||||
Gimp.PDBProcType.PLUGIN,
|
||||
run, None)
|
||||
|
||||
procedure.set_image_types("*")
|
||||
procedure.set_documentation(
|
||||
N_("Performs k means clustering for current layer."),
|
||||
globals()["__doc__"], # This includes the docstring, on the top of the file
|
||||
name)
|
||||
procedure.set_menu_label(N_("_K Means..."))
|
||||
procedure.set_attribution("João S. O. Bueno",
|
||||
"(c) GPL V3.0 or later",
|
||||
"2014")
|
||||
procedure.add_menu_path("<Image>/Layer/GIMP-ML/")
|
||||
|
||||
# procedure.add_argument_from_property(self, "file")
|
||||
procedure.add_argument_from_property(self, "n_cluster")
|
||||
procedure.add_argument_from_property(self, "position")
|
||||
# procedure.add_argument_from_property(self, "output_format")
|
||||
# procedure.add_argument_from_property(self, "force_cpu")
|
||||
|
||||
return procedure
|
||||
|
||||
|
||||
Gimp.main(Kmeans.__gtype__, sys.argv)
|
@ -0,0 +1,240 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding: utf-8
|
||||
"""
|
||||
.d8888b. 8888888 888b d888 8888888b. 888b d888 888
|
||||
d88P Y88b 888 8888b d8888 888 Y88b 8888b d8888 888
|
||||
888 888 888 88888b.d88888 888 888 88888b.d88888 888
|
||||
888 888 888Y88888P888 888 d88P 888Y88888P888 888
|
||||
888 88888 888 888 Y888P 888 8888888P" 888 Y888P 888 888
|
||||
888 888 888 888 Y8P 888 888 888 Y8P 888 888
|
||||
Y88b d88P 888 888 " 888 888 888 " 888 888
|
||||
"Y8888P88 8888888 888 888 888 888 888 88888888
|
||||
|
||||
|
||||
Extracts the monocular depth of the current layer.
|
||||
"""
|
||||
import sys
|
||||
import gi
|
||||
|
||||
gi.require_version('Gimp', '3.0')
|
||||
from gi.repository import Gimp
|
||||
|
||||
gi.require_version('GimpUi', '3.0')
|
||||
from gi.repository import GimpUi
|
||||
from gi.repository import GObject
|
||||
from gi.repository import GLib
|
||||
from gi.repository import Gio
|
||||
|
||||
gi.require_version('Gtk', '3.0')
|
||||
from gi.repository import Gtk
|
||||
|
||||
import gettext
|
||||
|
||||
_ = gettext.gettext
|
||||
|
||||
|
||||
def N_(message): return message
|
||||
|
||||
|
||||
import subprocess
|
||||
import pickle
|
||||
import os
|
||||
|
||||
|
||||
def matting(procedure, image, n_drawables, drawables, force_cpu, progress_bar):
|
||||
# layers = Gimp.Image.get_selected_layers(image)
|
||||
# Gimp.get_pdb().run_procedure('gimp-message', [GObject.Value(GObject.TYPE_STRING, "Error")])
|
||||
config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "tools")
|
||||
with open(os.path.join(config_path, 'gimp_ml_config.pkl'), 'rb') as file:
|
||||
data_output = pickle.load(file)
|
||||
weight_path = data_output["weight_path"]
|
||||
python_path = data_output["python_path"]
|
||||
plugin_path = os.path.join(config_path, 'matting.py')
|
||||
|
||||
Gimp.context_push()
|
||||
image.undo_group_start()
|
||||
|
||||
for index, drawable in enumerate(drawables):
|
||||
interlace, compression = 0, 2
|
||||
Gimp.get_pdb().run_procedure('file-png-save', [
|
||||
GObject.Value(Gimp.RunMode, Gimp.RunMode.NONINTERACTIVE),
|
||||
GObject.Value(Gimp.Image, image),
|
||||
GObject.Value(GObject.TYPE_INT, 1),
|
||||
GObject.Value(Gimp.ObjectArray, Gimp.ObjectArray.new(Gimp.Drawable, [drawable], 1)),
|
||||
GObject.Value(Gio.File,
|
||||
Gio.File.new_for_path(os.path.join(weight_path, '..', 'cache' + str(index) + '.png'))),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, interlace),
|
||||
GObject.Value(GObject.TYPE_INT, compression),
|
||||
# write all PNG chunks except oFFs(ets)
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, False),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
])
|
||||
|
||||
with open(os.path.join(weight_path, '..', 'gimp_ml_run.pkl'), 'wb') as file:
|
||||
pickle.dump({"force_cpu": bool(force_cpu)}, file)
|
||||
|
||||
subprocess.call([python_path, plugin_path])
|
||||
|
||||
result = Gimp.file_load(Gimp.RunMode.NONINTERACTIVE,
|
||||
Gio.file_new_for_path(os.path.join(weight_path, '..', 'cache.png')))
|
||||
result_layer = result.get_active_layer()
|
||||
copy = Gimp.Layer.new_from_drawable(result_layer, image)
|
||||
copy.set_name("Matting")
|
||||
copy.set_mode(Gimp.LayerMode.NORMAL_LEGACY) # DIFFERENCE_LEGACY
|
||||
image.insert_layer(copy, None, -1)
|
||||
|
||||
image.undo_group_end()
|
||||
Gimp.context_pop()
|
||||
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.SUCCESS, GLib.Error())
|
||||
|
||||
|
||||
def run(procedure, run_mode, image, n_drawables, layer, args, data):
|
||||
# gio_file = args.index(0)
|
||||
# bucket_size = args.index(0)
|
||||
force_cpu = args.index(1)
|
||||
# output_format = args.index(2)
|
||||
|
||||
progress_bar = None
|
||||
config = None
|
||||
|
||||
if run_mode == Gimp.RunMode.INTERACTIVE:
|
||||
|
||||
config = procedure.create_config()
|
||||
|
||||
# Set properties from arguments. These properties will be changed by the UI.
|
||||
# config.set_property("file", gio_file)
|
||||
# config.set_property("bucket_size", bucket_size)
|
||||
config.set_property("force_cpu", force_cpu)
|
||||
# config.set_property("output_format", output_format)
|
||||
config.begin_run(image, run_mode, args)
|
||||
|
||||
GimpUi.init("matting.py")
|
||||
use_header_bar = Gtk.Settings.get_default().get_property("gtk-dialogs-use-header")
|
||||
dialog = GimpUi.Dialog(use_header_bar=use_header_bar,
|
||||
title=_("matting..."))
|
||||
dialog.add_button("_Cancel", Gtk.ResponseType.CANCEL)
|
||||
dialog.add_button("_OK", Gtk.ResponseType.OK)
|
||||
|
||||
vbox = Gtk.Box(orientation=Gtk.Orientation.VERTICAL,
|
||||
homogeneous=False, spacing=10)
|
||||
dialog.get_content_area().add(vbox)
|
||||
vbox.show()
|
||||
|
||||
# Create grid to set all the properties inside.
|
||||
grid = Gtk.Grid()
|
||||
grid.set_column_homogeneous(False)
|
||||
grid.set_border_width(10)
|
||||
grid.set_column_spacing(10)
|
||||
grid.set_row_spacing(10)
|
||||
vbox.add(grid)
|
||||
grid.show()
|
||||
|
||||
# # Bucket size parameter
|
||||
# label = Gtk.Label.new_with_mnemonic(_("_Bucket Size"))
|
||||
# grid.attach(label, 0, 1, 1, 1)
|
||||
# label.show()
|
||||
# spin = GimpUi.prop_spin_button_new(config, "bucket_size", step_increment=0.001, page_increment=0.1, digits=3)
|
||||
# grid.attach(spin, 1, 1, 1, 1)
|
||||
# spin.show()
|
||||
|
||||
# Force CPU parameter
|
||||
spin = GimpUi.prop_check_button_new(config, "force_cpu", _("Force _CPU"))
|
||||
spin.set_tooltip_text(_("If checked, CPU is used for model inference."
|
||||
" Otherwise, GPU will be used if available."))
|
||||
grid.attach(spin, 1, 2, 1, 1)
|
||||
spin.show()
|
||||
|
||||
# # Output format parameter
|
||||
# label = Gtk.Label.new_with_mnemonic(_("_Output Format"))
|
||||
# grid.attach(label, 0, 3, 1, 1)
|
||||
# label.show()
|
||||
# combo = GimpUi.prop_string_combo_box_new(config, "output_format", output_format_enum.get_tree_model(), 0, 1)
|
||||
# grid.attach(combo, 1, 3, 1, 1)
|
||||
# combo.show()
|
||||
|
||||
progress_bar = Gtk.ProgressBar()
|
||||
vbox.add(progress_bar)
|
||||
progress_bar.show()
|
||||
|
||||
dialog.show()
|
||||
if dialog.run() != Gtk.ResponseType.OK:
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.CANCEL,
|
||||
GLib.Error())
|
||||
|
||||
result = matting(procedure, image, n_drawables, layer, force_cpu, progress_bar)
|
||||
|
||||
# If the execution was successful, save parameters so they will be restored next time we show dialog.
|
||||
if result.index(0) == Gimp.PDBStatusType.SUCCESS and config is not None:
|
||||
config.end_run(Gimp.PDBStatusType.SUCCESS)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Matting(Gimp.PlugIn):
|
||||
## Parameters ##
|
||||
__gproperties__ = {
|
||||
# "filename": (str,
|
||||
# # TODO: I wanted this property to be a path (and not just str) , so I could use
|
||||
# # prop_file_chooser_button_new to open a file dialog. However, it fails without an error message.
|
||||
# # Gimp.ConfigPath,
|
||||
# _("Histogram _File"),
|
||||
# _("Histogram _File"),
|
||||
# "matting.csv",
|
||||
# # Gimp.ConfigPathType.FILE,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "file": (Gio.File,
|
||||
# _("Histogram _File"),
|
||||
# "Histogram export file",
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "bucket_size": (float,
|
||||
# _("_Bucket Size"),
|
||||
# "Bucket Size",
|
||||
# 0.001, 1.0, 0.01,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
"force_cpu": (bool,
|
||||
_("Force _CPU"),
|
||||
"Force CPU",
|
||||
False,
|
||||
GObject.ParamFlags.READWRITE),
|
||||
# "output_format": (str,
|
||||
# _("Output format"),
|
||||
# "Output format: 'pixel count', 'normalized', 'percent'",
|
||||
# "pixel count",
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
}
|
||||
|
||||
## GimpPlugIn virtual methods ##
|
||||
def do_query_procedures(self):
|
||||
self.set_translation_domain("gimp30-python",
|
||||
Gio.file_new_for_path(Gimp.locale_directory()))
|
||||
return ['matting']
|
||||
|
||||
def do_create_procedure(self, name):
|
||||
procedure = None
|
||||
if name == 'matting':
|
||||
procedure = Gimp.ImageProcedure.new(self, name, Gimp.PDBProcType.PLUGIN, run, None)
|
||||
procedure.set_image_types("*")
|
||||
procedure.set_sensitivity_mask(
|
||||
Gimp.ProcedureSensitivityMask.DRAWABLE | Gimp.ProcedureSensitivityMask.DRAWABLES)
|
||||
procedure.set_documentation(
|
||||
N_("Extracts the monocular depth of the current layer."),
|
||||
globals()["__doc__"], # This includes the docstring, on the top of the file
|
||||
name)
|
||||
procedure.set_menu_label(N_("_Matting..."))
|
||||
procedure.set_attribution("Kritik Soman",
|
||||
"GIMP-ML",
|
||||
"2021")
|
||||
procedure.add_menu_path("<Image>/Layer/GIMP-ML/")
|
||||
|
||||
# procedure.add_argument_from_property(self, "file")
|
||||
# procedure.add_argument_from_property(self, "bucket_size")
|
||||
procedure.add_argument_from_property(self, "force_cpu")
|
||||
# procedure.add_argument_from_property(self, "output_format")
|
||||
|
||||
return procedure
|
||||
|
||||
|
||||
Gimp.main(Matting.__gtype__, sys.argv)
|
@ -1,44 +0,0 @@
|
||||
import pickle
|
||||
import os
|
||||
import sys
|
||||
import cv2
|
||||
|
||||
plugin_loc = os.path.dirname(os.path.realpath(__file__)) + '/'
|
||||
base_loc = os.path.expanduser("~") + '/GIMP-ML/'
|
||||
# base_loc = "D:/PycharmProjects/"
|
||||
sys.path.extend([plugin_loc + 'MiDaS'])
|
||||
# data_path = "D:/PycharmProjects/GIMP3-ML-pip/gimpml/"
|
||||
|
||||
from mono_run import run_depth
|
||||
from monodepth_net import MonoDepthNet
|
||||
import MiDaS_utils as MiDaS_utils
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
|
||||
def get_mono_depth(input_image, cFlag = False):
|
||||
image = input_image / 255.0
|
||||
out = run_depth(image, base_loc + 'weights/MiDaS/model.pt', MonoDepthNet, MiDaS_utils, target_w=640, f=cFlag)
|
||||
out = np.repeat(out[:, :, np.newaxis], 3, axis=2)
|
||||
d1, d2 = input_image.shape[:2]
|
||||
out = cv2.resize(out, (d2, d1))
|
||||
# cv2.imwrite("/Users/kritiksoman/PycharmProjects/new/out.png", out)
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# # This will run when script is run as sub-process
|
||||
# dbfile = open(data_path + "data_input", 'rb')
|
||||
# data_input = pickle.load(dbfile)
|
||||
# dbfile.close()
|
||||
# # print(data)
|
||||
# data_output = {'args_input': {'processed': 1}, 'image_output': get_mono_depth(data_input['image'])}
|
||||
#
|
||||
# dbfile = open(data_path + "data_output", 'ab')
|
||||
# pickle.dump(data_output, dbfile) # source, destination
|
||||
# dbfile.close()
|
||||
|
||||
image = cv2.imread(os.path.join(base_loc, "cache.png"))[:, :, ::-1]
|
||||
output = get_mono_depth(image)
|
||||
cv2.imwrite(os.path.join(base_loc, 'cache.png'), output[:, :, ::-1])
|
@ -0,0 +1,227 @@
|
||||
#!/usr/bin/env python3
|
||||
#coding: utf-8
|
||||
"""
|
||||
.d8888b. 8888888 888b d888 8888888b. 888b d888 888
|
||||
d88P Y88b 888 8888b d8888 888 Y88b 8888b d8888 888
|
||||
888 888 888 88888b.d88888 888 888 88888b.d88888 888
|
||||
888 888 888Y88888P888 888 d88P 888Y88888P888 888
|
||||
888 88888 888 888 Y888P 888 8888888P" 888 Y888P 888 888
|
||||
888 888 888 888 Y8P 888 888 888 Y8P 888 888
|
||||
Y88b d88P 888 888 " 888 888 888 " 888 888
|
||||
"Y8888P88 8888888 888 888 888 888 888 88888888
|
||||
|
||||
|
||||
Performs semantic segmentation of the current layer.
|
||||
"""
|
||||
import sys
|
||||
import gi
|
||||
gi.require_version('Gimp', '3.0')
|
||||
from gi.repository import Gimp
|
||||
gi.require_version('GimpUi', '3.0')
|
||||
from gi.repository import GimpUi
|
||||
from gi.repository import GObject
|
||||
from gi.repository import GLib
|
||||
from gi.repository import Gio
|
||||
gi.require_version('Gtk', '3.0')
|
||||
from gi.repository import Gtk
|
||||
|
||||
import gettext
|
||||
_ = gettext.gettext
|
||||
def N_(message): return message
|
||||
|
||||
import subprocess
|
||||
import pickle
|
||||
import os
|
||||
|
||||
def semseg(procedure, image, drawable, force_cpu, progress_bar):
|
||||
config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "tools")
|
||||
with open(os.path.join(config_path, 'gimp_ml_config.pkl'), 'rb') as file:
|
||||
data_output = pickle.load(file)
|
||||
weight_path = data_output["weight_path"]
|
||||
python_path = data_output["python_path"]
|
||||
plugin_path = os.path.join(config_path, 'semseg.py')
|
||||
|
||||
Gimp.context_push()
|
||||
image.undo_group_start()
|
||||
|
||||
interlace, compression = 0, 2
|
||||
Gimp.get_pdb().run_procedure('file-png-save', [
|
||||
GObject.Value(Gimp.RunMode, Gimp.RunMode.NONINTERACTIVE),
|
||||
GObject.Value(Gimp.Image, image),
|
||||
GObject.Value(GObject.TYPE_INT, 1),
|
||||
GObject.Value(Gimp.ObjectArray, Gimp.ObjectArray.new(Gimp.Drawable, drawable, 1)),
|
||||
GObject.Value(Gio.File, Gio.File.new_for_path(os.path.join(weight_path, '..', 'cache.png'))),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, interlace),
|
||||
GObject.Value(GObject.TYPE_INT, compression),
|
||||
# write all PNG chunks except oFFs(ets)
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, False),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
])
|
||||
|
||||
with open(os.path.join(weight_path, '..', 'gimp_ml_run.pkl'), 'wb') as file:
|
||||
pickle.dump({"force_cpu": bool(force_cpu)}, file)
|
||||
|
||||
subprocess.call([python_path, plugin_path])
|
||||
|
||||
result = Gimp.file_load(Gimp.RunMode.NONINTERACTIVE, Gio.file_new_for_path(os.path.join(weight_path, '..', 'cache.png')))
|
||||
result_layer = result.get_active_layer()
|
||||
copy = Gimp.Layer.new_from_drawable(result_layer, image)
|
||||
copy.set_name("Semantic Segmentation")
|
||||
copy.set_mode(Gimp.LayerMode.NORMAL_LEGACY)#DIFFERENCE_LEGACY
|
||||
image.insert_layer(copy, None, -1)
|
||||
|
||||
image.undo_group_end()
|
||||
Gimp.context_pop()
|
||||
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.SUCCESS, GLib.Error())
|
||||
|
||||
|
||||
|
||||
def run(procedure, run_mode, image, n_drawables, layer, args, data):
|
||||
# gio_file = args.index(0)
|
||||
# bucket_size = args.index(0)
|
||||
force_cpu = args.index(1)
|
||||
# output_format = args.index(2)
|
||||
|
||||
progress_bar = None
|
||||
config = None
|
||||
|
||||
if run_mode == Gimp.RunMode.INTERACTIVE:
|
||||
|
||||
config = procedure.create_config()
|
||||
|
||||
# Set properties from arguments. These properties will be changed by the UI.
|
||||
#config.set_property("file", gio_file)
|
||||
#config.set_property("bucket_size", bucket_size)
|
||||
config.set_property("force_cpu", force_cpu)
|
||||
#config.set_property("output_format", output_format)
|
||||
config.begin_run(image, run_mode, args)
|
||||
|
||||
GimpUi.init("semseg.py")
|
||||
use_header_bar = Gtk.Settings.get_default().get_property("gtk-dialogs-use-header")
|
||||
dialog = GimpUi.Dialog(use_header_bar=use_header_bar,
|
||||
title=_("Mono Depth..."))
|
||||
dialog.add_button("_Cancel", Gtk.ResponseType.CANCEL)
|
||||
dialog.add_button("_OK", Gtk.ResponseType.OK)
|
||||
|
||||
vbox = Gtk.Box(orientation=Gtk.Orientation.VERTICAL,
|
||||
homogeneous=False, spacing=10)
|
||||
dialog.get_content_area().add(vbox)
|
||||
vbox.show()
|
||||
|
||||
# Create grid to set all the properties inside.
|
||||
grid = Gtk.Grid()
|
||||
grid.set_column_homogeneous(False)
|
||||
grid.set_border_width(10)
|
||||
grid.set_column_spacing(10)
|
||||
grid.set_row_spacing(10)
|
||||
vbox.add(grid)
|
||||
grid.show()
|
||||
|
||||
# # Bucket size parameter
|
||||
# label = Gtk.Label.new_with_mnemonic(_("_Bucket Size"))
|
||||
# grid.attach(label, 0, 1, 1, 1)
|
||||
# label.show()
|
||||
# spin = GimpUi.prop_spin_button_new(config, "bucket_size", step_increment=0.001, page_increment=0.1, digits=3)
|
||||
# grid.attach(spin, 1, 1, 1, 1)
|
||||
# spin.show()
|
||||
|
||||
# Force CPU parameter
|
||||
spin = GimpUi.prop_check_button_new(config, "force_cpu", _("Force _CPU"))
|
||||
spin.set_tooltip_text(_("If checked, CPU is used for model inference."
|
||||
" Otherwise, GPU will be used if available."))
|
||||
grid.attach(spin, 1, 2, 1, 1)
|
||||
spin.show()
|
||||
|
||||
# # Output format parameter
|
||||
# label = Gtk.Label.new_with_mnemonic(_("_Output Format"))
|
||||
# grid.attach(label, 0, 3, 1, 1)
|
||||
# label.show()
|
||||
# combo = GimpUi.prop_string_combo_box_new(config, "output_format", output_format_enum.get_tree_model(), 0, 1)
|
||||
# grid.attach(combo, 1, 3, 1, 1)
|
||||
# combo.show()
|
||||
|
||||
progress_bar = Gtk.ProgressBar()
|
||||
vbox.add(progress_bar)
|
||||
progress_bar.show()
|
||||
|
||||
dialog.show()
|
||||
if dialog.run() != Gtk.ResponseType.OK:
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.CANCEL,
|
||||
GLib.Error())
|
||||
|
||||
result = semseg(procedure, image, layer, force_cpu, progress_bar)
|
||||
|
||||
# If the execution was successful, save parameters so they will be restored next time we show dialog.
|
||||
if result.index(0) == Gimp.PDBStatusType.SUCCESS and config is not None:
|
||||
config.end_run(Gimp.PDBStatusType.SUCCESS)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class SemSeg(Gimp.PlugIn):
|
||||
|
||||
## Parameters ##
|
||||
__gproperties__ = {
|
||||
# "filename": (str,
|
||||
# # TODO: I wanted this property to be a path (and not just str) , so I could use
|
||||
# # prop_file_chooser_button_new to open a file dialog. However, it fails without an error message.
|
||||
# # Gimp.ConfigPath,
|
||||
# _("Histogram _File"),
|
||||
# _("Histogram _File"),
|
||||
# "semseg.csv",
|
||||
# # Gimp.ConfigPathType.FILE,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "file": (Gio.File,
|
||||
# _("Histogram _File"),
|
||||
# "Histogram export file",
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "bucket_size": (float,
|
||||
# _("_Bucket Size"),
|
||||
# "Bucket Size",
|
||||
# 0.001, 1.0, 0.01,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
"force_cpu": (bool,
|
||||
_("Force _CPU"),
|
||||
"Force CPU",
|
||||
False,
|
||||
GObject.ParamFlags.READWRITE),
|
||||
# "output_format": (str,
|
||||
# _("Output format"),
|
||||
# "Output format: 'pixel count', 'normalized', 'percent'",
|
||||
# "pixel count",
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
}
|
||||
|
||||
## GimpPlugIn virtual methods ##
|
||||
def do_query_procedures(self):
|
||||
self.set_translation_domain("gimp30-python",
|
||||
Gio.file_new_for_path(Gimp.locale_directory()))
|
||||
return ['semseg']
|
||||
|
||||
def do_create_procedure(self, name):
|
||||
procedure = None
|
||||
if name == 'semseg':
|
||||
procedure = Gimp.ImageProcedure.new(self, name, Gimp.PDBProcType.PLUGIN, run, None)
|
||||
procedure.set_image_types("*")
|
||||
procedure.set_documentation (
|
||||
N_("Performs semantic segmentation of the current layer."),
|
||||
globals()["__doc__"], # This includes the docstring, on the top of the file
|
||||
name)
|
||||
procedure.set_menu_label(N_("Semantic Segmentation..."))
|
||||
procedure.set_attribution("Kritik Soman",
|
||||
"GIMP-ML",
|
||||
"2021")
|
||||
procedure.add_menu_path("<Image>/Layer/GIMP-ML/")
|
||||
|
||||
# procedure.add_argument_from_property(self, "file")
|
||||
# procedure.add_argument_from_property(self, "bucket_size")
|
||||
procedure.add_argument_from_property(self, "force_cpu")
|
||||
# procedure.add_argument_from_property(self, "output_format")
|
||||
|
||||
return procedure
|
||||
|
||||
|
||||
Gimp.main(SemSeg.__gtype__, sys.argv)
|
@ -0,0 +1,368 @@
|
||||
#!/usr/bin/env python3
|
||||
#coding: utf-8
|
||||
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation; either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
"""
|
||||
Exports the image histogram to a text file,
|
||||
so that it can be used by other programs
|
||||
and loaded into spreadsheets.
|
||||
|
||||
The resulting file is a CSV file (Comma Separated
|
||||
Values), which can be imported
|
||||
directly in most spreadsheet programs.
|
||||
|
||||
The first two columns are the bucket boundaries,
|
||||
followed by the selected columns. The histogram
|
||||
refers to the selected image area, and
|
||||
can use either Sample Average data or data
|
||||
from the current drawable only.;
|
||||
|
||||
The output is in "weighted pixels" - meaning
|
||||
all fully transparent pixels are not counted.
|
||||
|
||||
Check the gimp-histogram call
|
||||
"""
|
||||
|
||||
import csv
|
||||
import math
|
||||
import sys
|
||||
|
||||
import gi
|
||||
gi.require_version('Gimp', '3.0')
|
||||
from gi.repository import Gimp
|
||||
gi.require_version('GimpUi', '3.0')
|
||||
from gi.repository import GimpUi
|
||||
from gi.repository import GObject
|
||||
from gi.repository import GLib
|
||||
from gi.repository import Gio
|
||||
gi.require_version('Gtk', '3.0')
|
||||
from gi.repository import Gtk
|
||||
|
||||
import gettext
|
||||
_ = gettext.gettext
|
||||
def N_(message): return message
|
||||
|
||||
|
||||
class StringEnum:
|
||||
"""
|
||||
Helper class for when you want to use strings as keys of an enum. The values would be
|
||||
user facing strings that might undergo translation.
|
||||
|
||||
The constructor accepts an even amount of arguments. Each pair of arguments
|
||||
is a key/value pair.
|
||||
"""
|
||||
def __init__(self, *args):
|
||||
self.keys = []
|
||||
self.values = []
|
||||
|
||||
for i in range(len(args)//2):
|
||||
self.keys.append(args[i*2])
|
||||
self.values.append(args[i*2+1])
|
||||
|
||||
def get_tree_model(self):
|
||||
""" Get a tree model that can be used in GTK widgets. """
|
||||
tree_model = Gtk.ListStore(GObject.TYPE_STRING, GObject.TYPE_STRING)
|
||||
for i in range(len(self.keys)):
|
||||
tree_model.append([self.keys[i], self.values[i]])
|
||||
return tree_model
|
||||
|
||||
def __getattr__(self, name):
|
||||
""" Implements access to the key. For example, if you provided a key "red", then you could access it by
|
||||
referring to
|
||||
my_enum.red
|
||||
It may seem silly as "my_enum.red" is longer to write then just "red",
|
||||
but this provides verification that the key is indeed inside enum. """
|
||||
key = name.replace("_", " ")
|
||||
if key in self.keys:
|
||||
return key
|
||||
raise AttributeError("No such key string " + key)
|
||||
|
||||
|
||||
output_format_enum = StringEnum(
|
||||
"pixel count", _("Pixel count"),
|
||||
"normalized", _("Normalized"),
|
||||
"percent", _("Percent")
|
||||
)
|
||||
|
||||
|
||||
def super_resolution(procedure, img, drw, gio_file,
|
||||
bucket_size, sample_average, output_format,
|
||||
progress_bar):
|
||||
if sample_average:
|
||||
new_img = img.duplicate()
|
||||
drw = new_img.merge_visible_layers(Gimp.MergeType.CLIP_TO_IMAGE)
|
||||
|
||||
channels_txt = ["Value"]
|
||||
channels_gimp = [Gimp.HistogramChannel.VALUE]
|
||||
if drw.is_rgb():
|
||||
channels_txt += ["Red", "Green", "Blue", "Luminance"]
|
||||
channels_gimp += [Gimp.HistogramChannel.RED, Gimp.HistogramChannel.GREEN, Gimp.HistogramChannel.BLUE,
|
||||
Gimp.HistogramChannel.LUMINANCE]
|
||||
if drw.has_alpha():
|
||||
channels_txt += ["Alpha"]
|
||||
channels_gimp += [Gimp.HistogramChannel.ALPHA]
|
||||
|
||||
try:
|
||||
with open(gio_file.get_path(), "wt") as hfile:
|
||||
writer = csv.writer(hfile)
|
||||
|
||||
# Write headers:
|
||||
writer.writerow(["Range start"] + channels_txt)
|
||||
|
||||
max_index = 1.0/bucket_size if bucket_size > 0 else 1
|
||||
i = 0
|
||||
progress_bar_int_percent = 0
|
||||
while True:
|
||||
start_range = i * bucket_size
|
||||
i += 1
|
||||
if start_range >= 1.0:
|
||||
break
|
||||
|
||||
row = [start_range]
|
||||
for channel in channels_gimp:
|
||||
result = Gimp.get_pdb().run_procedure('gimp-drawable-histogram',
|
||||
[ GObject.Value(Gimp.Drawable, drw),
|
||||
GObject.Value(Gimp.HistogramChannel, channel),
|
||||
GObject.Value(GObject.TYPE_DOUBLE,
|
||||
float(start_range)),
|
||||
GObject.Value(GObject.TYPE_DOUBLE,
|
||||
float(min(start_range + bucket_size, 1.0))) ])
|
||||
|
||||
if output_format == output_format_enum.pixel_count:
|
||||
count = int(result.index(5))
|
||||
else:
|
||||
pixels = result.index(4)
|
||||
count = (result.index(5) / pixels) if pixels else 0
|
||||
if output_format == output_format_enum.percent:
|
||||
count = "%.2f%%" % (count * 100)
|
||||
row.append(str(count))
|
||||
writer.writerow(row)
|
||||
|
||||
# Update progress bar
|
||||
if progress_bar:
|
||||
fraction = i / max_index
|
||||
# Only update the progress bar if it changed at least 1% .
|
||||
new_percent = math.floor(fraction * 100)
|
||||
if new_percent != progress_bar_int_percent:
|
||||
progress_bar_int_percent = new_percent
|
||||
progress_bar.set_fraction(fraction)
|
||||
# Make sure the progress bar gets drawn on screen.
|
||||
while Gtk.events_pending():
|
||||
Gtk.main_iteration()
|
||||
except IsADirectoryError:
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.EXECUTION_ERROR,
|
||||
GLib.Error(_("File is either a directory or file name is empty.")))
|
||||
except FileNotFoundError:
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.EXECUTION_ERROR,
|
||||
GLib.Error(_("Directory not found.")))
|
||||
except PermissionError:
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.EXECUTION_ERROR,
|
||||
GLib.Error("You do not have permissions to write that file."))
|
||||
|
||||
if sample_average:
|
||||
new_img.delete()
|
||||
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.SUCCESS, GLib.Error())
|
||||
|
||||
|
||||
def run(procedure, run_mode, image, n_drawables, layer, args, data):
|
||||
gio_file = args.index(0)
|
||||
bucket_size = args.index(1)
|
||||
sample_average = args.index(2)
|
||||
output_format = args.index(3)
|
||||
|
||||
progress_bar = None
|
||||
config = None
|
||||
|
||||
if run_mode == Gimp.RunMode.INTERACTIVE:
|
||||
|
||||
config = procedure.create_config()
|
||||
|
||||
# Set properties from arguments. These properties will be changed by the UI.
|
||||
#config.set_property("file", gio_file)
|
||||
#config.set_property("bucket_size", bucket_size)
|
||||
#config.set_property("sample_average", sample_average)
|
||||
#config.set_property("output_format", output_format)
|
||||
config.begin_run(image, run_mode, args)
|
||||
|
||||
GimpUi.init("superresolution.py")
|
||||
use_header_bar = Gtk.Settings.get_default().get_property("gtk-dialogs-use-header")
|
||||
dialog = GimpUi.Dialog(use_header_bar=use_header_bar,
|
||||
title=_("Super Resolution..."))
|
||||
dialog.add_button("_Cancel", Gtk.ResponseType.CANCEL)
|
||||
dialog.add_button("_OK", Gtk.ResponseType.OK)
|
||||
|
||||
vbox = Gtk.Box(orientation=Gtk.Orientation.VERTICAL,
|
||||
homogeneous=False, spacing=10)
|
||||
dialog.get_content_area().add(vbox)
|
||||
vbox.show()
|
||||
|
||||
# Create grid to set all the properties inside.
|
||||
grid = Gtk.Grid()
|
||||
grid.set_column_homogeneous(False)
|
||||
grid.set_border_width(10)
|
||||
grid.set_column_spacing(10)
|
||||
grid.set_row_spacing(10)
|
||||
vbox.add(grid)
|
||||
grid.show()
|
||||
|
||||
# UI for the file parameter
|
||||
|
||||
def choose_file(widget):
|
||||
if file_chooser_dialog.run() == Gtk.ResponseType.OK:
|
||||
if file_chooser_dialog.get_file() is not None:
|
||||
config.set_property("file", file_chooser_dialog.get_file())
|
||||
file_entry.set_text(file_chooser_dialog.get_file().get_path())
|
||||
file_chooser_dialog.hide()
|
||||
|
||||
file_chooser_button = Gtk.Button.new_with_mnemonic(label=_("_File..."))
|
||||
grid.attach(file_chooser_button, 0, 0, 1, 1)
|
||||
file_chooser_button.show()
|
||||
file_chooser_button.connect("clicked", choose_file)
|
||||
|
||||
file_entry = Gtk.Entry.new()
|
||||
grid.attach(file_entry, 1, 0, 1, 1)
|
||||
file_entry.set_width_chars(40)
|
||||
file_entry.set_placeholder_text(_("Choose export file..."))
|
||||
if gio_file is not None:
|
||||
file_entry.set_text(gio_file.get_path())
|
||||
file_entry.show()
|
||||
|
||||
file_chooser_dialog = Gtk.FileChooserDialog(use_header_bar=use_header_bar,
|
||||
title=_("Histogram Export file..."),
|
||||
action=Gtk.FileChooserAction.SAVE)
|
||||
file_chooser_dialog.add_button("_Cancel", Gtk.ResponseType.CANCEL)
|
||||
file_chooser_dialog.add_button("_OK", Gtk.ResponseType.OK)
|
||||
|
||||
# Bucket size parameter
|
||||
label = Gtk.Label.new_with_mnemonic(_("_Bucket Size"))
|
||||
grid.attach(label, 0, 1, 1, 1)
|
||||
label.show()
|
||||
spin = GimpUi.prop_spin_button_new(config, "bucket_size", step_increment=0.001, page_increment=0.1, digits=3)
|
||||
grid.attach(spin, 1, 1, 1, 1)
|
||||
spin.show()
|
||||
|
||||
# Sample average parameter
|
||||
spin = GimpUi.prop_check_button_new(config, "sample_average", _("Sample _Average"))
|
||||
spin.set_tooltip_text(_("If checked, the histogram is generated from merging all visible layers."
|
||||
" Otherwise, the histogram is only for the current layer."))
|
||||
grid.attach(spin, 1, 2, 1, 1)
|
||||
spin.show()
|
||||
|
||||
# Output format parameter
|
||||
label = Gtk.Label.new_with_mnemonic(_("_Output Format"))
|
||||
grid.attach(label, 0, 3, 1, 1)
|
||||
label.show()
|
||||
combo = GimpUi.prop_string_combo_box_new(config, "output_format", output_format_enum.get_tree_model(), 0, 1)
|
||||
grid.attach(combo, 1, 3, 1, 1)
|
||||
combo.show()
|
||||
|
||||
progress_bar = Gtk.ProgressBar()
|
||||
vbox.add(progress_bar)
|
||||
progress_bar.show()
|
||||
|
||||
dialog.show()
|
||||
if dialog.run() != Gtk.ResponseType.OK:
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.CANCEL,
|
||||
GLib.Error())
|
||||
|
||||
# Extract values from UI
|
||||
gio_file = Gio.file_new_for_path(file_entry.get_text()) # config.get_property("file")
|
||||
bucket_size = config.get_property("bucket_size")
|
||||
sample_average = config.get_property("sample_average")
|
||||
output_format = config.get_property("output_format")
|
||||
|
||||
if gio_file is None:
|
||||
error = 'No file given'
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.CALLING_ERROR,
|
||||
GLib.Error(error))
|
||||
|
||||
result = super_resolution(procedure, image, layer, gio_file,
|
||||
bucket_size, sample_average, output_format, progress_bar)
|
||||
|
||||
# If the execution was successful, save parameters so they will be restored next time we show dialog.
|
||||
if result.index(0) == Gimp.PDBStatusType.SUCCESS and config is not None:
|
||||
config.end_run(Gimp.PDBStatusType.SUCCESS)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class SuperResolution(Gimp.PlugIn):
|
||||
|
||||
## Parameters ##
|
||||
__gproperties__ = {
|
||||
# "filename": (str,
|
||||
# # TODO: I wanted this property to be a path (and not just str) , so I could use
|
||||
# # prop_file_chooser_button_new to open a file dialog. However, it fails without an error message.
|
||||
# # Gimp.ConfigPath,
|
||||
# _("Histogram _File"),
|
||||
# _("Histogram _File"),
|
||||
# "super_resolution.csv",
|
||||
# # Gimp.ConfigPathType.FILE,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
"file": (Gio.File,
|
||||
_("Histogram _File"),
|
||||
"Histogram export file",
|
||||
GObject.ParamFlags.READWRITE),
|
||||
"bucket_size": (float,
|
||||
_("_Bucket Size"),
|
||||
"Bucket Size",
|
||||
0.1, 4.0, 1.1,
|
||||
GObject.ParamFlags.READWRITE),
|
||||
"sample_average": (bool,
|
||||
_("Sample _Average"),
|
||||
"Sample Average",
|
||||
False,
|
||||
GObject.ParamFlags.READWRITE),
|
||||
"output_format": (str,
|
||||
_("Output format"),
|
||||
"Output format: 'pixel count', 'normalized', 'percent'",
|
||||
"pixel count",
|
||||
GObject.ParamFlags.READWRITE),
|
||||
}
|
||||
|
||||
## GimpPlugIn virtual methods ##
|
||||
def do_query_procedures(self):
|
||||
self.set_translation_domain("gimp30-python",
|
||||
Gio.file_new_for_path(Gimp.locale_directory()))
|
||||
return ['superresolution']
|
||||
|
||||
def do_create_procedure(self, name):
|
||||
procedure = None
|
||||
if name == 'superresolution':
|
||||
procedure = Gimp.ImageProcedure.new(self, name,
|
||||
Gimp.PDBProcType.PLUGIN,
|
||||
run, None)
|
||||
|
||||
procedure.set_image_types("*")
|
||||
procedure.set_documentation (
|
||||
N_("Exports the image histogram to a text file (CSV)"),
|
||||
globals()["__doc__"], # This includes the docstring, on the top of the file
|
||||
name)
|
||||
procedure.set_menu_label(N_("_Super Resolution..."))
|
||||
procedure.set_attribution("João S. O. Bueno",
|
||||
"(c) GPL V3.0 or later",
|
||||
"2014")
|
||||
procedure.add_menu_path("<Image>/Layer/GIMP-ML/")
|
||||
|
||||
procedure.add_argument_from_property(self, "file")
|
||||
procedure.add_argument_from_property(self, "bucket_size")
|
||||
procedure.add_argument_from_property(self, "sample_average")
|
||||
procedure.add_argument_from_property(self, "output_format")
|
||||
|
||||
return procedure
|
||||
|
||||
|
||||
Gimp.main(SuperResolution.__gtype__, sys.argv)
|
@ -0,0 +1,371 @@
|
||||
#!/usr/bin/env python3
|
||||
# coding: utf-8
|
||||
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation; either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
"""
|
||||
Exports the image histogram to a text file,
|
||||
so that it can be used by other programs
|
||||
and loaded into spreadsheets.
|
||||
|
||||
The resulting file is a CSV file (Comma Separated
|
||||
Values), which can be imported
|
||||
directly in most spreadsheet programs.
|
||||
|
||||
The first two columns are the bucket boundaries,
|
||||
followed by the selected columns. The histogram
|
||||
refers to the selected image area, and
|
||||
can use either Sample Average data or data
|
||||
from the current drawable only.;
|
||||
|
||||
The output is in "weighted pixels" - meaning
|
||||
all fully transparent pixels are not counted.
|
||||
|
||||
Check the gimp-histogram call
|
||||
"""
|
||||
|
||||
import csv
|
||||
import math
|
||||
import sys
|
||||
|
||||
import gi
|
||||
|
||||
gi.require_version('Gimp', '3.0')
|
||||
from gi.repository import Gimp
|
||||
|
||||
gi.require_version('GimpUi', '3.0')
|
||||
from gi.repository import GimpUi
|
||||
from gi.repository import GObject
|
||||
from gi.repository import GLib
|
||||
from gi.repository import Gio
|
||||
|
||||
gi.require_version('Gtk', '3.0')
|
||||
from gi.repository import Gtk
|
||||
|
||||
import gettext
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
|
||||
_ = gettext.gettext
|
||||
|
||||
|
||||
def N_(message): return message
|
||||
|
||||
|
||||
class StringEnum:
|
||||
"""
|
||||
Helper class for when you want to use strings as keys of an enum. The values would be
|
||||
user facing strings that might undergo translation.
|
||||
|
||||
The constructor accepts an even amount of arguments. Each pair of arguments
|
||||
is a key/value pair.
|
||||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
self.keys = []
|
||||
self.values = []
|
||||
|
||||
for i in range(len(args) // 2):
|
||||
self.keys.append(args[i * 2])
|
||||
self.values.append(args[i * 2 + 1])
|
||||
|
||||
def get_tree_model(self):
|
||||
""" Get a tree model that can be used in GTK widgets. """
|
||||
tree_model = Gtk.ListStore(GObject.TYPE_STRING, GObject.TYPE_STRING)
|
||||
for i in range(len(self.keys)):
|
||||
tree_model.append([self.keys[i], self.values[i]])
|
||||
return tree_model
|
||||
|
||||
def __getattr__(self, name):
|
||||
""" Implements access to the key. For example, if you provided a key "red", then you could access it by
|
||||
referring to
|
||||
my_enum.red
|
||||
It may seem silly as "my_enum.red" is longer to write then just "red",
|
||||
but this provides verification that the key is indeed inside enum. """
|
||||
key = name.replace("_", " ")
|
||||
if key in self.keys:
|
||||
return key
|
||||
raise AttributeError("No such key string " + key)
|
||||
|
||||
|
||||
# output_format_enum = StringEnum(
|
||||
# "pixel count", _("Pixel count"),
|
||||
# "normalized", _("Normalized"),
|
||||
# "percent", _("Percent")
|
||||
# )
|
||||
|
||||
|
||||
def super_resolution(procedure, image, drawable, scale, filter, force_cpu, progress_bar):
|
||||
config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "tools")
|
||||
with open(os.path.join(config_path, 'gimp_ml_config.pkl'), 'rb') as file:
|
||||
data_output = pickle.load(file)
|
||||
weight_path = data_output["weight_path"]
|
||||
python_path = data_output["python_path"]
|
||||
plugin_path = os.path.join(config_path, 'superresolution.py')
|
||||
|
||||
Gimp.context_push()
|
||||
image.undo_group_start()
|
||||
|
||||
interlace, compression = 0, 2
|
||||
Gimp.get_pdb().run_procedure('file-png-save', [
|
||||
GObject.Value(Gimp.RunMode, Gimp.RunMode.NONINTERACTIVE),
|
||||
GObject.Value(Gimp.Image, image),
|
||||
GObject.Value(GObject.TYPE_INT, 1),
|
||||
GObject.Value(Gimp.ObjectArray, Gimp.ObjectArray.new(Gimp.Drawable, drawable, 1)),
|
||||
GObject.Value(Gio.File, Gio.File.new_for_path(os.path.join(weight_path, '..', 'cache.png'))),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, interlace),
|
||||
GObject.Value(GObject.TYPE_INT, compression),
|
||||
# write all PNG chunks except oFFs(ets)
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, False),
|
||||
GObject.Value(GObject.TYPE_BOOLEAN, True),
|
||||
])
|
||||
|
||||
with open(os.path.join(weight_path, '..', 'gimp_ml_run.pkl'), 'wb') as file:
|
||||
pickle.dump({"force_cpu": bool(force_cpu), "filter": bool(filter), "scale": float(scale)}, file)
|
||||
|
||||
subprocess.call([python_path, plugin_path])
|
||||
|
||||
if scale == 1:
|
||||
result = Gimp.file_load(Gimp.RunMode.NONINTERACTIVE,
|
||||
Gio.file_new_for_path(os.path.join(weight_path, '..', 'cache.png')))
|
||||
result_layer = result.get_active_layer()
|
||||
copy = Gimp.Layer.new_from_drawable(result_layer, image)
|
||||
copy.set_name("Super-resolution")
|
||||
copy.set_mode(Gimp.LayerMode.NORMAL_LEGACY) # DIFFERENCE_LEGACY
|
||||
image.insert_layer(copy, None, -1)
|
||||
else:
|
||||
image_new = Gimp.Image.new(drawable[0].get_width() * scale, drawable[0].get_height() * scale, 0) # 0 for RGB
|
||||
display = Gimp.Display.new(image_new)
|
||||
result = Gimp.file_load(Gimp.RunMode.NONINTERACTIVE,
|
||||
Gio.File.new_for_path(os.path.join(weight_path, '..', 'cache.png')))
|
||||
result_layer = result.get_active_layer()
|
||||
copy = Gimp.Layer.new_from_drawable(result_layer, image_new)
|
||||
copy.set_name("Super-resolution")
|
||||
copy.set_mode(Gimp.LayerMode.NORMAL_LEGACY) # DIFFERENCE_LEGACY
|
||||
image_new.insert_layer(copy, None, -1)
|
||||
|
||||
Gimp.displays_flush()
|
||||
|
||||
image.undo_group_end()
|
||||
Gimp.context_pop()
|
||||
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.SUCCESS, GLib.Error())
|
||||
|
||||
|
||||
def run(procedure, run_mode, image, n_drawables, layer, args, data):
|
||||
# gio_file = args.index(0)
|
||||
scale = args.index(0)
|
||||
filter = args.index(1)
|
||||
# output_format = args.index(3)
|
||||
force_cpu = args.index(2)
|
||||
|
||||
progress_bar = None
|
||||
config = None
|
||||
|
||||
if run_mode == Gimp.RunMode.INTERACTIVE:
|
||||
|
||||
config = procedure.create_config()
|
||||
|
||||
# Set properties from arguments. These properties will be changed by the UI.
|
||||
# config.set_property("file", gio_file)
|
||||
# config.set_property("scale", scale)
|
||||
# config.set_property("sample_average", sample_average)
|
||||
# config.set_property("output_format", output_format)
|
||||
config.set_property("force_cpu", force_cpu)
|
||||
config.begin_run(image, run_mode, args)
|
||||
|
||||
GimpUi.init("superresolution.py")
|
||||
use_header_bar = Gtk.Settings.get_default().get_property("gtk-dialogs-use-header")
|
||||
dialog = GimpUi.Dialog(use_header_bar=use_header_bar,
|
||||
title=_("Super Resolution..."))
|
||||
dialog.add_button("_Cancel", Gtk.ResponseType.CANCEL)
|
||||
dialog.add_button("_OK", Gtk.ResponseType.OK)
|
||||
|
||||
vbox = Gtk.Box(orientation=Gtk.Orientation.VERTICAL,
|
||||
homogeneous=False, spacing=10)
|
||||
dialog.get_content_area().add(vbox)
|
||||
vbox.show()
|
||||
|
||||
# Create grid to set all the properties inside.
|
||||
grid = Gtk.Grid()
|
||||
grid.set_column_homogeneous(False)
|
||||
grid.set_border_width(10)
|
||||
grid.set_column_spacing(10)
|
||||
grid.set_row_spacing(10)
|
||||
vbox.add(grid)
|
||||
grid.show()
|
||||
|
||||
# UI for the file parameter
|
||||
|
||||
# def choose_file(widget):
|
||||
# if file_chooser_dialog.run() == Gtk.ResponseType.OK:
|
||||
# if file_chooser_dialog.get_file() is not None:
|
||||
# config.set_property("file", file_chooser_dialog.get_file())
|
||||
# file_entry.set_text(file_chooser_dialog.get_file().get_path())
|
||||
# file_chooser_dialog.hide()
|
||||
#
|
||||
# file_chooser_button = Gtk.Button.new_with_mnemonic(label=_("_File..."))
|
||||
# grid.attach(file_chooser_button, 0, 0, 1, 1)
|
||||
# file_chooser_button.show()
|
||||
# file_chooser_button.connect("clicked", choose_file)
|
||||
#
|
||||
# file_entry = Gtk.Entry.new()
|
||||
# grid.attach(file_entry, 1, 0, 1, 1)
|
||||
# file_entry.set_width_chars(40)
|
||||
# file_entry.set_placeholder_text(_("Choose export file..."))
|
||||
# # if gio_file is not None:
|
||||
# # file_entry.set_text(gio_file.get_path())
|
||||
# file_entry.show()
|
||||
#
|
||||
# file_chooser_dialog = Gtk.FileChooserDialog(use_header_bar=use_header_bar,
|
||||
# title=_("Histogram Export file..."),
|
||||
# action=Gtk.FileChooserAction.SAVE)
|
||||
# file_chooser_dialog.add_button("_Cancel", Gtk.ResponseType.CANCEL)
|
||||
# file_chooser_dialog.add_button("_OK", Gtk.ResponseType.OK)
|
||||
|
||||
# Scale parameter
|
||||
label = Gtk.Label.new_with_mnemonic(_("_Scale"))
|
||||
grid.attach(label, 0, 1, 1, 1)
|
||||
label.show()
|
||||
spin = GimpUi.prop_spin_button_new(config, "scale", step_increment=0.01, page_increment=0.1, digits=2)
|
||||
grid.attach(spin, 1, 1, 1, 1)
|
||||
spin.show()
|
||||
|
||||
# Sample average parameter
|
||||
spin = GimpUi.prop_check_button_new(config, "filter", _("Use _Filter"))
|
||||
spin.set_tooltip_text(_("If checked, super-resolution will be used as a filter."
|
||||
" Otherwise, it will run on whole image at once."))
|
||||
grid.attach(spin, 1, 2, 1, 1)
|
||||
spin.show()
|
||||
|
||||
# # Output format parameter
|
||||
# label = Gtk.Label.new_with_mnemonic(_("_Output Format"))
|
||||
# grid.attach(label, 0, 3, 1, 1)
|
||||
# label.show()
|
||||
# combo = GimpUi.prop_string_combo_box_new(config, "output_format", output_format_enum.get_tree_model(), 0, 1)
|
||||
# grid.attach(combo, 1, 3, 1, 1)
|
||||
# combo.show()
|
||||
|
||||
# Force CPU parameter
|
||||
spin = GimpUi.prop_check_button_new(config, "force_cpu", _("Force _CPU"))
|
||||
spin.set_tooltip_text(_("If checked, CPU is used for model inference."
|
||||
" Otherwise, GPU will be used if available."))
|
||||
grid.attach(spin, 1, 3, 1, 1)
|
||||
spin.show()
|
||||
|
||||
progress_bar = Gtk.ProgressBar()
|
||||
vbox.add(progress_bar)
|
||||
progress_bar.show()
|
||||
|
||||
dialog.show()
|
||||
if dialog.run() != Gtk.ResponseType.OK:
|
||||
return procedure.new_return_values(Gimp.PDBStatusType.CANCEL,
|
||||
GLib.Error())
|
||||
|
||||
# Extract values from UI
|
||||
# gio_file = Gio.file_new_for_path(file_entry.get_text()) # config.get_property("file")
|
||||
scale = config.get_property("scale")
|
||||
filter = config.get_property("filter")
|
||||
force_cpu = config.get_property("force_cpu")
|
||||
|
||||
# if gio_file is None:
|
||||
# error = 'No file given'
|
||||
# return procedure.new_return_values(Gimp.PDBStatusType.CALLING_ERROR,
|
||||
# GLib.Error(error))
|
||||
|
||||
result = super_resolution(procedure, image, layer,
|
||||
scale, filter, force_cpu, progress_bar)
|
||||
|
||||
# If the execution was successful, save parameters so they will be restored next time we show dialog.
|
||||
if result.index(0) == Gimp.PDBStatusType.SUCCESS and config is not None:
|
||||
config.end_run(Gimp.PDBStatusType.SUCCESS)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class SuperResolution(Gimp.PlugIn):
|
||||
## Parameters ##
|
||||
__gproperties__ = {
|
||||
# "filename": (str,
|
||||
# # TODO: I wanted this property to be a path (and not just str) , so I could use
|
||||
# # prop_file_chooser_button_new to open a file dialog. However, it fails without an error message.
|
||||
# # Gimp.ConfigPath,
|
||||
# _("Histogram _File"),
|
||||
# _("Histogram _File"),
|
||||
# "super_resolution.csv",
|
||||
# # Gimp.ConfigPathType.FILE,
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
# "file": (Gio.File,
|
||||
# _("Histogram _File"),
|
||||
# "Histogram export file",
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
"scale": (float,
|
||||
_("_Scale"),
|
||||
"Scale",
|
||||
1, 4, 2,
|
||||
GObject.ParamFlags.READWRITE),
|
||||
"filter": (bool,
|
||||
_("Use _Filter"),
|
||||
"Use as Filter",
|
||||
False,
|
||||
GObject.ParamFlags.READWRITE),
|
||||
# "output_format": (str,
|
||||
# _("Output format"),
|
||||
# "Output format: 'pixel count', 'normalized', 'percent'",
|
||||
# "pixel count",
|
||||
# GObject.ParamFlags.READWRITE),
|
||||
"force_cpu": (bool,
|
||||
_("Force _CPU"),
|
||||
"Force CPU",
|
||||
False,
|
||||
GObject.ParamFlags.READWRITE),
|
||||
}
|
||||
|
||||
## GimpPlugIn virtual methods ##
|
||||
def do_query_procedures(self):
|
||||
self.set_translation_domain("gimp30-python",
|
||||
Gio.file_new_for_path(Gimp.locale_directory()))
|
||||
return ['superresolution']
|
||||
|
||||
def do_create_procedure(self, name):
|
||||
procedure = None
|
||||
if name == 'superresolution':
|
||||
procedure = Gimp.ImageProcedure.new(self, name,
|
||||
Gimp.PDBProcType.PLUGIN,
|
||||
run, None)
|
||||
|
||||
procedure.set_image_types("*")
|
||||
procedure.set_documentation(
|
||||
N_("Exports the image histogram to a text file (CSV)"),
|
||||
globals()["__doc__"], # This includes the docstring, on the top of the file
|
||||
name)
|
||||
procedure.set_menu_label(N_("_Super Resolution..."))
|
||||
procedure.set_attribution("João S. O. Bueno",
|
||||
"(c) GPL V3.0 or later",
|
||||
"2014")
|
||||
procedure.add_menu_path("<Image>/Layer/GIMP-ML/")
|
||||
|
||||
# procedure.add_argument_from_property(self, "file")
|
||||
procedure.add_argument_from_property(self, "scale")
|
||||
procedure.add_argument_from_property(self, "filter")
|
||||
# procedure.add_argument_from_property(self, "output_format")
|
||||
procedure.add_argument_from_property(self, "force_cpu")
|
||||
|
||||
return procedure
|
||||
|
||||
|
||||
Gimp.main(SuperResolution.__gtype__, sys.argv)
|
@ -0,0 +1,99 @@
|
||||
import torch
|
||||
import copy
|
||||
|
||||
|
||||
class GANFactory:
|
||||
factories = {}
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def add_factory(gan_id, model_factory):
|
||||
GANFactory.factories.put[gan_id] = model_factory
|
||||
|
||||
add_factory = staticmethod(add_factory)
|
||||
|
||||
# A Template Method:
|
||||
|
||||
def create_model(gan_id, net_d=None, criterion=None):
|
||||
if gan_id not in GANFactory.factories:
|
||||
GANFactory.factories[gan_id] = \
|
||||
eval(gan_id + '.Factory()')
|
||||
return GANFactory.factories[gan_id].create(net_d, criterion)
|
||||
|
||||
create_model = staticmethod(create_model)
|
||||
|
||||
|
||||
class GANTrainer(object):
|
||||
def __init__(self, net_d, criterion):
|
||||
self.net_d = net_d
|
||||
self.criterion = criterion
|
||||
|
||||
def loss_d(self, pred, gt):
|
||||
pass
|
||||
|
||||
def loss_g(self, pred, gt):
|
||||
pass
|
||||
|
||||
def get_params(self):
|
||||
pass
|
||||
|
||||
|
||||
class NoGAN(GANTrainer):
|
||||
def __init__(self, net_d, criterion):
|
||||
GANTrainer.__init__(self, net_d, criterion)
|
||||
|
||||
def loss_d(self, pred, gt):
|
||||
return [0]
|
||||
|
||||
def loss_g(self, pred, gt):
|
||||
return 0
|
||||
|
||||
def get_params(self):
|
||||
return [torch.nn.Parameter(torch.Tensor(1))]
|
||||
|
||||
class Factory:
|
||||
@staticmethod
|
||||
def create(net_d, criterion): return NoGAN(net_d, criterion)
|
||||
|
||||
|
||||
class SingleGAN(GANTrainer):
|
||||
def __init__(self, net_d, criterion):
|
||||
GANTrainer.__init__(self, net_d, criterion)
|
||||
self.net_d = self.net_d.cuda()
|
||||
|
||||
def loss_d(self, pred, gt):
|
||||
return self.criterion(self.net_d, pred, gt)
|
||||
|
||||
def loss_g(self, pred, gt):
|
||||
return self.criterion.get_g_loss(self.net_d, pred, gt)
|
||||
|
||||
def get_params(self):
|
||||
return self.net_d.parameters()
|
||||
|
||||
class Factory:
|
||||
@staticmethod
|
||||
def create(net_d, criterion): return SingleGAN(net_d, criterion)
|
||||
|
||||
|
||||
class DoubleGAN(GANTrainer):
|
||||
def __init__(self, net_d, criterion):
|
||||
GANTrainer.__init__(self, net_d, criterion)
|
||||
self.patch_d = net_d['patch'].cuda()
|
||||
self.full_d = net_d['full'].cuda()
|
||||
self.full_criterion = copy.deepcopy(criterion)
|
||||
|
||||
def loss_d(self, pred, gt):
|
||||
return (self.criterion(self.patch_d, pred, gt) + self.full_criterion(self.full_d, pred, gt)) / 2
|
||||
|
||||
def loss_g(self, pred, gt):
|
||||
return (self.criterion.get_g_loss(self.patch_d, pred, gt) + self.full_criterion.get_g_loss(self.full_d, pred,
|
||||
gt)) / 2
|
||||
|
||||
def get_params(self):
|
||||
return list(self.patch_d.parameters()) + list(self.full_d.parameters())
|
||||
|
||||
class Factory:
|
||||
@staticmethod
|
||||
def create(net_d, criterion): return DoubleGAN(net_d, criterion)
|
||||
|
@ -0,0 +1,93 @@
|
||||
from typing import List
|
||||
|
||||
import albumentations as albu
|
||||
|
||||
|
||||
def get_transforms(size, scope = 'geometric', crop='random'):
|
||||
augs = {'strong': albu.Compose([albu.HorizontalFlip(),
|
||||
albu.ShiftScaleRotate(shift_limit=0.0, scale_limit=0.2, rotate_limit=20, p=.4),
|
||||
albu.ElasticTransform(),
|
||||
albu.OpticalDistortion(),
|
||||
albu.OneOf([
|
||||
albu.CLAHE(clip_limit=2),
|
||||
albu.IAASharpen(),
|
||||
albu.IAAEmboss(),
|
||||
albu.RandomBrightnessContrast(),
|
||||
albu.RandomGamma()
|
||||
], p=0.5),
|
||||
albu.OneOf([
|
||||
albu.RGBShift(),
|
||||
albu.HueSaturationValue(),
|
||||
], p=0.5),
|
||||
]),
|
||||
'weak': albu.Compose([albu.HorizontalFlip(),
|
||||
]),
|
||||
'geometric': albu.OneOf([albu.HorizontalFlip(always_apply=True),
|
||||
albu.ShiftScaleRotate(always_apply=True),
|
||||
albu.Transpose(always_apply=True),
|
||||
albu.OpticalDistortion(always_apply=True),
|
||||
albu.ElasticTransform(always_apply=True),
|
||||
])
|
||||
}
|
||||
|
||||
aug_fn = augs[scope]
|
||||
crop_fn = {'random': albu.RandomCrop(size, size, always_apply=True),
|
||||
'center': albu.CenterCrop(size, size, always_apply=True)}[crop]
|
||||
pad = albu.PadIfNeeded(size, size)
|
||||
|
||||
pipeline = albu.Compose([aug_fn, crop_fn, pad], additional_targets={'target': 'image'})
|
||||
|
||||
def process(a, b):
|
||||
r = pipeline(image=a, target=b)
|
||||
return r['image'], r['target']
|
||||
|
||||
return process
|
||||
|
||||
|
||||
def get_normalize():
|
||||
normalize = albu.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
normalize = albu.Compose([normalize], additional_targets={'target': 'image'})
|
||||
|
||||
def process(a, b):
|
||||
r = normalize(image=a, target=b)
|
||||
return r['image'], r['target']
|
||||
|
||||
return process
|
||||
|
||||
|
||||
def _resolve_aug_fn(name):
|
||||
d = {
|
||||
'cutout': albu.Cutout,
|
||||
'rgb_shift': albu.RGBShift,
|
||||
'hsv_shift': albu.HueSaturationValue,
|
||||
'motion_blur': albu.MotionBlur,
|
||||
'median_blur': albu.MedianBlur,
|
||||
'snow': albu.RandomSnow,
|
||||
'shadow': albu.RandomShadow,
|
||||
'fog': albu.RandomFog,
|
||||
'brightness_contrast': albu.RandomBrightnessContrast,
|
||||
'gamma': albu.RandomGamma,
|
||||
'sun_flare': albu.RandomSunFlare,
|
||||
'sharpen': albu.IAASharpen,
|
||||
'jpeg': albu.JpegCompression,
|
||||
'gray': albu.ToGray,
|
||||
# ToDo: pixelize
|
||||
# ToDo: partial gray
|
||||
}
|
||||
return d[name]
|
||||
|
||||
|
||||
def get_corrupt_function(config):
|
||||
augs = []
|
||||
for aug_params in config:
|
||||
name = aug_params.pop('name')
|
||||
cls = _resolve_aug_fn(name)
|
||||
prob = aug_params.pop('prob') if 'prob' in aug_params else .5
|
||||
augs.append(cls(p=prob, **aug_params))
|
||||
|
||||
augs = albu.OneOf(augs)
|
||||
|
||||
def process(x):
|
||||
return augs(image=x)['image']
|
||||
|
||||
return process
|
@ -0,0 +1,142 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from glob import glob
|
||||
from hashlib import sha1
|
||||
from typing import Callable, Iterable, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from glog import logger
|
||||
from joblib import Parallel, cpu_count, delayed
|
||||
from skimage.io import imread
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import aug
|
||||
|
||||
|
||||
def subsample(data: Iterable, bounds: Tuple[float, float], hash_fn: Callable, n_buckets=100, salt='', verbose=True):
|
||||
data = list(data)
|
||||
buckets = split_into_buckets(data, n_buckets=n_buckets, salt=salt, hash_fn=hash_fn)
|
||||
|
||||
lower_bound, upper_bound = [x * n_buckets for x in bounds]
|
||||
msg = f'Subsampling buckets from {lower_bound} to {upper_bound}, total buckets number is {n_buckets}'
|
||||
if salt:
|
||||
msg += f'; salt is {salt}'
|
||||
if verbose:
|
||||
logger.info(msg)
|
||||
return np.array([sample for bucket, sample in zip(buckets, data) if lower_bound <= bucket < upper_bound])
|
||||
|
||||
|
||||
def hash_from_paths(x: Tuple[str, str], salt: str = '') -> str:
|
||||
path_a, path_b = x
|
||||
names = ''.join(map(os.path.basename, (path_a, path_b)))
|
||||
return sha1(f'{names}_{salt}'.encode()).hexdigest()
|
||||
|
||||
|
||||
def split_into_buckets(data: Iterable, n_buckets: int, hash_fn: Callable, salt=''):
|
||||
hashes = map(partial(hash_fn, salt=salt), data)
|
||||
return np.array([int(x, 16) % n_buckets for x in hashes])
|
||||
|
||||
|
||||
def _read_img(x: str):
|
||||
img = cv2.imread(x)
|
||||
if img is None:
|
||||
logger.warning(f'Can not read image {x} with OpenCV, switching to scikit-image')
|
||||
img = imread(x)
|
||||
return img
|
||||
|
||||
|
||||
class PairedDataset(Dataset):
|
||||
def __init__(self,
|
||||
files_a: Tuple[str],
|
||||
files_b: Tuple[str],
|
||||
transform_fn: Callable,
|
||||
normalize_fn: Callable,
|
||||
corrupt_fn: Optional[Callable] = None,
|
||||
preload: bool = True,
|
||||
preload_size: Optional[int] = 0,
|
||||
verbose=True):
|
||||
|
||||
assert len(files_a) == len(files_b)
|
||||
|
||||
self.preload = preload
|
||||
self.data_a = files_a
|
||||
self.data_b = files_b
|
||||
self.verbose = verbose
|
||||
self.corrupt_fn = corrupt_fn
|
||||
self.transform_fn = transform_fn
|
||||
self.normalize_fn = normalize_fn
|
||||
logger.info(f'Dataset has been created with {len(self.data_a)} samples')
|
||||
|
||||
if preload:
|
||||
preload_fn = partial(self._bulk_preload, preload_size=preload_size)
|
||||
if files_a == files_b:
|
||||
self.data_a = self.data_b = preload_fn(self.data_a)
|
||||
else:
|
||||
self.data_a, self.data_b = map(preload_fn, (self.data_a, self.data_b))
|
||||
self.preload = True
|
||||
|
||||
def _bulk_preload(self, data: Iterable[str], preload_size: int):
|
||||
jobs = [delayed(self._preload)(x, preload_size=preload_size) for x in data]
|
||||
jobs = tqdm(jobs, desc='preloading images', disable=not self.verbose)
|
||||
return Parallel(n_jobs=cpu_count(), backend='threading')(jobs)
|
||||
|
||||
@staticmethod
|
||||
def _preload(x: str, preload_size: int):
|
||||
img = _read_img(x)
|
||||
if preload_size:
|
||||
h, w, *_ = img.shape
|
||||
h_scale = preload_size / h
|
||||
w_scale = preload_size / w
|
||||
scale = max(h_scale, w_scale)
|
||||
img = cv2.resize(img, fx=scale, fy=scale, dsize=None)
|
||||
assert min(img.shape[:2]) >= preload_size, f'weird img shape: {img.shape}'
|
||||
return img
|
||||
|
||||
def _preprocess(self, img, res):
|
||||
def transpose(x):
|
||||
return np.transpose(x, (2, 0, 1))
|
||||
|
||||
return map(transpose, self.normalize_fn(img, res))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_a)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
a, b = self.data_a[idx], self.data_b[idx]
|
||||
if not self.preload:
|
||||
a, b = map(_read_img, (a, b))
|
||||
a, b = self.transform_fn(a, b)
|
||||
if self.corrupt_fn is not None:
|
||||
a = self.corrupt_fn(a)
|
||||
a, b = self._preprocess(a, b)
|
||||
return {'a': a, 'b': b}
|
||||
|
||||
@staticmethod
|
||||
def from_config(config):
|
||||
config = deepcopy(config)
|
||||
files_a, files_b = map(lambda x: sorted(glob(config[x], recursive=True)), ('files_a', 'files_b'))
|
||||
transform_fn = aug.get_transforms(size=config['size'], scope=config['scope'], crop=config['crop'])
|
||||
normalize_fn = aug.get_normalize()
|
||||
corrupt_fn = aug.get_corrupt_function(config['corrupt'])
|
||||
|
||||
hash_fn = hash_from_paths
|
||||
# ToDo: add more hash functions
|
||||
verbose = config.get('verbose', True)
|
||||
data = subsample(data=zip(files_a, files_b),
|
||||
bounds=config.get('bounds', (0, 1)),
|
||||
hash_fn=hash_fn,
|
||||
verbose=verbose)
|
||||
|
||||
files_a, files_b = map(list, zip(*data))
|
||||
|
||||
return PairedDataset(files_a=files_a,
|
||||
files_b=files_b,
|
||||
preload=config['preload'],
|
||||
preload_size=config['preload_size'],
|
||||
corrupt_fn=corrupt_fn,
|
||||
normalize_fn=normalize_fn,
|
||||
transform_fn=transform_fn,
|
||||
verbose=verbose)
|
@ -0,0 +1,56 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
WINDOW_SIZE = 100
|
||||
|
||||
|
||||
class MetricCounter:
|
||||
def __init__(self, exp_name):
|
||||
self.writer = SummaryWriter(exp_name)
|
||||
logging.basicConfig(filename='{}.log'.format(exp_name), level=logging.DEBUG)
|
||||
self.metrics = defaultdict(list)
|
||||
self.images = defaultdict(list)
|
||||
self.best_metric = 0
|
||||
|
||||
def add_image(self, x: np.ndarray, tag: str):
|
||||
self.images[tag].append(x)
|
||||
|
||||
def clear(self):
|
||||
self.metrics = defaultdict(list)
|
||||
self.images = defaultdict(list)
|
||||
|
||||
def add_losses(self, l_G, l_content, l_D=0):
|
||||
for name, value in zip(('G_loss', 'G_loss_content', 'G_loss_adv', 'D_loss'),
|
||||
(l_G, l_content, l_G - l_content, l_D)):
|
||||
self.metrics[name].append(value)
|
||||
|
||||
def add_metrics(self, psnr, ssim):
|
||||
for name, value in zip(('PSNR', 'SSIM'),
|
||||
(psnr, ssim)):
|
||||
self.metrics[name].append(value)
|
||||
|
||||
def loss_message(self):
|
||||
metrics = ((k, np.mean(self.metrics[k][-WINDOW_SIZE:])) for k in ('G_loss', 'PSNR', 'SSIM'))
|
||||
return '; '.join(map(lambda x: f'{x[0]}={x[1]:.4f}', metrics))
|
||||
|
||||
def write_to_tensorboard(self, epoch_num, validation=False):
|
||||
scalar_prefix = 'Validation' if validation else 'Train'
|
||||
for tag in ('G_loss', 'D_loss', 'G_loss_adv', 'G_loss_content', 'SSIM', 'PSNR'):
|
||||
self.writer.add_scalar(f'{scalar_prefix}_{tag}', np.mean(self.metrics[tag]), global_step=epoch_num)
|
||||
for tag in self.images:
|
||||
imgs = self.images[tag]
|
||||
if imgs:
|
||||
imgs = np.array(imgs)
|
||||
self.writer.add_images(tag, imgs[:, :, :, ::-1].astype('float32') / 255, dataformats='NHWC',
|
||||
global_step=epoch_num)
|
||||
self.images[tag] = []
|
||||
|
||||
def update_best_model(self):
|
||||
cur_metric = np.mean(self.metrics['PSNR'])
|
||||
if self.best_metric < cur_metric:
|
||||
self.best_metric = cur_metric
|
||||
return True
|
||||
return False
|
@ -0,0 +1,135 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torchvision.models import resnet50, densenet121, densenet201
|
||||
|
||||
|
||||
class FPNSegHead(nn.Module):
|
||||
def __init__(self, num_in, num_mid, num_out):
|
||||
super().__init__()
|
||||
|
||||
self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False)
|
||||
self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = nn.functional.relu(self.block0(x), inplace=True)
|
||||
x = nn.functional.relu(self.block1(x), inplace=True)
|
||||
return x
|
||||
|
||||
|
||||
class FPNDense(nn.Module):
|
||||
|
||||
def __init__(self, output_ch=3, num_filters=128, num_filters_fpn=256, pretrained=True):
|
||||
super().__init__()
|
||||
|
||||
# Feature Pyramid Network (FPN) with four feature maps of resolutions
|
||||
# 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.
|
||||
|
||||
self.fpn = FPN(num_filters=num_filters_fpn, pretrained=pretrained)
|
||||
|
||||
# The segmentation heads on top of the FPN
|
||||
|
||||
self.head1 = FPNSegHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head2 = FPNSegHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head3 = FPNSegHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head4 = FPNSegHead(num_filters_fpn, num_filters, num_filters)
|
||||
|
||||
self.smooth = nn.Sequential(
|
||||
nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(num_filters),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.smooth2 = nn.Sequential(
|
||||
nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(num_filters // 2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
map0, map1, map2, map3, map4 = self.fpn(x)
|
||||
|
||||
map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest")
|
||||
map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest")
|
||||
map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest")
|
||||
map1 = nn.functional.upsample(self.head1(map1), scale_factor=1, mode="nearest")
|
||||
|
||||
smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1))
|
||||
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
|
||||
smoothed = self.smooth2(smoothed + map0)
|
||||
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
|
||||
|
||||
final = self.final(smoothed)
|
||||
|
||||
nn.Tanh(final)
|
||||
|
||||
|
||||
class FPN(nn.Module):
|
||||
|
||||
def __init__(self, num_filters=256, pretrained=True):
|
||||
"""Creates an `FPN` instance for feature extraction.
|
||||
Args:
|
||||
num_filters: the number of filters in each output pyramid level
|
||||
pretrained: use ImageNet pre-trained backbone feature extractor
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.features = densenet121(pretrained=pretrained).features
|
||||
|
||||
self.enc0 = nn.Sequential(self.features.conv0,
|
||||
self.features.norm0,
|
||||
self.features.relu0)
|
||||
self.pool0 = self.features.pool0
|
||||
self.enc1 = self.features.denseblock1 # 256
|
||||
self.enc2 = self.features.denseblock2 # 512
|
||||
self.enc3 = self.features.denseblock3 # 1024
|
||||
self.enc4 = self.features.denseblock4 # 2048
|
||||
self.norm = self.features.norm5 # 2048
|
||||
|
||||
self.tr1 = self.features.transition1 # 256
|
||||
self.tr2 = self.features.transition2 # 512
|
||||
self.tr3 = self.features.transition3 # 1024
|
||||
|
||||
self.lateral4 = nn.Conv2d(1024, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral3 = nn.Conv2d(1024, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral2 = nn.Conv2d(512, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral1 = nn.Conv2d(256, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral0 = nn.Conv2d(64, num_filters // 2, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
# Bottom-up pathway, from ResNet
|
||||
enc0 = self.enc0(x)
|
||||
|
||||
pooled = self.pool0(enc0)
|
||||
|
||||
enc1 = self.enc1(pooled) # 256
|
||||
tr1 = self.tr1(enc1)
|
||||
|
||||
enc2 = self.enc2(tr1) # 512
|
||||
tr2 = self.tr2(enc2)
|
||||
|
||||
enc3 = self.enc3(tr2) # 1024
|
||||
tr3 = self.tr3(enc3)
|
||||
|
||||
enc4 = self.enc4(tr3) # 2048
|
||||
enc4 = self.norm(enc4)
|
||||
|
||||
# Lateral connections
|
||||
|
||||
lateral4 = self.lateral4(enc4)
|
||||
lateral3 = self.lateral3(enc3)
|
||||
lateral2 = self.lateral2(enc2)
|
||||
lateral1 = self.lateral1(enc1)
|
||||
lateral0 = self.lateral0(enc0)
|
||||
|
||||
# Top-down pathway
|
||||
|
||||
map4 = lateral4
|
||||
map3 = lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest")
|
||||
map2 = lateral2 + nn.functional.upsample(map3, scale_factor=2, mode="nearest")
|
||||
map1 = lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest")
|
||||
|
||||
return lateral0, map1, map2, map3, map4
|
@ -0,0 +1,167 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
# from pretrainedmodels import inceptionresnetv2
|
||||
# from torchsummary import summary
|
||||
import torch.nn.functional as F
|
||||
|
||||
class FPNHead(nn.Module):
|
||||
def __init__(self, num_in, num_mid, num_out):
|
||||
super(FPNHead,self).__init__()
|
||||
|
||||
self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False)
|
||||
self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = nn.functional.relu(self.block0(x), inplace=True)
|
||||
x = nn.functional.relu(self.block1(x), inplace=True)
|
||||
return x
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, num_in, num_out, norm_layer):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(nn.Conv2d(num_in, num_out, kernel_size=3, padding=1),
|
||||
norm_layer(num_out),
|
||||
nn.ReLU(inplace=True))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.block(x)
|
||||
return x
|
||||
|
||||
|
||||
class FPNInception(nn.Module):
|
||||
|
||||
def __init__(self, norm_layer, output_ch=3, num_filters=128, num_filters_fpn=256):
|
||||
super(FPNInception,self).__init__()
|
||||
|
||||
# Feature Pyramid Network (FPN) with four feature maps of resolutions
|
||||
# 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.
|
||||
self.fpn = FPN(num_filters=num_filters_fpn, norm_layer=norm_layer)
|
||||
|
||||
# The segmentation heads on top of the FPN
|
||||
|
||||
self.head1 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head2 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head3 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head4 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
|
||||
self.smooth = nn.Sequential(
|
||||
nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.smooth2 = nn.Sequential(
|
||||
nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters // 2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1)
|
||||
|
||||
def unfreeze(self):
|
||||
self.fpn.unfreeze()
|
||||
|
||||
def forward(self, x):
|
||||
map0, map1, map2, map3, map4 = self.fpn(x)
|
||||
|
||||
map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest")
|
||||
map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest")
|
||||
map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest")
|
||||
map1 = nn.functional.upsample(self.head1(map1), scale_factor=1, mode="nearest")
|
||||
|
||||
smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1))
|
||||
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
|
||||
smoothed = self.smooth2(smoothed + map0)
|
||||
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
|
||||
|
||||
final = self.final(smoothed)
|
||||
res = torch.tanh(final) + x
|
||||
|
||||
return torch.clamp(res, min = -1,max = 1)
|
||||
|
||||
|
||||
class FPN(nn.Module):
|
||||
|
||||
def __init__(self, norm_layer, num_filters=256):
|
||||
"""Creates an `FPN` instance for feature extraction.
|
||||
Args:
|
||||
num_filters: the number of filters in each output pyramid level
|
||||
pretrained: use ImageNet pre-trained backbone feature extractor
|
||||
"""
|
||||
|
||||
super(FPN,self).__init__()
|
||||
self.inception = inceptionresnetv2(num_classes=1000, pretrained='imagenet')
|
||||
# self.inception = torch.load('inceptionresnetv2-520b38e4.pth')
|
||||
self.enc0 = self.inception.conv2d_1a
|
||||
self.enc1 = nn.Sequential(
|
||||
self.inception.conv2d_2a,
|
||||
self.inception.conv2d_2b,
|
||||
self.inception.maxpool_3a,
|
||||
) # 64
|
||||
self.enc2 = nn.Sequential(
|
||||
self.inception.conv2d_3b,
|
||||
self.inception.conv2d_4a,
|
||||
self.inception.maxpool_5a,
|
||||
) # 192
|
||||
self.enc3 = nn.Sequential(
|
||||
self.inception.mixed_5b,
|
||||
self.inception.repeat,
|
||||
self.inception.mixed_6a,
|
||||
) # 1088
|
||||
self.enc4 = nn.Sequential(
|
||||
self.inception.repeat_1,
|
||||
self.inception.mixed_7a,
|
||||
) #2080
|
||||
self.td1 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters),
|
||||
nn.ReLU(inplace=True))
|
||||
self.td2 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters),
|
||||
nn.ReLU(inplace=True))
|
||||
self.td3 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters),
|
||||
nn.ReLU(inplace=True))
|
||||
self.pad = nn.ReflectionPad2d(1)
|
||||
self.lateral4 = nn.Conv2d(2080, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral3 = nn.Conv2d(1088, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral2 = nn.Conv2d(192, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral1 = nn.Conv2d(64, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral0 = nn.Conv2d(32, num_filters // 2, kernel_size=1, bias=False)
|
||||
|
||||
for param in self.inception.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def unfreeze(self):
|
||||
for param in self.inception.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# Bottom-up pathway, from ResNet
|
||||
enc0 = self.enc0(x)
|
||||
|
||||
enc1 = self.enc1(enc0) # 256
|
||||
|
||||
enc2 = self.enc2(enc1) # 512
|
||||
|
||||
enc3 = self.enc3(enc2) # 1024
|
||||
|
||||
enc4 = self.enc4(enc3) # 2048
|
||||
|
||||
# Lateral connections
|
||||
|
||||
lateral4 = self.pad(self.lateral4(enc4))
|
||||
lateral3 = self.pad(self.lateral3(enc3))
|
||||
lateral2 = self.lateral2(enc2)
|
||||
lateral1 = self.pad(self.lateral1(enc1))
|
||||
lateral0 = self.lateral0(enc0)
|
||||
|
||||
# Top-down pathway
|
||||
pad = (1, 2, 1, 2) # pad last dim by 1 on each side
|
||||
pad1 = (0, 1, 0, 1)
|
||||
map4 = lateral4
|
||||
map3 = self.td1(lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest"))
|
||||
map2 = self.td2(F.pad(lateral2, pad, "reflect") + nn.functional.upsample(map3, scale_factor=2, mode="nearest"))
|
||||
map1 = self.td3(lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest"))
|
||||
return F.pad(lateral0, pad1, "reflect"), map1, map2, map3, map4
|
@ -0,0 +1,160 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pretrainedmodels import inceptionresnetv2
|
||||
from torchsummary import summary
|
||||
import torch.nn.functional as F
|
||||
|
||||
class FPNHead(nn.Module):
|
||||
def __init__(self, num_in, num_mid, num_out):
|
||||
super().__init__()
|
||||
|
||||
self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False)
|
||||
self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = nn.functional.relu(self.block0(x), inplace=True)
|
||||
x = nn.functional.relu(self.block1(x), inplace=True)
|
||||
return x
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, num_in, num_out, norm_layer):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(nn.Conv2d(num_in, num_out, kernel_size=3, padding=1),
|
||||
norm_layer(num_out),
|
||||
nn.ReLU(inplace=True))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.block(x)
|
||||
return x
|
||||
|
||||
|
||||
class FPNInceptionSimple(nn.Module):
|
||||
|
||||
def __init__(self, norm_layer, output_ch=3, num_filters=128, num_filters_fpn=256):
|
||||
super().__init__()
|
||||
|
||||
# Feature Pyramid Network (FPN) with four feature maps of resolutions
|
||||
# 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.
|
||||
self.fpn = FPN(num_filters=num_filters_fpn, norm_layer=norm_layer)
|
||||
|
||||
# The segmentation heads on top of the FPN
|
||||
|
||||
self.head1 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head2 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head3 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head4 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
|
||||
self.smooth = nn.Sequential(
|
||||
nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.smooth2 = nn.Sequential(
|
||||
nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters // 2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1)
|
||||
|
||||
def unfreeze(self):
|
||||
self.fpn.unfreeze()
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
map0, map1, map2, map3, map4 = self.fpn(x)
|
||||
|
||||
map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest")
|
||||
map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest")
|
||||
map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest")
|
||||
map1 = nn.functional.upsample(self.head1(map1), scale_factor=1, mode="nearest")
|
||||
|
||||
smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1))
|
||||
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
|
||||
smoothed = self.smooth2(smoothed + map0)
|
||||
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
|
||||
|
||||
final = self.final(smoothed)
|
||||
res = torch.tanh(final) + x
|
||||
|
||||
return torch.clamp(res, min = -1,max = 1)
|
||||
|
||||
|
||||
class FPN(nn.Module):
|
||||
|
||||
def __init__(self, norm_layer, num_filters=256):
|
||||
"""Creates an `FPN` instance for feature extraction.
|
||||
Args:
|
||||
num_filters: the number of filters in each output pyramid level
|
||||
pretrained: use ImageNet pre-trained backbone feature extractor
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.inception = inceptionresnetv2(num_classes=1000, pretrained='imagenet')
|
||||
|
||||
self.enc0 = self.inception.conv2d_1a
|
||||
self.enc1 = nn.Sequential(
|
||||
self.inception.conv2d_2a,
|
||||
self.inception.conv2d_2b,
|
||||
self.inception.maxpool_3a,
|
||||
) # 64
|
||||
self.enc2 = nn.Sequential(
|
||||
self.inception.conv2d_3b,
|
||||
self.inception.conv2d_4a,
|
||||
self.inception.maxpool_5a,
|
||||
) # 192
|
||||
self.enc3 = nn.Sequential(
|
||||
self.inception.mixed_5b,
|
||||
self.inception.repeat,
|
||||
self.inception.mixed_6a,
|
||||
) # 1088
|
||||
self.enc4 = nn.Sequential(
|
||||
self.inception.repeat_1,
|
||||
self.inception.mixed_7a,
|
||||
) #2080
|
||||
|
||||
self.pad = nn.ReflectionPad2d(1)
|
||||
self.lateral4 = nn.Conv2d(2080, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral3 = nn.Conv2d(1088, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral2 = nn.Conv2d(192, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral1 = nn.Conv2d(64, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral0 = nn.Conv2d(32, num_filters // 2, kernel_size=1, bias=False)
|
||||
|
||||
for param in self.inception.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def unfreeze(self):
|
||||
for param in self.inception.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# Bottom-up pathway, from ResNet
|
||||
enc0 = self.enc0(x)
|
||||
|
||||
enc1 = self.enc1(enc0) # 256
|
||||
|
||||
enc2 = self.enc2(enc1) # 512
|
||||
|
||||
enc3 = self.enc3(enc2) # 1024
|
||||
|
||||
enc4 = self.enc4(enc3) # 2048
|
||||
|
||||
# Lateral connections
|
||||
|
||||
lateral4 = self.pad(self.lateral4(enc4))
|
||||
lateral3 = self.pad(self.lateral3(enc3))
|
||||
lateral2 = self.lateral2(enc2)
|
||||
lateral1 = self.pad(self.lateral1(enc1))
|
||||
lateral0 = self.lateral0(enc0)
|
||||
|
||||
# Top-down pathway
|
||||
pad = (1, 2, 1, 2) # pad last dim by 1 on each side
|
||||
pad1 = (0, 1, 0, 1)
|
||||
map4 = lateral4
|
||||
map3 = lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest")
|
||||
map2 = F.pad(lateral2, pad, "reflect") + nn.functional.upsample(map3, scale_factor=2, mode="nearest")
|
||||
map1 = lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest")
|
||||
return F.pad(lateral0, pad1, "reflect"), map1, map2, map3, map4
|
@ -0,0 +1,150 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import sys
|
||||
if sys.version_info.major>2:
|
||||
from models.mobilenet_v2 import MobileNetV2
|
||||
else:
|
||||
from mobilenet_v2 import MobileNetV2
|
||||
class FPNHead(nn.Module):
|
||||
def __init__(self, num_in, num_mid, num_out):
|
||||
super().__init__()
|
||||
|
||||
self.block0 = nn.Conv2d(num_in, num_mid, kernel_size=3, padding=1, bias=False)
|
||||
self.block1 = nn.Conv2d(num_mid, num_out, kernel_size=3, padding=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = nn.functional.relu(self.block0(x), inplace=True)
|
||||
x = nn.functional.relu(self.block1(x), inplace=True)
|
||||
return x
|
||||
|
||||
|
||||
class FPNMobileNet(nn.Module):
|
||||
|
||||
def __init__(self, norm_layer, output_ch=3, num_filters=64, num_filters_fpn=128, pretrained=True):
|
||||
super().__init__()
|
||||
|
||||
# Feature Pyramid Network (FPN) with four feature maps of resolutions
|
||||
# 1/4, 1/8, 1/16, 1/32 and `num_filters` filters for all feature maps.
|
||||
|
||||
self.fpn = FPN(num_filters=num_filters_fpn, norm_layer = norm_layer, pretrained=pretrained)
|
||||
|
||||
# The segmentation heads on top of the FPN
|
||||
|
||||
self.head1 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head2 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head3 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
self.head4 = FPNHead(num_filters_fpn, num_filters, num_filters)
|
||||
|
||||
self.smooth = nn.Sequential(
|
||||
nn.Conv2d(4 * num_filters, num_filters, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.smooth2 = nn.Sequential(
|
||||
nn.Conv2d(num_filters, num_filters // 2, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters // 2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.final = nn.Conv2d(num_filters // 2, output_ch, kernel_size=3, padding=1)
|
||||
|
||||
def unfreeze(self):
|
||||
self.fpn.unfreeze()
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
map0, map1, map2, map3, map4 = self.fpn(x)
|
||||
|
||||
map4 = nn.functional.upsample(self.head4(map4), scale_factor=8, mode="nearest")
|
||||
map3 = nn.functional.upsample(self.head3(map3), scale_factor=4, mode="nearest")
|
||||
map2 = nn.functional.upsample(self.head2(map2), scale_factor=2, mode="nearest")
|
||||
map1 = nn.functional.upsample(self.head1(map1), scale_factor=1, mode="nearest")
|
||||
|
||||
smoothed = self.smooth(torch.cat([map4, map3, map2, map1], dim=1))
|
||||
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
|
||||
smoothed = self.smooth2(smoothed + map0)
|
||||
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")
|
||||
|
||||
final = self.final(smoothed)
|
||||
res = torch.tanh(final) + x
|
||||
|
||||
return torch.clamp(res, min=-1, max=1)
|
||||
|
||||
|
||||
class FPN(nn.Module):
|
||||
|
||||
def __init__(self, norm_layer, num_filters=128, pretrained=True):
|
||||
"""Creates an `FPN` instance for feature extraction.
|
||||
Args:
|
||||
num_filters: the number of filters in each output pyramid level
|
||||
pretrained: use ImageNet pre-trained backbone feature extractor
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
net = MobileNetV2(n_class=1000)
|
||||
|
||||
if pretrained:
|
||||
#Load weights into the project directory
|
||||
state_dict = torch.load('mobilenetv2.pth.tar') # add map_location='cpu' if no gpu
|
||||
net.load_state_dict(state_dict)
|
||||
self.features = net.features
|
||||
|
||||
self.enc0 = nn.Sequential(*self.features[0:2])
|
||||
self.enc1 = nn.Sequential(*self.features[2:4])
|
||||
self.enc2 = nn.Sequential(*self.features[4:7])
|
||||
self.enc3 = nn.Sequential(*self.features[7:11])
|
||||
self.enc4 = nn.Sequential(*self.features[11:16])
|
||||
|
||||
self.td1 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters),
|
||||
nn.ReLU(inplace=True))
|
||||
self.td2 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters),
|
||||
nn.ReLU(inplace=True))
|
||||
self.td3 = nn.Sequential(nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
|
||||
norm_layer(num_filters),
|
||||
nn.ReLU(inplace=True))
|
||||
|
||||
self.lateral4 = nn.Conv2d(160, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral3 = nn.Conv2d(64, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral2 = nn.Conv2d(32, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral1 = nn.Conv2d(24, num_filters, kernel_size=1, bias=False)
|
||||
self.lateral0 = nn.Conv2d(16, num_filters // 2, kernel_size=1, bias=False)
|
||||
|
||||
for param in self.features.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def unfreeze(self):
|
||||
for param in self.features.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# Bottom-up pathway, from ResNet
|
||||
enc0 = self.enc0(x)
|
||||
|
||||
enc1 = self.enc1(enc0) # 256
|
||||
|
||||
enc2 = self.enc2(enc1) # 512
|
||||
|
||||
enc3 = self.enc3(enc2) # 1024
|
||||
|
||||
enc4 = self.enc4(enc3) # 2048
|
||||
|
||||
# Lateral connections
|
||||
|
||||
lateral4 = self.lateral4(enc4)
|
||||
lateral3 = self.lateral3(enc3)
|
||||
lateral2 = self.lateral2(enc2)
|
||||
lateral1 = self.lateral1(enc1)
|
||||
lateral0 = self.lateral0(enc0)
|
||||
|
||||
# Top-down pathway
|
||||
map4 = lateral4
|
||||
map3 = self.td1(lateral3 + nn.functional.upsample(map4, scale_factor=2, mode="nearest"))
|
||||
map2 = self.td2(lateral2 + nn.functional.upsample(map3, scale_factor=2, mode="nearest"))
|
||||
map1 = self.td3(lateral1 + nn.functional.upsample(map2, scale_factor=2, mode="nearest"))
|
||||
return lateral0, map1, map2, map3, map4
|
||||
|
@ -0,0 +1,300 @@
|
||||
import torch
|
||||
import torch.autograd as autograd
|
||||
import torch.nn as nn
|
||||
import torchvision.models as models
|
||||
import torchvision.transforms as transforms
|
||||
from torch.autograd import Variable
|
||||
|
||||
from util.image_pool import ImagePool
|
||||
|
||||
|
||||
###############################################################################
|
||||
# Functions
|
||||
###############################################################################
|
||||
|
||||
class ContentLoss():
|
||||
def initialize(self, loss):
|
||||
self.criterion = loss
|
||||
|
||||
def get_loss(self, fakeIm, realIm):
|
||||
return self.criterion(fakeIm, realIm)
|
||||
|
||||
def __call__(self, fakeIm, realIm):
|
||||
return self.get_loss(fakeIm, realIm)
|
||||
|
||||
|
||||
class PerceptualLoss():
|
||||
|
||||
def contentFunc(self):
|
||||
conv_3_3_layer = 14
|
||||
cnn = models.vgg19(pretrained=True).features
|
||||
cnn = cnn.cuda()
|
||||
model = nn.Sequential()
|
||||
model = model.cuda()
|
||||
model = model.eval()
|
||||
for i, layer in enumerate(list(cnn)):
|
||||
model.add_module(str(i), layer)
|
||||
if i == conv_3_3_layer:
|
||||
break
|
||||
return model
|
||||
|
||||
def initialize(self, loss):
|
||||
with torch.no_grad():
|
||||
self.criterion = loss
|
||||
self.contentFunc = self.contentFunc()
|
||||
self.transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
def get_loss(self, fakeIm, realIm):
|
||||
fakeIm = (fakeIm + 1) / 2.0
|
||||
realIm = (realIm + 1) / 2.0
|
||||
fakeIm[0, :, :, :] = self.transform(fakeIm[0, :, :, :])
|
||||
realIm[0, :, :, :] = self.transform(realIm[0, :, :, :])
|
||||
f_fake = self.contentFunc.forward(fakeIm)
|
||||
f_real = self.contentFunc.forward(realIm)
|
||||
f_real_no_grad = f_real.detach()
|
||||
loss = self.criterion(f_fake, f_real_no_grad)
|
||||
return 0.006 * torch.mean(loss) + 0.5 * nn.MSELoss()(fakeIm, realIm)
|
||||
|
||||
def __call__(self, fakeIm, realIm):
|
||||
return self.get_loss(fakeIm, realIm)
|
||||
|
||||
|
||||
class GANLoss(nn.Module):
|
||||
def __init__(self, use_l1=True, target_real_label=1.0, target_fake_label=0.0,
|
||||
tensor=torch.FloatTensor):
|
||||
super(GANLoss, self).__init__()
|
||||
self.real_label = target_real_label
|
||||
self.fake_label = target_fake_label
|
||||
self.real_label_var = None
|
||||
self.fake_label_var = None
|
||||
self.Tensor = tensor
|
||||
if use_l1:
|
||||
self.loss = nn.L1Loss()
|
||||
else:
|
||||
self.loss = nn.BCEWithLogitsLoss()
|
||||
|
||||
def get_target_tensor(self, input, target_is_real):
|
||||
if target_is_real:
|
||||
create_label = ((self.real_label_var is None) or
|
||||
(self.real_label_var.numel() != input.numel()))
|
||||
if create_label:
|
||||
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
|
||||
self.real_label_var = Variable(real_tensor, requires_grad=False)
|
||||
target_tensor = self.real_label_var
|
||||
else:
|
||||
create_label = ((self.fake_label_var is None) or
|
||||
(self.fake_label_var.numel() != input.numel()))
|
||||
if create_label:
|
||||
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
|
||||
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
|
||||
target_tensor = self.fake_label_var
|
||||
return target_tensor.cuda()
|
||||
|
||||
def __call__(self, input, target_is_real):
|
||||
target_tensor = self.get_target_tensor(input, target_is_real)
|
||||
return self.loss(input, target_tensor)
|
||||
|
||||
|
||||
class DiscLoss(nn.Module):
|
||||
def name(self):
|
||||
return 'DiscLoss'
|
||||
|
||||
def __init__(self):
|
||||
super(DiscLoss, self).__init__()
|
||||
|
||||
self.criterionGAN = GANLoss(use_l1=False)
|
||||
self.fake_AB_pool = ImagePool(50)
|
||||
|
||||
def get_g_loss(self, net, fakeB, realB):
|
||||
# First, G(A) should fake the discriminator
|
||||
pred_fake = net.forward(fakeB)
|
||||
return self.criterionGAN(pred_fake, 1)
|
||||
|
||||
def get_loss(self, net, fakeB, realB):
|
||||
# Fake
|
||||
# stop backprop to the generator by detaching fake_B
|
||||
# Generated Image Disc Output should be close to zero
|
||||
self.pred_fake = net.forward(fakeB.detach())
|
||||
self.loss_D_fake = self.criterionGAN(self.pred_fake, 0)
|
||||
|
||||
# Real
|
||||
self.pred_real = net.forward(realB)
|
||||
self.loss_D_real = self.criterionGAN(self.pred_real, 1)
|
||||
|
||||
# Combined loss
|
||||
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
|
||||
return self.loss_D
|
||||
|
||||
def __call__(self, net, fakeB, realB):
|
||||
return self.get_loss(net, fakeB, realB)
|
||||
|
||||
|
||||
class RelativisticDiscLoss(nn.Module):
|
||||
def name(self):
|
||||
return 'RelativisticDiscLoss'
|
||||
|
||||
def __init__(self):
|
||||
super(RelativisticDiscLoss, self).__init__()
|
||||
|
||||
self.criterionGAN = GANLoss(use_l1=False)
|
||||
self.fake_pool = ImagePool(50) # create image buffer to store previously generated images
|
||||
self.real_pool = ImagePool(50)
|
||||
|
||||
def get_g_loss(self, net, fakeB, realB):
|
||||
# First, G(A) should fake the discriminator
|
||||
self.pred_fake = net.forward(fakeB)
|
||||
|
||||
# Real
|
||||
self.pred_real = net.forward(realB)
|
||||
errG = (self.criterionGAN(self.pred_real - torch.mean(self.fake_pool.query()), 0) +
|
||||
self.criterionGAN(self.pred_fake - torch.mean(self.real_pool.query()), 1)) / 2
|
||||
return errG
|
||||
|
||||
def get_loss(self, net, fakeB, realB):
|
||||
# Fake
|
||||
# stop backprop to the generator by detaching fake_B
|
||||
# Generated Image Disc Output should be close to zero
|
||||
self.fake_B = fakeB.detach()
|
||||
self.real_B = realB
|
||||
self.pred_fake = net.forward(fakeB.detach())
|
||||
self.fake_pool.add(self.pred_fake)
|
||||
|
||||
# Real
|
||||
self.pred_real = net.forward(realB)
|
||||
self.real_pool.add(self.pred_real)
|
||||
|
||||
# Combined loss
|
||||
self.loss_D = (self.criterionGAN(self.pred_real - torch.mean(self.fake_pool.query()), 1) +
|
||||
self.criterionGAN(self.pred_fake - torch.mean(self.real_pool.query()), 0)) / 2
|
||||
return self.loss_D
|
||||
|
||||
def __call__(self, net, fakeB, realB):
|
||||
return self.get_loss(net, fakeB, realB)
|
||||
|
||||
|
||||
class RelativisticDiscLossLS(nn.Module):
|
||||
def name(self):
|
||||
return 'RelativisticDiscLossLS'
|
||||
|
||||
def __init__(self):
|
||||
super(RelativisticDiscLossLS, self).__init__()
|
||||
|
||||
self.criterionGAN = GANLoss(use_l1=True)
|
||||
self.fake_pool = ImagePool(50) # create image buffer to store previously generated images
|
||||
self.real_pool = ImagePool(50)
|
||||
|
||||
def get_g_loss(self, net, fakeB, realB):
|
||||
# First, G(A) should fake the discriminator
|
||||
self.pred_fake = net.forward(fakeB)
|
||||
|
||||
# Real
|
||||
self.pred_real = net.forward(realB)
|
||||
errG = (torch.mean((self.pred_real - torch.mean(self.fake_pool.query()) + 1) ** 2) +
|
||||
torch.mean((self.pred_fake - torch.mean(self.real_pool.query()) - 1) ** 2)) / 2
|
||||
return errG
|
||||
|
||||
def get_loss(self, net, fakeB, realB):
|
||||
# Fake
|
||||
# stop backprop to the generator by detaching fake_B
|
||||
# Generated Image Disc Output should be close to zero
|
||||
self.fake_B = fakeB.detach()
|
||||
self.real_B = realB
|
||||
self.pred_fake = net.forward(fakeB.detach())
|
||||
self.fake_pool.add(self.pred_fake)
|
||||
|
||||
# Real
|
||||
self.pred_real = net.forward(realB)
|
||||
self.real_pool.add(self.pred_real)
|
||||
|
||||
# Combined loss
|
||||
self.loss_D = (torch.mean((self.pred_real - torch.mean(self.fake_pool.query()) - 1) ** 2) +
|
||||
torch.mean((self.pred_fake - torch.mean(self.real_pool.query()) + 1) ** 2)) / 2
|
||||
return self.loss_D
|
||||
|
||||
def __call__(self, net, fakeB, realB):
|
||||
return self.get_loss(net, fakeB, realB)
|
||||
|
||||
|
||||
class DiscLossLS(DiscLoss):
|
||||
def name(self):
|
||||
return 'DiscLossLS'
|
||||
|
||||
def __init__(self):
|
||||
super(DiscLossLS, self).__init__()
|
||||
self.criterionGAN = GANLoss(use_l1=True)
|
||||
|
||||
def get_g_loss(self, net, fakeB, realB):
|
||||
return DiscLoss.get_g_loss(self, net, fakeB)
|
||||
|
||||
def get_loss(self, net, fakeB, realB):
|
||||
return DiscLoss.get_loss(self, net, fakeB, realB)
|
||||
|
||||
|
||||
class DiscLossWGANGP(DiscLossLS):
|
||||
def name(self):
|
||||
return 'DiscLossWGAN-GP'
|
||||
|
||||
def __init__(self):
|
||||
super(DiscLossWGANGP, self).__init__()
|
||||
self.LAMBDA = 10
|
||||
|
||||
def get_g_loss(self, net, fakeB, realB):
|
||||
# First, G(A) should fake the discriminator
|
||||
self.D_fake = net.forward(fakeB)
|
||||
return -self.D_fake.mean()
|
||||
|
||||
def calc_gradient_penalty(self, netD, real_data, fake_data):
|
||||
alpha = torch.rand(1, 1)
|
||||
alpha = alpha.expand(real_data.size())
|
||||
alpha = alpha.cuda()
|
||||
|
||||
interpolates = alpha * real_data + ((1 - alpha) * fake_data)
|
||||
|
||||
interpolates = interpolates.cuda()
|
||||
interpolates = Variable(interpolates, requires_grad=True)
|
||||
|
||||
disc_interpolates = netD.forward(interpolates)
|
||||
|
||||
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
|
||||
grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
|
||||
create_graph=True, retain_graph=True, only_inputs=True)[0]
|
||||
|
||||
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.LAMBDA
|
||||
return gradient_penalty
|
||||
|
||||
def get_loss(self, net, fakeB, realB):
|
||||
self.D_fake = net.forward(fakeB.detach())
|
||||
self.D_fake = self.D_fake.mean()
|
||||
|
||||
# Real
|
||||
self.D_real = net.forward(realB)
|
||||
self.D_real = self.D_real.mean()
|
||||
# Combined loss
|
||||
self.loss_D = self.D_fake - self.D_real
|
||||
gradient_penalty = self.calc_gradient_penalty(net, realB.data, fakeB.data)
|
||||
return self.loss_D + gradient_penalty
|
||||
|
||||
|
||||
def get_loss(model):
|
||||
if model['content_loss'] == 'perceptual':
|
||||
content_loss = PerceptualLoss()
|
||||
content_loss.initialize(nn.MSELoss())
|
||||
elif model['content_loss'] == 'l1':
|
||||
content_loss = ContentLoss()
|
||||
content_loss.initialize(nn.L1Loss())
|
||||
else:
|
||||
raise ValueError("ContentLoss [%s] not recognized." % model['content_loss'])
|
||||
|
||||
if model['disc_loss'] == 'wgan-gp':
|
||||
disc_loss = DiscLossWGANGP()
|
||||
elif model['disc_loss'] == 'lsgan':
|
||||
disc_loss = DiscLossLS()
|
||||
elif model['disc_loss'] == 'gan':
|
||||
disc_loss = DiscLoss()
|
||||
elif model['disc_loss'] == 'ragan':
|
||||
disc_loss = RelativisticDiscLoss()
|
||||
elif model['disc_loss'] == 'ragan-ls':
|
||||
disc_loss = RelativisticDiscLossLS()
|
||||
else:
|
||||
raise ValueError("GAN Loss [%s] not recognized." % model['disc_loss'])
|
||||
return content_loss, disc_loss
|
@ -0,0 +1,126 @@
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
|
||||
def conv_bn(inp, oup, stride):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
|
||||
def conv_1x1_bn(inp, oup):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
nn.ReLU6(inplace=True)
|
||||
)
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride, expand_ratio):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = round(inp * expand_ratio)
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
|
||||
if expand_ratio == 1:
|
||||
self.conv = nn.Sequential(
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU6(inplace=True),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Sequential(
|
||||
# pw
|
||||
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU6(inplace=True),
|
||||
# dw
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU6(inplace=True),
|
||||
# pw-linear
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
def __init__(self, n_class=1000, input_size=224, width_mult=1.):
|
||||
super(MobileNetV2, self).__init__()
|
||||
block = InvertedResidual
|
||||
input_channel = 32
|
||||
last_channel = 1280
|
||||
interverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 320, 1, 1],
|
||||
]
|
||||
|
||||
# building first layer
|
||||
assert input_size % 32 == 0
|
||||
input_channel = int(input_channel * width_mult)
|
||||
self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
|
||||
self.features = [conv_bn(3, input_channel, 2)]
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in interverted_residual_setting:
|
||||
output_channel = int(c * width_mult)
|
||||
for i in range(n):
|
||||
if i == 0:
|
||||
self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
|
||||
else:
|
||||
self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
|
||||
input_channel = output_channel
|
||||
# building last several layers
|
||||
self.features.append(conv_1x1_bn(input_channel, self.last_channel))
|
||||
# make it nn.Sequential
|
||||
self.features = nn.Sequential(*self.features)
|
||||
|
||||
# building classifier
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(self.last_channel, n_class),
|
||||
)
|
||||
|
||||
self._initialize_weights()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = x.mean(3).mean(2)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
n = m.weight.size(1)
|
||||
m.weight.data.normal_(0, 0.01)
|
||||
m.bias.data.zero_()
|
||||
|
@ -0,0 +1,39 @@
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
# from skimage.measure import compare_ssim as SSIM
|
||||
import sys
|
||||
|
||||
|
||||
# if sys.version_info.major>2:
|
||||
# from util.metrics import PSNR
|
||||
# else:
|
||||
# from metrics import PSNR
|
||||
|
||||
class DeblurModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(DeblurModel, self).__init__()
|
||||
|
||||
def get_input(self, data):
|
||||
img = data['a']
|
||||
inputs = img
|
||||
targets = data['b']
|
||||
inputs, targets = inputs.cuda(), targets.cuda()
|
||||
return inputs, targets
|
||||
|
||||
def tensor2im(self, image_tensor, imtype=np.uint8):
|
||||
image_numpy = image_tensor[0].cpu().float().numpy()
|
||||
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
|
||||
return image_numpy.astype(imtype)
|
||||
|
||||
def get_images_and_metrics(self, inp, output, target):
|
||||
inp = self.tensor2im(inp)
|
||||
fake = self.tensor2im(output.data)
|
||||
real = self.tensor2im(target.data)
|
||||
psnr = PSNR(fake, real)
|
||||
ssim = SSIM(fake, real, multichannel=True)
|
||||
vis_img = np.hstack((inp, fake, real))
|
||||
return psnr, ssim, vis_img
|
||||
|
||||
|
||||
def get_model(model_config):
|
||||
return DeblurModel()
|
@ -0,0 +1,339 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import init
|
||||
import functools
|
||||
from torch.autograd import Variable
|
||||
import numpy as np
|
||||
import sys
|
||||
if sys.version_info.major>2:
|
||||
from models.fpn_mobilenet import FPNMobileNet
|
||||
from models.fpn_inception import FPNInception
|
||||
# from fpn_inception_simple import FPNInceptionSimple
|
||||
from models.unet_seresnext import UNetSEResNext
|
||||
from models.fpn_densenet import FPNDense
|
||||
else:
|
||||
from fpn_mobilenet import FPNMobileNet
|
||||
from fpn_inception import FPNInception
|
||||
# from fpn_inception_simple import FPNInceptionSimple
|
||||
from unet_seresnext import UNetSEResNext
|
||||
from fpn_densenet import FPNDense
|
||||
|
||||
###############################################################################
|
||||
# Functions
|
||||
###############################################################################
|
||||
|
||||
|
||||
def get_norm_layer(norm_type='instance'):
|
||||
if norm_type == 'batch':
|
||||
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
||||
elif norm_type == 'instance':
|
||||
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
|
||||
else:
|
||||
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
||||
return norm_layer
|
||||
|
||||
##############################################################################
|
||||
# Classes
|
||||
##############################################################################
|
||||
|
||||
|
||||
# Defines the generator that consists of Resnet blocks between a few
|
||||
# downsampling/upsampling operations.
|
||||
# Code and idea originally from Justin Johnson's architecture.
|
||||
# https://github.com/jcjohnson/fast-neural-style/
|
||||
class ResnetGenerator(nn.Module):
|
||||
def __init__(self, input_nc=3, output_nc=3, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, use_parallel=True, learn_residual=True, padding_type='reflect'):
|
||||
assert(n_blocks >= 0)
|
||||
super(ResnetGenerator, self).__init__()
|
||||
self.input_nc = input_nc
|
||||
self.output_nc = output_nc
|
||||
self.ngf = ngf
|
||||
self.use_parallel = use_parallel
|
||||
self.learn_residual = learn_residual
|
||||
if type(norm_layer) == functools.partial:
|
||||
use_bias = norm_layer.func == nn.InstanceNorm2d
|
||||
else:
|
||||
use_bias = norm_layer == nn.InstanceNorm2d
|
||||
|
||||
model = [nn.ReflectionPad2d(3),
|
||||
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
|
||||
bias=use_bias),
|
||||
norm_layer(ngf),
|
||||
nn.ReLU(True)]
|
||||
|
||||
n_downsampling = 2
|
||||
for i in range(n_downsampling):
|
||||
mult = 2**i
|
||||
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
|
||||
stride=2, padding=1, bias=use_bias),
|
||||
norm_layer(ngf * mult * 2),
|
||||
nn.ReLU(True)]
|
||||
|
||||
mult = 2**n_downsampling
|
||||
for i in range(n_blocks):
|
||||
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
||||
|
||||
for i in range(n_downsampling):
|
||||
mult = 2**(n_downsampling - i)
|
||||
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
||||
kernel_size=3, stride=2,
|
||||
padding=1, output_padding=1,
|
||||
bias=use_bias),
|
||||
norm_layer(int(ngf * mult / 2)),
|
||||
nn.ReLU(True)]
|
||||
model += [nn.ReflectionPad2d(3)]
|
||||
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
||||
model += [nn.Tanh()]
|
||||
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(self, input):
|
||||
output = self.model(input)
|
||||
if self.learn_residual:
|
||||
output = input + output
|
||||
output = torch.clamp(output,min = -1,max = 1)
|
||||
return output
|
||||
|
||||
|
||||
# Define a resnet block
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
||||
super(ResnetBlock, self).__init__()
|
||||
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
|
||||
|
||||
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
||||
conv_block = []
|
||||
p = 0
|
||||
if padding_type == 'reflect':
|
||||
conv_block += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv_block += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
|
||||
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
|
||||
norm_layer(dim),
|
||||
nn.ReLU(True)]
|
||||
if use_dropout:
|
||||
conv_block += [nn.Dropout(0.5)]
|
||||
|
||||
p = 0
|
||||
if padding_type == 'reflect':
|
||||
conv_block += [nn.ReflectionPad2d(1)]
|
||||
elif padding_type == 'replicate':
|
||||
conv_block += [nn.ReplicationPad2d(1)]
|
||||
elif padding_type == 'zero':
|
||||
p = 1
|
||||
else:
|
||||
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
|
||||
norm_layer(dim)]
|
||||
|
||||
return nn.Sequential(*conv_block)
|
||||
|
||||
def forward(self, x):
|
||||
out = x + self.conv_block(x)
|
||||
return out
|
||||
|
||||
|
||||
class DicsriminatorTail(nn.Module):
|
||||
def __init__(self, nf_mult, n_layers, ndf=64, norm_layer=nn.BatchNorm2d, use_parallel=True):
|
||||
super(DicsriminatorTail, self).__init__()
|
||||
self.use_parallel = use_parallel
|
||||
if type(norm_layer) == functools.partial:
|
||||
use_bias = norm_layer.func == nn.InstanceNorm2d
|
||||
else:
|
||||
use_bias = norm_layer == nn.InstanceNorm2d
|
||||
|
||||
kw = 4
|
||||
padw = int(np.ceil((kw-1)/2))
|
||||
|
||||
nf_mult_prev = nf_mult
|
||||
nf_mult = min(2**n_layers, 8)
|
||||
sequence = [
|
||||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
||||
kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
|
||||
|
||||
self.model = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, input):
|
||||
return self.model(input)
|
||||
|
||||
|
||||
class MultiScaleDiscriminator(nn.Module):
|
||||
def __init__(self, input_nc=3, ndf=64, norm_layer=nn.BatchNorm2d, use_parallel=True):
|
||||
super(MultiScaleDiscriminator, self).__init__()
|
||||
self.use_parallel = use_parallel
|
||||
if type(norm_layer) == functools.partial:
|
||||
use_bias = norm_layer.func == nn.InstanceNorm2d
|
||||
else:
|
||||
use_bias = norm_layer == nn.InstanceNorm2d
|
||||
|
||||
kw = 4
|
||||
padw = int(np.ceil((kw-1)/2))
|
||||
sequence = [
|
||||
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
nf_mult = 1
|
||||
for n in range(1, 3):
|
||||
nf_mult_prev = nf_mult
|
||||
nf_mult = min(2**n, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
||||
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
self.scale_one = nn.Sequential(*sequence)
|
||||
self.first_tail = DicsriminatorTail(nf_mult=nf_mult, n_layers=3)
|
||||
nf_mult_prev = 4
|
||||
nf_mult = 8
|
||||
|
||||
self.scale_two = nn.Sequential(
|
||||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
||||
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True))
|
||||
nf_mult_prev = nf_mult
|
||||
self.second_tail = DicsriminatorTail(nf_mult=nf_mult, n_layers=4)
|
||||
self.scale_three = nn.Sequential(
|
||||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True))
|
||||
self.third_tail = DicsriminatorTail(nf_mult=nf_mult, n_layers=5)
|
||||
|
||||
def forward(self, input):
|
||||
x = self.scale_one(input)
|
||||
x_1 = self.first_tail(x)
|
||||
x = self.scale_two(x)
|
||||
x_2 = self.second_tail(x)
|
||||
x = self.scale_three(x)
|
||||
x = self.third_tail(x)
|
||||
return [x_1, x_2, x]
|
||||
|
||||
|
||||
# Defines the PatchGAN discriminator with the specified arguments.
|
||||
class NLayerDiscriminator(nn.Module):
|
||||
def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, use_parallel=True):
|
||||
super(NLayerDiscriminator, self).__init__()
|
||||
self.use_parallel = use_parallel
|
||||
if type(norm_layer) == functools.partial:
|
||||
use_bias = norm_layer.func == nn.InstanceNorm2d
|
||||
else:
|
||||
use_bias = norm_layer == nn.InstanceNorm2d
|
||||
|
||||
kw = 4
|
||||
padw = int(np.ceil((kw-1)/2))
|
||||
sequence = [
|
||||
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
nf_mult = 1
|
||||
for n in range(1, n_layers):
|
||||
nf_mult_prev = nf_mult
|
||||
nf_mult = min(2**n, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
||||
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
nf_mult_prev = nf_mult
|
||||
nf_mult = min(2**n_layers, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
||||
kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
||||
norm_layer(ndf * nf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
|
||||
|
||||
if use_sigmoid:
|
||||
sequence += [nn.Sigmoid()]
|
||||
|
||||
self.model = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, input):
|
||||
return self.model(input)
|
||||
|
||||
|
||||
def get_fullD(model_config):
|
||||
model_d = NLayerDiscriminator(n_layers=5,
|
||||
norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
|
||||
use_sigmoid=False)
|
||||
return model_d
|
||||
|
||||
|
||||
def get_generator(model_config):
|
||||
generator_name = model_config['g_name']
|
||||
if generator_name == 'resnet':
|
||||
model_g = ResnetGenerator(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
|
||||
use_dropout=model_config['dropout'],
|
||||
n_blocks=model_config['blocks'],
|
||||
learn_residual=model_config['learn_residual'])
|
||||
elif generator_name == 'fpn_mobilenet':
|
||||
model_g = FPNMobileNet(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
|
||||
elif generator_name == 'fpn_inception':
|
||||
# model_g = FPNInception(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
|
||||
# torch.save(model_g, 'mymodel.pth')
|
||||
model_g = torch.load('mymodel.pth')
|
||||
elif generator_name == 'fpn_inception_simple':
|
||||
model_g = FPNInceptionSimple(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
|
||||
elif generator_name == 'fpn_dense':
|
||||
model_g = FPNDense()
|
||||
elif generator_name == 'unet_seresnext':
|
||||
model_g = UNetSEResNext(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
|
||||
pretrained=model_config['pretrained'])
|
||||
else:
|
||||
raise ValueError("Generator Network [%s] not recognized." % generator_name)
|
||||
|
||||
return nn.DataParallel(model_g)
|
||||
|
||||
def get_generator_new(weights_path):
|
||||
|
||||
model_g = torch.load(weights_path+'mymodel.pth')
|
||||
|
||||
return nn.DataParallel(model_g)
|
||||
|
||||
def get_discriminator(model_config):
|
||||
discriminator_name = model_config['d_name']
|
||||
if discriminator_name == 'no_gan':
|
||||
model_d = None
|
||||
elif discriminator_name == 'patch_gan':
|
||||
model_d = NLayerDiscriminator(n_layers=model_config['d_layers'],
|
||||
norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
|
||||
use_sigmoid=False)
|
||||
model_d = nn.DataParallel(model_d)
|
||||
elif discriminator_name == 'double_gan':
|
||||
patch_gan = NLayerDiscriminator(n_layers=model_config['d_layers'],
|
||||
norm_layer=get_norm_layer(norm_type=model_config['norm_layer']),
|
||||
use_sigmoid=False)
|
||||
patch_gan = nn.DataParallel(patch_gan)
|
||||
full_gan = get_fullD(model_config)
|
||||
full_gan = nn.DataParallel(full_gan)
|
||||
model_d = {'patch': patch_gan,
|
||||
'full': full_gan}
|
||||
elif discriminator_name == 'multi_scale':
|
||||
model_d = MultiScaleDiscriminator(norm_layer=get_norm_layer(norm_type=model_config['norm_layer']))
|
||||
model_d = nn.DataParallel(model_d)
|
||||
else:
|
||||
raise ValueError("Discriminator Network [%s] not recognized." % discriminator_name)
|
||||
|
||||
return model_d
|
||||
|
||||
|
||||
def get_nets(model_config):
|
||||
return get_generator(model_config), get_discriminator(model_config)
|
@ -0,0 +1,430 @@
|
||||
from __future__ import print_function, division, absolute_import
|
||||
from collections import OrderedDict
|
||||
import math
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.utils import model_zoo
|
||||
|
||||
__all__ = ['SENet', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152',
|
||||
'se_resnext50_32x4d', 'se_resnext101_32x4d']
|
||||
|
||||
pretrained_settings = {
|
||||
'senet154': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
'se_resnet50': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
'se_resnet101': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
'se_resnet152': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
'se_resnext50_32x4d': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
'se_resnext101_32x4d': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class SEModule(nn.Module):
|
||||
|
||||
def __init__(self, channels, reduction):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1,
|
||||
padding=0)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1,
|
||||
padding=0)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
module_input = x
|
||||
x = self.avg_pool(x)
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.sigmoid(x)
|
||||
return module_input * x
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
"""
|
||||
Base class for bottlenecks that implements `forward()` method.
|
||||
"""
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out = self.se_module(out) + residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class SEBottleneck(Bottleneck):
|
||||
"""
|
||||
Bottleneck for SENet154.
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, groups, reduction, stride=1,
|
||||
downsample=None):
|
||||
super(SEBottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1)
|
||||
self.bn1 = nn.InstanceNorm2d(planes * 2, affine=False)
|
||||
self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3,
|
||||
stride=stride, padding=1, groups=groups)
|
||||
self.bn2 = nn.InstanceNorm2d(planes * 4, affine=False)
|
||||
self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1)
|
||||
self.bn3 = nn.InstanceNorm2d(planes * 4, affine=False)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.se_module = SEModule(planes * 4, reduction=reduction)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
|
||||
class SEResNetBottleneck(Bottleneck):
|
||||
"""
|
||||
ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe
|
||||
implementation and uses `stride=stride` in `conv1` and not in `conv2`
|
||||
(the latter is used in the torchvision implementation of ResNet).
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, groups, reduction, stride=1,
|
||||
downsample=None):
|
||||
super(SEResNetBottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1,
|
||||
stride=stride)
|
||||
self.bn1 = nn.InstanceNorm2d(planes, affine=False)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1,
|
||||
groups=groups)
|
||||
self.bn2 = nn.InstanceNorm2d(planes, affine=False)
|
||||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1)
|
||||
self.bn3 = nn.InstanceNorm2d(planes * 4, affine=False)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.se_module = SEModule(planes * 4, reduction=reduction)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
|
||||
class SEResNeXtBottleneck(Bottleneck):
|
||||
"""
|
||||
ResNeXt bottleneck type C with a Squeeze-and-Excitation module.
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, groups, reduction, stride=1,
|
||||
downsample=None, base_width=4):
|
||||
super(SEResNeXtBottleneck, self).__init__()
|
||||
width = math.floor(planes * (base_width / 64)) * groups
|
||||
self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1,
|
||||
stride=1)
|
||||
self.bn1 = nn.InstanceNorm2d(width, affine=False)
|
||||
self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
|
||||
padding=1, groups=groups)
|
||||
self.bn2 = nn.InstanceNorm2d(width, affine=False)
|
||||
self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1)
|
||||
self.bn3 = nn.InstanceNorm2d(planes * 4, affine=False)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.se_module = SEModule(planes * 4, reduction=reduction)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
|
||||
class SENet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, groups, reduction, dropout_p=0.2,
|
||||
inplanes=128, input_3x3=True, downsample_kernel_size=3,
|
||||
downsample_padding=1, num_classes=1000):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
block (nn.Module): Bottleneck class.
|
||||
- For SENet154: SEBottleneck
|
||||
- For SE-ResNet models: SEResNetBottleneck
|
||||
- For SE-ResNeXt models: SEResNeXtBottleneck
|
||||
layers (list of ints): Number of residual blocks for 4 layers of the
|
||||
network (layer1...layer4).
|
||||
groups (int): Number of groups for the 3x3 convolution in each
|
||||
bottleneck block.
|
||||
- For SENet154: 64
|
||||
- For SE-ResNet models: 1
|
||||
- For SE-ResNeXt models: 32
|
||||
reduction (int): Reduction ratio for Squeeze-and-Excitation modules.
|
||||
- For all models: 16
|
||||
dropout_p (float or None): Drop probability for the Dropout layer.
|
||||
If `None` the Dropout layer is not used.
|
||||
- For SENet154: 0.2
|
||||
- For SE-ResNet models: None
|
||||
- For SE-ResNeXt models: None
|
||||
inplanes (int): Number of input channels for layer1.
|
||||
- For SENet154: 128
|
||||
- For SE-ResNet models: 64
|
||||
- For SE-ResNeXt models: 64
|
||||
input_3x3 (bool): If `True`, use three 3x3 convolutions instead of
|
||||
a single 7x7 convolution in layer0.
|
||||
- For SENet154: True
|
||||
- For SE-ResNet models: False
|
||||
- For SE-ResNeXt models: False
|
||||
downsample_kernel_size (int): Kernel size for downsampling convolutions
|
||||
in layer2, layer3 and layer4.
|
||||
- For SENet154: 3
|
||||
- For SE-ResNet models: 1
|
||||
- For SE-ResNeXt models: 1
|
||||
downsample_padding (int): Padding for downsampling convolutions in
|
||||
layer2, layer3 and layer4.
|
||||
- For SENet154: 1
|
||||
- For SE-ResNet models: 0
|
||||
- For SE-ResNeXt models: 0
|
||||
num_classes (int): Number of outputs in `last_linear` layer.
|
||||
- For all models: 1000
|
||||
"""
|
||||
super(SENet, self).__init__()
|
||||
self.inplanes = inplanes
|
||||
if input_3x3:
|
||||
layer0_modules = [
|
||||
('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1)),
|
||||
('bn1', nn.InstanceNorm2d(64, affine=False)),
|
||||
('relu1', nn.ReLU(inplace=True)),
|
||||
('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1)),
|
||||
('bn2', nn.InstanceNorm2d(64, affine=False)),
|
||||
('relu2', nn.ReLU(inplace=True)),
|
||||
('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1)),
|
||||
('bn3', nn.InstanceNorm2d(inplanes, affine=False)),
|
||||
('relu3', nn.ReLU(inplace=True)),
|
||||
]
|
||||
else:
|
||||
layer0_modules = [
|
||||
('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2,
|
||||
padding=3)),
|
||||
('bn1', nn.InstanceNorm2d(inplanes, affine=False)),
|
||||
('relu1', nn.ReLU(inplace=True)),
|
||||
]
|
||||
# To preserve compatibility with Caffe weights `ceil_mode=True`
|
||||
# is used instead of `padding=1`.
|
||||
layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2,
|
||||
ceil_mode=True)))
|
||||
self.layer0 = nn.Sequential(OrderedDict(layer0_modules))
|
||||
self.layer1 = self._make_layer(
|
||||
block,
|
||||
planes=64,
|
||||
blocks=layers[0],
|
||||
groups=groups,
|
||||
reduction=reduction,
|
||||
downsample_kernel_size=1,
|
||||
downsample_padding=0
|
||||
)
|
||||
self.layer2 = self._make_layer(
|
||||
block,
|
||||
planes=128,
|
||||
blocks=layers[1],
|
||||
stride=2,
|
||||
groups=groups,
|
||||
reduction=reduction,
|
||||
downsample_kernel_size=downsample_kernel_size,
|
||||
downsample_padding=downsample_padding
|
||||
)
|
||||
self.layer3 = self._make_layer(
|
||||
block,
|
||||
planes=256,
|
||||
blocks=layers[2],
|
||||
stride=2,
|
||||
groups=groups,
|
||||
reduction=reduction,
|
||||
downsample_kernel_size=downsample_kernel_size,
|
||||
downsample_padding=downsample_padding
|
||||
)
|
||||
self.layer4 = self._make_layer(
|
||||
block,
|
||||
planes=512,
|
||||
blocks=layers[3],
|
||||
stride=2,
|
||||
groups=groups,
|
||||
reduction=reduction,
|
||||
downsample_kernel_size=downsample_kernel_size,
|
||||
downsample_padding=downsample_padding
|
||||
)
|
||||
self.avg_pool = nn.AvgPool2d(7, stride=1)
|
||||
self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
|
||||
self.last_linear = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
|
||||
downsample_kernel_size=1, downsample_padding=0):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=downsample_kernel_size, stride=stride,
|
||||
padding=downsample_padding),
|
||||
nn.InstanceNorm2d(planes * block.expansion, affine=False),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, groups, reduction, stride,
|
||||
downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups, reduction))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def features(self, x):
|
||||
x = self.layer0(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
return x
|
||||
|
||||
def logits(self, x):
|
||||
x = self.avg_pool(x)
|
||||
if self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.last_linear(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.logits(x)
|
||||
return x
|
||||
|
||||
|
||||
def initialize_pretrained_model(model, num_classes, settings):
|
||||
assert num_classes == settings['num_classes'], \
|
||||
'num_classes should be {}, but is {}'.format(
|
||||
settings['num_classes'], num_classes)
|
||||
model.load_state_dict(model_zoo.load_url(settings['url']))
|
||||
model.input_space = settings['input_space']
|
||||
model.input_size = settings['input_size']
|
||||
model.input_range = settings['input_range']
|
||||
model.mean = settings['mean']
|
||||
model.std = settings['std']
|
||||
|
||||
|
||||
def senet154(num_classes=1000, pretrained='imagenet'):
|
||||
model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16,
|
||||
dropout_p=0.2, num_classes=num_classes)
|
||||
if pretrained is not None:
|
||||
settings = pretrained_settings['senet154'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, settings)
|
||||
return model
|
||||
|
||||
|
||||
def se_resnet50(num_classes=1000, pretrained='imagenet'):
|
||||
model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16,
|
||||
dropout_p=None, inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes)
|
||||
if pretrained is not None:
|
||||
settings = pretrained_settings['se_resnet50'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, settings)
|
||||
return model
|
||||
|
||||
|
||||
def se_resnet101(num_classes=1000, pretrained='imagenet'):
|
||||
model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16,
|
||||
dropout_p=None, inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes)
|
||||
if pretrained is not None:
|
||||
settings = pretrained_settings['se_resnet101'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, settings)
|
||||
return model
|
||||
|
||||
|
||||
def se_resnet152(num_classes=1000, pretrained='imagenet'):
|
||||
model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16,
|
||||
dropout_p=None, inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes)
|
||||
if pretrained is not None:
|
||||
settings = pretrained_settings['se_resnet152'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, settings)
|
||||
return model
|
||||
|
||||
|
||||
def se_resnext50_32x4d(num_classes=1000, pretrained='imagenet'):
|
||||
model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16,
|
||||
dropout_p=None, inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes)
|
||||
return model
|
||||
|
||||
|
||||
def se_resnext101_32x4d(num_classes=1000, pretrained='imagenet'):
|
||||
model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16,
|
||||
dropout_p=None, inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes)
|
||||
if pretrained is not None:
|
||||
settings = pretrained_settings['se_resnext101_32x4d'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, settings)
|
||||
return model
|
@ -0,0 +1,158 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.parallel
|
||||
import torch.optim
|
||||
import torch.utils.data
|
||||
from torch.nn import Sequential
|
||||
from collections import OrderedDict
|
||||
import torchvision
|
||||
from torch.nn import functional as F
|
||||
import sys
|
||||
|
||||
if sys.version_info.major > 2:
|
||||
from models.senet import se_resnext50_32x4d
|
||||
else:
|
||||
from senet import se_resnext50_32x4d
|
||||
|
||||
|
||||
def conv3x3(in_, out):
|
||||
return nn.Conv2d(in_, out, 3, padding=1)
|
||||
|
||||
|
||||
class ConvRelu(nn.Module):
|
||||
def __init__(self, in_, out):
|
||||
super(ConvRelu, self).__init__()
|
||||
self.conv = conv3x3(in_, out)
|
||||
self.activation = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.activation(x)
|
||||
return x
|
||||
|
||||
|
||||
class UNetSEResNext(nn.Module):
|
||||
|
||||
def __init__(self, num_classes=3, num_filters=32,
|
||||
pretrained=True, is_deconv=True):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
pretrain = 'imagenet' if pretrained is True else None
|
||||
self.encoder = se_resnext50_32x4d(num_classes=1000, pretrained=pretrain)
|
||||
bottom_channel_nr = 2048
|
||||
|
||||
self.conv1 = self.encoder.layer0
|
||||
# self.se_e1 = SCSEBlock(64)
|
||||
self.conv2 = self.encoder.layer1
|
||||
# self.se_e2 = SCSEBlock(64 * 4)
|
||||
self.conv3 = self.encoder.layer2
|
||||
# self.se_e3 = SCSEBlock(128 * 4)
|
||||
self.conv4 = self.encoder.layer3
|
||||
# self.se_e4 = SCSEBlock(256 * 4)
|
||||
self.conv5 = self.encoder.layer4
|
||||
# self.se_e5 = SCSEBlock(512 * 4)
|
||||
|
||||
self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 * 2, num_filters * 8, False)
|
||||
|
||||
self.dec5 = DecoderBlockV(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 2, is_deconv)
|
||||
# self.se_d5 = SCSEBlock(num_filters * 2)
|
||||
self.dec4 = DecoderBlockV(bottom_channel_nr // 2 + num_filters * 2, num_filters * 8, num_filters * 2, is_deconv)
|
||||
# self.se_d4 = SCSEBlock(num_filters * 2)
|
||||
self.dec3 = DecoderBlockV(bottom_channel_nr // 4 + num_filters * 2, num_filters * 4, num_filters * 2, is_deconv)
|
||||
# self.se_d3 = SCSEBlock(num_filters * 2)
|
||||
self.dec2 = DecoderBlockV(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2, num_filters * 2, is_deconv)
|
||||
# self.se_d2 = SCSEBlock(num_filters * 2)
|
||||
self.dec1 = DecoderBlockV(num_filters * 2, num_filters, num_filters * 2, is_deconv)
|
||||
# self.se_d1 = SCSEBlock(num_filters * 2)
|
||||
self.dec0 = ConvRelu(num_filters * 10, num_filters * 2)
|
||||
self.final = nn.Conv2d(num_filters * 2, num_classes, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
conv1 = self.conv1(x)
|
||||
# conv1 = self.se_e1(conv1)
|
||||
conv2 = self.conv2(conv1)
|
||||
# conv2 = self.se_e2(conv2)
|
||||
conv3 = self.conv3(conv2)
|
||||
# conv3 = self.se_e3(conv3)
|
||||
conv4 = self.conv4(conv3)
|
||||
# conv4 = self.se_e4(conv4)
|
||||
conv5 = self.conv5(conv4)
|
||||
# conv5 = self.se_e5(conv5)
|
||||
|
||||
center = self.center(conv5)
|
||||
dec5 = self.dec5(torch.cat([center, conv5], 1))
|
||||
# dec5 = self.se_d5(dec5)
|
||||
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
|
||||
# dec4 = self.se_d4(dec4)
|
||||
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
|
||||
# dec3 = self.se_d3(dec3)
|
||||
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
|
||||
# dec2 = self.se_d2(dec2)
|
||||
dec1 = self.dec1(dec2)
|
||||
# dec1 = self.se_d1(dec1)
|
||||
|
||||
f = torch.cat((
|
||||
dec1,
|
||||
F.upsample(dec2, scale_factor=2, mode='bilinear', align_corners=False),
|
||||
F.upsample(dec3, scale_factor=4, mode='bilinear', align_corners=False),
|
||||
F.upsample(dec4, scale_factor=8, mode='bilinear', align_corners=False),
|
||||
F.upsample(dec5, scale_factor=16, mode='bilinear', align_corners=False),
|
||||
), 1)
|
||||
|
||||
dec0 = self.dec0(f)
|
||||
|
||||
return self.final(dec0)
|
||||
|
||||
|
||||
class DecoderBlockV(nn.Module):
|
||||
def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
|
||||
super(DecoderBlockV, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
if is_deconv:
|
||||
self.block = nn.Sequential(
|
||||
ConvRelu(in_channels, middle_channels),
|
||||
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
|
||||
padding=1),
|
||||
nn.InstanceNorm2d(out_channels, affine=False),
|
||||
nn.ReLU(inplace=True)
|
||||
|
||||
)
|
||||
else:
|
||||
self.block = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode='bilinear'),
|
||||
ConvRelu(in_channels, middle_channels),
|
||||
ConvRelu(middle_channels, out_channels),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class DecoderCenter(nn.Module):
|
||||
def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
|
||||
super(DecoderCenter, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
if is_deconv:
|
||||
"""
|
||||
Paramaters for Deconvolution were chosen to avoid artifacts, following
|
||||
link https://distill.pub/2016/deconv-checkerboard/
|
||||
"""
|
||||
|
||||
self.block = nn.Sequential(
|
||||
ConvRelu(in_channels, middle_channels),
|
||||
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
|
||||
padding=1),
|
||||
nn.InstanceNorm2d(out_channels, affine=False),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
else:
|
||||
self.block = nn.Sequential(
|
||||
ConvRelu(in_channels, middle_channels),
|
||||
ConvRelu(middle_channels, out_channels)
|
||||
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
@ -0,0 +1,108 @@
|
||||
import os
|
||||
from glob import glob
|
||||
# from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from fire import Fire
|
||||
from tqdm import tqdm
|
||||
|
||||
from aug import get_normalize
|
||||
from models.networks import get_generator
|
||||
|
||||
|
||||
class Predictor:
|
||||
def __init__(self, weights_path, model_name=''):
|
||||
with open('config/config.yaml') as cfg:
|
||||
config = yaml.load(cfg)
|
||||
model = get_generator(model_name or config['model'])
|
||||
model.load_state_dict(torch.load(weights_path, map_location=lambda storage, loc: storage)['model'])
|
||||
if torch.cuda.is_available():
|
||||
self.model = model.cuda()
|
||||
else:
|
||||
self.model = model
|
||||
self.model.train(True)
|
||||
# GAN inference should be in train mode to use actual stats in norm layers,
|
||||
# it's not a bug
|
||||
self.normalize_fn = get_normalize()
|
||||
|
||||
@staticmethod
|
||||
def _array_to_batch(x):
|
||||
x = np.transpose(x, (2, 0, 1))
|
||||
x = np.expand_dims(x, 0)
|
||||
return torch.from_numpy(x)
|
||||
|
||||
def _preprocess(self, x, mask):
|
||||
x, _ = self.normalize_fn(x, x)
|
||||
if mask is None:
|
||||
mask = np.ones_like(x, dtype=np.float32)
|
||||
else:
|
||||
mask = np.round(mask.astype('float32') / 255)
|
||||
|
||||
h, w, _ = x.shape
|
||||
block_size = 32
|
||||
min_height = (h // block_size + 1) * block_size
|
||||
min_width = (w // block_size + 1) * block_size
|
||||
|
||||
pad_params = {'mode': 'constant',
|
||||
'constant_values': 0,
|
||||
'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0))
|
||||
}
|
||||
x = np.pad(x, **pad_params)
|
||||
mask = np.pad(mask, **pad_params)
|
||||
|
||||
return map(self._array_to_batch, (x, mask)), h, w
|
||||
|
||||
@staticmethod
|
||||
def _postprocess(x):
|
||||
x, = x
|
||||
x = x.detach().cpu().float().numpy()
|
||||
x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0
|
||||
return x.astype('uint8')
|
||||
|
||||
def __call__(self, img, mask, ignore_mask=True):
|
||||
(img, mask), h, w = self._preprocess(img, mask)
|
||||
with torch.no_grad():
|
||||
if torch.cuda.is_available():
|
||||
inputs = [img.cuda()]
|
||||
else:
|
||||
inputs = [img]
|
||||
if not ignore_mask:
|
||||
inputs += [mask]
|
||||
pred = self.model(*inputs)
|
||||
return self._postprocess(pred)[:h, :w, :]
|
||||
|
||||
def sorted_glob(pattern):
|
||||
return sorted(glob(pattern))
|
||||
|
||||
def main(img_pattern,
|
||||
mask_pattern = None,
|
||||
weights_path='best_fpn.h5',
|
||||
out_dir='submit/',
|
||||
side_by_side = False):
|
||||
|
||||
|
||||
imgs = sorted_glob(img_pattern)
|
||||
masks = sorted_glob(mask_pattern) if mask_pattern is not None else [None for _ in imgs]
|
||||
pairs = zip(imgs, masks)
|
||||
names = sorted([os.path.basename(x) for x in glob(img_pattern)])
|
||||
predictor = Predictor(weights_path=weights_path)
|
||||
|
||||
# os.makedirs(out_dir)
|
||||
for name, pair in tqdm(zip(names, pairs), total=len(names)):
|
||||
f_img, f_mask = pair
|
||||
img, mask = map(cv2.imread, (f_img, f_mask))
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
pred = predictor(img, mask)
|
||||
if side_by_side:
|
||||
pred = np.hstack((img, pred))
|
||||
pred = cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)
|
||||
cv2.imwrite(os.path.join(out_dir, name),
|
||||
pred)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
Fire(main)
|
@ -0,0 +1,86 @@
|
||||
from models.networks import get_generator_new
|
||||
# from aug import get_normalize
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
config = {'project': 'deblur_gan', 'warmup_num': 3, 'optimizer': {'lr': 0.0001, 'name': 'adam'},
|
||||
'val': {'preload': False, 'bounds': [0.9, 1], 'crop': 'center', 'files_b': '/datasets/my_dataset/**/*.jpg',
|
||||
'files_a': '/datasets/my_dataset/**/*.jpg', 'scope': 'geometric',
|
||||
'corrupt': [{'num_holes': 3, 'max_w_size': 25, 'max_h_size': 25, 'name': 'cutout', 'prob': 0.5},
|
||||
{'quality_lower': 70, 'name': 'jpeg', 'quality_upper': 90}, {'name': 'motion_blur'},
|
||||
{'name': 'median_blur'}, {'name': 'gamma'}, {'name': 'rgb_shift'}, {'name': 'hsv_shift'},
|
||||
{'name': 'sharpen'}], 'preload_size': 0, 'size': 256}, 'val_batches_per_epoch': 100,
|
||||
'num_epochs': 200, 'batch_size': 1, 'experiment_desc': 'fpn', 'train_batches_per_epoch': 1000,
|
||||
'train': {'preload': False, 'bounds': [0, 0.9], 'crop': 'random', 'files_b': '/datasets/my_dataset/**/*.jpg',
|
||||
'files_a': '/datasets/my_dataset/**/*.jpg', 'preload_size': 0,
|
||||
'corrupt': [{'num_holes': 3, 'max_w_size': 25, 'max_h_size': 25, 'name': 'cutout', 'prob': 0.5},
|
||||
{'quality_lower': 70, 'name': 'jpeg', 'quality_upper': 90}, {'name': 'motion_blur'},
|
||||
{'name': 'median_blur'}, {'name': 'gamma'}, {'name': 'rgb_shift'},
|
||||
{'name': 'hsv_shift'}, {'name': 'sharpen'}], 'scope': 'geometric', 'size': 256},
|
||||
'scheduler': {'min_lr': 1e-07, 'name': 'linear', 'start_epoch': 50}, 'image_size': [256, 256],
|
||||
'phase': 'train',
|
||||
'model': {'d_name': 'double_gan', 'disc_loss': 'wgan-gp', 'blocks': 9, 'content_loss': 'perceptual',
|
||||
'adv_lambda': 0.001, 'dropout': True, 'g_name': 'fpn_inception', 'd_layers': 3,
|
||||
'learn_residual': True, 'norm_layer': 'instance'}}
|
||||
|
||||
|
||||
class Predictor:
|
||||
def __init__(self, weights_path, model_name='',cf=False):
|
||||
# model = get_generator(model_name or config['model'])
|
||||
model = get_generator_new(weights_path[0:-11])
|
||||
model.load_state_dict(torch.load(weights_path, map_location=lambda storage, loc: storage)['model'])
|
||||
if torch.cuda.is_available() and not cf:
|
||||
self.model = model.cuda()
|
||||
else:
|
||||
self.model = model
|
||||
self.model.train(True)
|
||||
# GAN inference should be in train mode to use actual stats in norm layers,
|
||||
# it's not a bug
|
||||
# self.normalize_fn = get_normalize()
|
||||
|
||||
@staticmethod
|
||||
def _array_to_batch(x):
|
||||
x = np.transpose(x, (2, 0, 1))
|
||||
x = np.expand_dims(x, 0)
|
||||
return torch.from_numpy(x)
|
||||
|
||||
def _preprocess(self, x, mask):
|
||||
# x, _ = self.normalize_fn(x, x)
|
||||
x = ((x.astype(np.float32) / 255) - 0.5) / 0.5
|
||||
if mask is None:
|
||||
mask = np.ones_like(x, dtype=np.float32)
|
||||
else:
|
||||
mask = np.round(mask.astype('float32') / 255)
|
||||
|
||||
h, w, _ = x.shape
|
||||
block_size = 32
|
||||
min_height = (h // block_size + 1) * block_size
|
||||
min_width = (w // block_size + 1) * block_size
|
||||
|
||||
pad_params = {'mode': 'constant',
|
||||
'constant_values': 0,
|
||||
'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0))
|
||||
}
|
||||
x = np.pad(x, **pad_params)
|
||||
mask = np.pad(mask, **pad_params)
|
||||
|
||||
return map(self._array_to_batch, (x, mask)), h, w
|
||||
|
||||
@staticmethod
|
||||
def _postprocess(x):
|
||||
x, = x
|
||||
x = x.detach().cpu().float().numpy()
|
||||
x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0
|
||||
return x.astype('uint8')
|
||||
|
||||
def __call__(self, img, mask, ignore_mask=True,cf=False):
|
||||
(img, mask), h, w = self._preprocess(img, mask)
|
||||
with torch.no_grad():
|
||||
if torch.cuda.is_available() and not cf:
|
||||
inputs = [img.cuda()]
|
||||
else:
|
||||
inputs = [img]
|
||||
if not ignore_mask:
|
||||
inputs += [mask]
|
||||
pred = self.model(*inputs)
|
||||
return self._postprocess(pred)[:h, :w, :]
|
@ -0,0 +1,59 @@
|
||||
import math
|
||||
|
||||
from torch.optim import lr_scheduler
|
||||
|
||||
|
||||
class WarmRestart(lr_scheduler.CosineAnnealingLR):
|
||||
"""This class implements Stochastic Gradient Descent with Warm Restarts(SGDR): https://arxiv.org/abs/1608.03983.
|
||||
|
||||
Set the learning rate of each parameter group using a cosine annealing schedule, When last_epoch=-1, sets initial lr as lr.
|
||||
This can't support scheduler.step(epoch). please keep epoch=None.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, T_max=30, T_mult=1, eta_min=0, last_epoch=-1):
|
||||
"""implements SGDR
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
T_max : int
|
||||
Maximum number of epochs.
|
||||
T_mult : int
|
||||
Multiplicative factor of T_max.
|
||||
eta_min : int
|
||||
Minimum learning rate. Default: 0.
|
||||
last_epoch : int
|
||||
The index of last epoch. Default: -1.
|
||||
"""
|
||||
self.T_mult = T_mult
|
||||
super().__init__(optimizer, T_max, eta_min, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch == self.T_max:
|
||||
self.last_epoch = 0
|
||||
self.T_max *= self.T_mult
|
||||
return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 for
|
||||
base_lr in self.base_lrs]
|
||||
|
||||
|
||||
class LinearDecay(lr_scheduler._LRScheduler):
|
||||
"""This class implements LinearDecay
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, num_epochs, start_epoch=0, min_lr=0, last_epoch=-1):
|
||||
"""implements LinearDecay
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
|
||||
"""
|
||||
self.num_epochs = num_epochs
|
||||
self.start_epoch = start_epoch
|
||||
self.min_lr = min_lr
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch < self.start_epoch:
|
||||
return self.base_lrs
|
||||
return [base_lr - ((base_lr - self.min_lr) / self.num_epochs) * (self.last_epoch - self.start_epoch) for
|
||||
base_lr in self.base_lrs]
|
@ -0,0 +1,20 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from aug import get_transforms
|
||||
|
||||
|
||||
class AugTest(unittest.TestCase):
|
||||
@staticmethod
|
||||
def make_images():
|
||||
img = (np.random.rand(100, 100, 3) * 255).astype('uint8')
|
||||
return img.copy(), img.copy()
|
||||
|
||||
def test_aug(self):
|
||||
for scope in ('strong', 'weak'):
|
||||
for crop in ('random', 'center'):
|
||||
aug_pipeline = get_transforms(80, scope=scope, crop=crop)
|
||||
a, b = self.make_images()
|
||||
a, b = aug_pipeline(a, b)
|
||||
np.testing.assert_allclose(a, b)
|
@ -0,0 +1,76 @@
|
||||
import os
|
||||
import unittest
|
||||
from shutil import rmtree
|
||||
from tempfile import mkdtemp
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from dataset import PairedDataset
|
||||
|
||||
|
||||
def make_img():
|
||||
return (np.random.rand(100, 100, 3) * 255).astype('uint8')
|
||||
|
||||
|
||||
class AugTest(unittest.TestCase):
|
||||
tmp_dir = mkdtemp()
|
||||
raw = os.path.join(tmp_dir, 'raw')
|
||||
gt = os.path.join(tmp_dir, 'gt')
|
||||
|
||||
def setUp(self):
|
||||
for d in (self.raw, self.gt):
|
||||
os.makedirs(d)
|
||||
|
||||
for i in range(5):
|
||||
for d in (self.raw, self.gt):
|
||||
img = make_img()
|
||||
cv2.imwrite(os.path.join(d, f'{i}.png'), img)
|
||||
|
||||
def tearDown(self):
|
||||
rmtree(self.tmp_dir)
|
||||
|
||||
def dataset_gen(self, equal=True):
|
||||
base_config = {'files_a': os.path.join(self.raw, '*.png'),
|
||||
'files_b': os.path.join(self.raw if equal else self.gt, '*.png'),
|
||||
'size': 32,
|
||||
}
|
||||
for b in ([0, 1], [0, 0.9]):
|
||||
for scope in ('strong', 'weak'):
|
||||
for crop in ('random', 'center'):
|
||||
for preload in (0, 1):
|
||||
for preload_size in (0, 64):
|
||||
config = base_config.copy()
|
||||
config['bounds'] = b
|
||||
config['scope'] = scope
|
||||
config['crop'] = crop
|
||||
config['preload'] = preload
|
||||
config['preload_size'] = preload_size
|
||||
config['verbose'] = False
|
||||
dataset = PairedDataset.from_config(config)
|
||||
yield dataset
|
||||
|
||||
def test_equal_datasets(self):
|
||||
for dataset in self.dataset_gen(equal=True):
|
||||
dataloader = DataLoader(dataset=dataset,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
dataloader = iter(dataloader)
|
||||
batch = next(dataloader)
|
||||
a, b = map(lambda x: x.numpy(), map(batch.get, ('a', 'b')))
|
||||
|
||||
np.testing.assert_allclose(a, b)
|
||||
|
||||
def test_datasets(self):
|
||||
for dataset in self.dataset_gen(equal=False):
|
||||
dataloader = DataLoader(dataset=dataset,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
dataloader = iter(dataloader)
|
||||
batch = next(dataloader)
|
||||
a, b = map(lambda x: x.numpy(), map(batch.get, ('a', 'b')))
|
||||
|
||||
assert not np.all(a == b), 'images should not be the same'
|
@ -0,0 +1,90 @@
|
||||
from __future__ import print_function
|
||||
import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
import yaml
|
||||
import os
|
||||
from torchvision import models, transforms
|
||||
from torch.autograd import Variable
|
||||
import shutil
|
||||
import glob
|
||||
import tqdm
|
||||
from util.metrics import PSNR
|
||||
from albumentations import Compose, CenterCrop, PadIfNeeded
|
||||
from PIL import Image
|
||||
from ssim.ssimlib import SSIM
|
||||
from models.networks import get_generator
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser('Test an image')
|
||||
parser.add_argument('--img_folder', required=True, help='GoPRO Folder')
|
||||
parser.add_argument('--weights_path', required=True, help='Weights path')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def prepare_dirs(path):
|
||||
if os.path.exists(path):
|
||||
shutil.rmtree(path)
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
def get_gt_image(path):
|
||||
dir, filename = os.path.split(path)
|
||||
base, seq = os.path.split(dir)
|
||||
base, _ = os.path.split(base)
|
||||
img = cv2.cvtColor(cv2.imread(os.path.join(base, 'sharp', seq, filename)), cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
def test_image(model, image_path):
|
||||
img_transforms = transforms.Compose([
|
||||
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
])
|
||||
size_transform = Compose([
|
||||
PadIfNeeded(736, 1280)
|
||||
])
|
||||
crop = CenterCrop(720, 1280)
|
||||
img = cv2.imread(image_path)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img_s = size_transform(image=img)['image']
|
||||
img_tensor = torch.from_numpy(np.transpose(img_s / 255, (2, 0, 1)).astype('float32'))
|
||||
img_tensor = img_transforms(img_tensor)
|
||||
with torch.no_grad():
|
||||
img_tensor = Variable(img_tensor.unsqueeze(0).cuda())
|
||||
result_image = model(img_tensor)
|
||||
result_image = result_image[0].cpu().float().numpy()
|
||||
result_image = (np.transpose(result_image, (1, 2, 0)) + 1) / 2.0 * 255.0
|
||||
result_image = crop(image=result_image)['image']
|
||||
result_image = result_image.astype('uint8')
|
||||
gt_image = get_gt_image(image_path)
|
||||
_, filename = os.path.split(image_path)
|
||||
psnr = PSNR(result_image, gt_image)
|
||||
pilFake = Image.fromarray(result_image)
|
||||
pilReal = Image.fromarray(gt_image)
|
||||
ssim = SSIM(pilFake).cw_ssim_value(pilReal)
|
||||
return psnr, ssim
|
||||
|
||||
|
||||
def test(model, files):
|
||||
psnr = 0
|
||||
ssim = 0
|
||||
for file in tqdm.tqdm(files):
|
||||
cur_psnr, cur_ssim = test_image(model, file)
|
||||
psnr += cur_psnr
|
||||
ssim += cur_ssim
|
||||
print("PSNR = {}".format(psnr / len(files)))
|
||||
print("SSIM = {}".format(ssim / len(files)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
with open('config/config.yaml') as cfg:
|
||||
config = yaml.load(cfg)
|
||||
model = get_generator(config['model'])
|
||||
model.load_state_dict(torch.load(args.weights_path)['model'])
|
||||
model = model.cuda()
|
||||
filenames = sorted(glob.glob(args.img_folder + '/test' + '/blur/**/*.png', recursive=True))
|
||||
test(model, filenames)
|
@ -0,0 +1,9 @@
|
||||
import cv2
|
||||
from predictorClass import Predictor
|
||||
|
||||
predictor = Predictor(weights_path='best_fpn.h5')
|
||||
img = cv2.imread('img/img.jpg')
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
pred = predictor(img, None)
|
||||
pred = cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)
|
||||
cv2.imwrite('submit/img.jpg',pred)
|
@ -0,0 +1,181 @@
|
||||
import logging
|
||||
from functools import partial
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import tqdm
|
||||
import yaml
|
||||
from joblib import cpu_count
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from adversarial_trainer import GANFactory
|
||||
from dataset import PairedDataset
|
||||
from metric_counter import MetricCounter
|
||||
from models.losses import get_loss
|
||||
from models.models import get_model
|
||||
from models.networks import get_nets
|
||||
from schedulers import LinearDecay, WarmRestart
|
||||
|
||||
cv2.setNumThreads(0)
|
||||
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, config, train: DataLoader, val: DataLoader):
|
||||
self.config = config
|
||||
self.train_dataset = train
|
||||
self.val_dataset = val
|
||||
self.adv_lambda = config['model']['adv_lambda']
|
||||
self.metric_counter = MetricCounter(config['experiment_desc'])
|
||||
self.warmup_epochs = config['warmup_num']
|
||||
|
||||
def train(self):
|
||||
self._init_params()
|
||||
for epoch in range(0, config['num_epochs']):
|
||||
if (epoch == self.warmup_epochs) and not (self.warmup_epochs == 0):
|
||||
self.netG.module.unfreeze()
|
||||
self.optimizer_G = self._get_optim(self.netG.parameters())
|
||||
self.scheduler_G = self._get_scheduler(self.optimizer_G)
|
||||
self._run_epoch(epoch)
|
||||
self._validate(epoch)
|
||||
self.scheduler_G.step()
|
||||
self.scheduler_D.step()
|
||||
|
||||
if self.metric_counter.update_best_model():
|
||||
torch.save({
|
||||
'model': self.netG.state_dict()
|
||||
}, 'best_{}.h5'.format(self.config['experiment_desc']))
|
||||
torch.save({
|
||||
'model': self.netG.state_dict()
|
||||
}, 'last_{}.h5'.format(self.config['experiment_desc']))
|
||||
print(self.metric_counter.loss_message())
|
||||
logging.debug("Experiment Name: %s, Epoch: %d, Loss: %s" % (
|
||||
self.config['experiment_desc'], epoch, self.metric_counter.loss_message()))
|
||||
|
||||
def _run_epoch(self, epoch):
|
||||
self.metric_counter.clear()
|
||||
for param_group in self.optimizer_G.param_groups:
|
||||
lr = param_group['lr']
|
||||
|
||||
epoch_size = config.get('train_batches_per_epoch') or len(self.train_dataset)
|
||||
tq = tqdm.tqdm(self.train_dataset, total=epoch_size)
|
||||
tq.set_description('Epoch {}, lr {}'.format(epoch, lr))
|
||||
i = 0
|
||||
for data in tq:
|
||||
inputs, targets = self.model.get_input(data)
|
||||
outputs = self.netG(inputs)
|
||||
loss_D = self._update_d(outputs, targets)
|
||||
self.optimizer_G.zero_grad()
|
||||
loss_content = self.criterionG(outputs, targets)
|
||||
loss_adv = self.adv_trainer.loss_g(outputs, targets)
|
||||
loss_G = loss_content + self.adv_lambda * loss_adv
|
||||
loss_G.backward()
|
||||
self.optimizer_G.step()
|
||||
self.metric_counter.add_losses(loss_G.item(), loss_content.item(), loss_D)
|
||||
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
|
||||
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
|
||||
tq.set_postfix(loss=self.metric_counter.loss_message())
|
||||
if not i:
|
||||
self.metric_counter.add_image(img_for_vis, tag='train')
|
||||
i += 1
|
||||
if i > epoch_size:
|
||||
break
|
||||
tq.close()
|
||||
self.metric_counter.write_to_tensorboard(epoch)
|
||||
|
||||
def _validate(self, epoch):
|
||||
self.metric_counter.clear()
|
||||
epoch_size = config.get('val_batches_per_epoch') or len(self.val_dataset)
|
||||
tq = tqdm.tqdm(self.val_dataset, total=epoch_size)
|
||||
tq.set_description('Validation')
|
||||
i = 0
|
||||
for data in tq:
|
||||
inputs, targets = self.model.get_input(data)
|
||||
outputs = self.netG(inputs)
|
||||
loss_content = self.criterionG(outputs, targets)
|
||||
loss_adv = self.adv_trainer.loss_g(outputs, targets)
|
||||
loss_G = loss_content + self.adv_lambda * loss_adv
|
||||
self.metric_counter.add_losses(loss_G.item(), loss_content.item())
|
||||
curr_psnr, curr_ssim, img_for_vis = self.model.get_images_and_metrics(inputs, outputs, targets)
|
||||
self.metric_counter.add_metrics(curr_psnr, curr_ssim)
|
||||
if not i:
|
||||
self.metric_counter.add_image(img_for_vis, tag='val')
|
||||
i += 1
|
||||
if i > epoch_size:
|
||||
break
|
||||
tq.close()
|
||||
self.metric_counter.write_to_tensorboard(epoch, validation=True)
|
||||
|
||||
def _update_d(self, outputs, targets):
|
||||
if self.config['model']['d_name'] == 'no_gan':
|
||||
return 0
|
||||
self.optimizer_D.zero_grad()
|
||||
loss_D = self.adv_lambda * self.adv_trainer.loss_d(outputs, targets)
|
||||
loss_D.backward(retain_graph=True)
|
||||
self.optimizer_D.step()
|
||||
return loss_D.item()
|
||||
|
||||
def _get_optim(self, params):
|
||||
if self.config['optimizer']['name'] == 'adam':
|
||||
optimizer = optim.Adam(params, lr=self.config['optimizer']['lr'])
|
||||
elif self.config['optimizer']['name'] == 'sgd':
|
||||
optimizer = optim.SGD(params, lr=self.config['optimizer']['lr'])
|
||||
elif self.config['optimizer']['name'] == 'adadelta':
|
||||
optimizer = optim.Adadelta(params, lr=self.config['optimizer']['lr'])
|
||||
else:
|
||||
raise ValueError("Optimizer [%s] not recognized." % self.config['optimizer']['name'])
|
||||
return optimizer
|
||||
|
||||
def _get_scheduler(self, optimizer):
|
||||
if self.config['scheduler']['name'] == 'plateau':
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
|
||||
mode='min',
|
||||
patience=self.config['scheduler']['patience'],
|
||||
factor=self.config['scheduler']['factor'],
|
||||
min_lr=self.config['scheduler']['min_lr'])
|
||||
elif self.config['optimizer']['name'] == 'sgdr':
|
||||
scheduler = WarmRestart(optimizer)
|
||||
elif self.config['scheduler']['name'] == 'linear':
|
||||
scheduler = LinearDecay(optimizer,
|
||||
min_lr=self.config['scheduler']['min_lr'],
|
||||
num_epochs=self.config['num_epochs'],
|
||||
start_epoch=self.config['scheduler']['start_epoch'])
|
||||
else:
|
||||
raise ValueError("Scheduler [%s] not recognized." % self.config['scheduler']['name'])
|
||||
return scheduler
|
||||
|
||||
@staticmethod
|
||||
def _get_adversarial_trainer(d_name, net_d, criterion_d):
|
||||
if d_name == 'no_gan':
|
||||
return GANFactory.create_model('NoGAN')
|
||||
elif d_name == 'patch_gan' or d_name == 'multi_scale':
|
||||
return GANFactory.create_model('SingleGAN', net_d, criterion_d)
|
||||
elif d_name == 'double_gan':
|
||||
return GANFactory.create_model('DoubleGAN', net_d, criterion_d)
|
||||
else:
|
||||
raise ValueError("Discriminator Network [%s] not recognized." % d_name)
|
||||
|
||||
def _init_params(self):
|
||||
self.criterionG, criterionD = get_loss(self.config['model'])
|
||||
self.netG, netD = get_nets(self.config['model'])
|
||||
self.netG.cuda()
|
||||
self.adv_trainer = self._get_adversarial_trainer(self.config['model']['d_name'], netD, criterionD)
|
||||
self.model = get_model(self.config['model'])
|
||||
self.optimizer_G = self._get_optim(filter(lambda p: p.requires_grad, self.netG.parameters()))
|
||||
self.optimizer_D = self._get_optim(self.adv_trainer.get_params())
|
||||
self.scheduler_G = self._get_scheduler(self.optimizer_G)
|
||||
self.scheduler_D = self._get_scheduler(self.optimizer_D)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with open('config/config.yaml', 'r') as f:
|
||||
config = yaml.load(f)
|
||||
|
||||
batch_size = config.pop('batch_size')
|
||||
get_dataloader = partial(DataLoader, batch_size=batch_size, num_workers=cpu_count(), shuffle=True, drop_last=True)
|
||||
|
||||
datasets = map(config.pop, ('train', 'val'))
|
||||
datasets = map(PairedDataset.from_config, datasets)
|
||||
train, val = map(get_dataloader, datasets)
|
||||
trainer = Trainer(config, train=train, val=val)
|
||||
trainer.train()
|
@ -0,0 +1,33 @@
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
from collections import deque
|
||||
|
||||
|
||||
class ImagePool():
|
||||
def __init__(self, pool_size):
|
||||
self.pool_size = pool_size
|
||||
self.sample_size = pool_size
|
||||
if self.pool_size > 0:
|
||||
self.num_imgs = 0
|
||||
self.images = deque()
|
||||
|
||||
def add(self, images):
|
||||
if self.pool_size == 0:
|
||||
return images
|
||||
for image in images.data:
|
||||
image = torch.unsqueeze(image, 0)
|
||||
if self.num_imgs < self.pool_size:
|
||||
self.num_imgs = self.num_imgs + 1
|
||||
self.images.append(image)
|
||||
else:
|
||||
self.images.popleft()
|
||||
self.images.append(image)
|
||||
|
||||
def query(self):
|
||||
if len(self.images) > self.sample_size:
|
||||
return_images = list(random.sample(self.images, self.sample_size))
|
||||
else:
|
||||
return_images = list(self.images)
|
||||
return torch.cat(return_images, 0)
|
@ -0,0 +1,54 @@
|
||||
import math
|
||||
from math import exp
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
def gaussian(window_size, sigma):
|
||||
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
||||
return gauss / gauss.sum()
|
||||
|
||||
|
||||
def create_window(window_size, channel):
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
||||
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
||||
return window
|
||||
|
||||
|
||||
def SSIM(img1, img2):
|
||||
(_, channel, _, _) = img1.size()
|
||||
window_size = 11
|
||||
window = create_window(window_size, channel)
|
||||
|
||||
if img1.is_cuda:
|
||||
window = window.cuda(img1.get_device())
|
||||
window = window.type_as(img1)
|
||||
|
||||
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
||||
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
||||
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
mu1_mu2 = mu1 * mu2
|
||||
|
||||
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
||||
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
||||
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
||||
|
||||
C1 = 0.01 ** 2
|
||||
C2 = 0.03 ** 2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
||||
return ssim_map.mean()
|
||||
|
||||
|
||||
def PSNR(img1, img2):
|
||||
mse = np.mean((img1 / 255. - img2 / 255.) ** 2)
|
||||
if mse == 0:
|
||||
return 100
|
||||
PIXEL_MAX = 1
|
||||
return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
|
@ -0,0 +1,50 @@
|
||||
import torch.utils.data as data
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
import random
|
||||
|
||||
class BaseDataset(data.Dataset):
|
||||
def __init__(self):
|
||||
super(BaseDataset, self).__init__()
|
||||
|
||||
def name(self):
|
||||
return 'BaseDataset'
|
||||
|
||||
def initialize(self, opt):
|
||||
pass
|
||||
|
||||
def get_transform(opt):
|
||||
transform_list = []
|
||||
if opt.resize_or_crop == 'resize_and_crop':
|
||||
zoom = 1 + 0.1*radom.randint(0,4)
|
||||
osize = [int(400*zoom), int(600*zoom)]
|
||||
transform_list.append(transforms.Scale(osize, Image.BICUBIC))
|
||||
transform_list.append(transforms.RandomCrop(opt.fineSize))
|
||||
elif opt.resize_or_crop == 'crop':
|
||||
transform_list.append(transforms.RandomCrop(opt.fineSize))
|
||||
elif opt.resize_or_crop == 'scale_width':
|
||||
transform_list.append(transforms.Lambda(
|
||||
lambda img: __scale_width(img, opt.fineSize)))
|
||||
elif opt.resize_or_crop == 'scale_width_and_crop':
|
||||
transform_list.append(transforms.Lambda(
|
||||
lambda img: __scale_width(img, opt.loadSize)))
|
||||
transform_list.append(transforms.RandomCrop(opt.fineSize))
|
||||
# elif opt.resize_or_crop == 'no':
|
||||
# osize = [384, 512]
|
||||
# transform_list.append(transforms.Scale(osize, Image.BICUBIC))
|
||||
|
||||
if opt.isTrain and not opt.no_flip:
|
||||
transform_list.append(transforms.RandomHorizontalFlip())
|
||||
|
||||
transform_list += [transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5),
|
||||
(0.5, 0.5, 0.5))]
|
||||
return transforms.Compose(transform_list)
|
||||
|
||||
def __scale_width(img, target_width):
|
||||
ow, oh = img.size
|
||||
if (ow == target_width):
|
||||
return img
|
||||
w = target_width
|
||||
h = int(target_width * oh / ow)
|
||||
return img.resize((w, h), Image.BICUBIC)
|
@ -0,0 +1,61 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
|
||||
class BaseModel():
|
||||
def name(self):
|
||||
return 'BaseModel'
|
||||
|
||||
def initialize(self, opt):
|
||||
self.opt = opt
|
||||
self.gpu_ids = opt.gpu_ids
|
||||
self.isTrain = opt.isTrain
|
||||
self.Tensor = torch.cuda.FloatTensor if (torch.cuda.is_available() and not opt.cFlag)else torch.Tensor
|
||||
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
||||
|
||||
def set_input(self, input):
|
||||
self.input = input
|
||||
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
# used in test time, no backprop
|
||||
def test(self):
|
||||
pass
|
||||
|
||||
def get_image_paths(self):
|
||||
pass
|
||||
|
||||
def optimize_parameters(self):
|
||||
pass
|
||||
|
||||
def get_current_visuals(self):
|
||||
return self.input
|
||||
|
||||
def get_current_errors(self):
|
||||
return {}
|
||||
|
||||
def save(self, label):
|
||||
pass
|
||||
|
||||
# helper saving function that can be used by subclasses
|
||||
def save_network(self, network, network_label, epoch_label, gpu_ids):
|
||||
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
||||
save_path = os.path.join(self.save_dir, save_filename)
|
||||
torch.save(network.cpu().state_dict(), save_path)
|
||||
if len(gpu_ids) and torch.cuda.is_available():
|
||||
network.cuda(device=gpu_ids[0])
|
||||
|
||||
# helper loading function that can be used by subclasses
|
||||
def load_network(self, network, network_label, epoch_label):
|
||||
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
|
||||
save_path = os.path.join(self.save_dir, save_filename)
|
||||
if torch.cuda.is_available():
|
||||
ckpt = torch.load(save_path)
|
||||
else:
|
||||
ckpt = torch.load(save_path, map_location=torch.device("cpu"))
|
||||
ckpt = {key.replace("module.", ""): value for key, value in ckpt.items()}
|
||||
network.load_state_dict(ckpt)
|
||||
|
||||
def update_learning_rate():
|
||||
pass
|
@ -0,0 +1,38 @@
|
||||
|
||||
def create_model(opt):
|
||||
model = None
|
||||
print(opt.model)
|
||||
if opt.model == 'cycle_gan':
|
||||
assert(opt.dataset_mode == 'unaligned')
|
||||
from .cycle_gan_model import CycleGANModel
|
||||
model = CycleGANModel()
|
||||
elif opt.model == 'pix2pix':
|
||||
assert(opt.dataset_mode == 'pix2pix')
|
||||
from .pix2pix_model import Pix2PixModel
|
||||
model = Pix2PixModel()
|
||||
elif opt.model == 'pair':
|
||||
# assert(opt.dataset_mode == 'pair')
|
||||
# from .pair_model import PairModel
|
||||
from .Unet_L1 import PairModel
|
||||
model = PairModel()
|
||||
elif opt.model == 'single':
|
||||
# assert(opt.dataset_mode == 'unaligned')
|
||||
from .single_model import SingleModel
|
||||
model = SingleModel()
|
||||
elif opt.model == 'temp':
|
||||
# assert(opt.dataset_mode == 'unaligned')
|
||||
from .temp_model import TempModel
|
||||
model = TempModel()
|
||||
elif opt.model == 'UNIT':
|
||||
assert(opt.dataset_mode == 'unaligned')
|
||||
from .unit_model import UNITModel
|
||||
model = UNITModel()
|
||||
elif opt.model == 'test':
|
||||
assert(opt.dataset_mode == 'single')
|
||||
from .test_model import TestModel
|
||||
model = TestModel()
|
||||
else:
|
||||
raise ValueError("Model [%s] not recognized." % opt.model)
|
||||
model.initialize(opt)
|
||||
print("model [%s] was created" % (model.name()))
|
||||
return model
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,496 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from torch.autograd import Variable
|
||||
import enlighten_util.util as util
|
||||
from collections import OrderedDict
|
||||
from torch.autograd import Variable
|
||||
import itertools
|
||||
import enlighten_util.util as util
|
||||
from enlighten_util.image_pool import ImagePool
|
||||
from .base_model import BaseModel
|
||||
import random
|
||||
from . import networks
|
||||
import sys
|
||||
|
||||
|
||||
class SingleModel(BaseModel):
|
||||
def name(self):
|
||||
return 'SingleGANModel'
|
||||
|
||||
def initialize(self, opt):
|
||||
BaseModel.initialize(self, opt)
|
||||
|
||||
nb = opt.batchSize
|
||||
size = opt.fineSize
|
||||
self.opt = opt
|
||||
self.input_A = self.Tensor(nb, opt.input_nc, size, size)
|
||||
self.input_B = self.Tensor(nb, opt.output_nc, size, size)
|
||||
self.input_img = self.Tensor(nb, opt.input_nc, size, size)
|
||||
self.input_A_gray = self.Tensor(nb, 1, size, size)
|
||||
|
||||
if opt.vgg > 0:
|
||||
self.vgg_loss = networks.PerceptualLoss(opt)
|
||||
if self.opt.IN_vgg:
|
||||
self.vgg_patch_loss = networks.PerceptualLoss(opt)
|
||||
self.vgg_patch_loss.cuda()
|
||||
self.vgg_loss.cuda()
|
||||
self.vgg = networks.load_vgg16("./model", self.gpu_ids)
|
||||
self.vgg.eval()
|
||||
for param in self.vgg.parameters():
|
||||
param.requires_grad = False
|
||||
elif opt.fcn > 0:
|
||||
self.fcn_loss = networks.SemanticLoss(opt)
|
||||
self.fcn_loss.cuda()
|
||||
self.fcn = networks.load_fcn("./model")
|
||||
self.fcn.eval()
|
||||
for param in self.fcn.parameters():
|
||||
param.requires_grad = False
|
||||
# load/define networks
|
||||
# The naming conversion is different from those used in the paper
|
||||
# Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
|
||||
|
||||
skip = True if opt.skip > 0 else False
|
||||
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
|
||||
opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=skip, opt=opt)
|
||||
# self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
|
||||
# opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, self.gpu_ids, skip=False, opt=opt)
|
||||
|
||||
if self.isTrain:
|
||||
use_sigmoid = opt.no_lsgan
|
||||
self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
|
||||
opt.which_model_netD,
|
||||
opt.n_layers_D, opt.norm, use_sigmoid, self.gpu_ids, False)
|
||||
if self.opt.patchD:
|
||||
self.netD_P = networks.define_D(opt.input_nc, opt.ndf,
|
||||
opt.which_model_netD,
|
||||
opt.n_layers_patchD, opt.norm, use_sigmoid, self.gpu_ids, True)
|
||||
if not self.isTrain or opt.continue_train:
|
||||
which_epoch = opt.which_epoch
|
||||
self.load_network(self.netG_A, 'G_A', which_epoch)
|
||||
# self.load_network(self.netG_B, 'G_B', which_epoch)
|
||||
if self.isTrain:
|
||||
self.load_network(self.netD_A, 'D_A', which_epoch)
|
||||
if self.opt.patchD:
|
||||
self.load_network(self.netD_P, 'D_P', which_epoch)
|
||||
|
||||
if self.isTrain:
|
||||
self.old_lr = opt.lr
|
||||
# self.fake_A_pool = ImagePool(opt.pool_size)
|
||||
self.fake_B_pool = ImagePool(opt.pool_size)
|
||||
# define loss functions
|
||||
if opt.use_wgan:
|
||||
self.criterionGAN = networks.DiscLossWGANGP()
|
||||
else:
|
||||
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
|
||||
if opt.use_mse:
|
||||
self.criterionCycle = torch.nn.MSELoss()
|
||||
else:
|
||||
self.criterionCycle = torch.nn.L1Loss()
|
||||
self.criterionL1 = torch.nn.L1Loss()
|
||||
self.criterionIdt = torch.nn.L1Loss()
|
||||
# initialize optimizers
|
||||
self.optimizer_G = torch.optim.Adam(self.netG_A.parameters(),
|
||||
lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
if self.opt.patchD:
|
||||
self.optimizer_D_P = torch.optim.Adam(self.netD_P.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
|
||||
|
||||
print('---------- Networks initialized -------------')
|
||||
networks.print_network(self.netG_A)
|
||||
# networks.print_network(self.netG_B)
|
||||
if self.isTrain:
|
||||
networks.print_network(self.netD_A)
|
||||
if self.opt.patchD:
|
||||
networks.print_network(self.netD_P)
|
||||
# networks.print_network(self.netD_B)
|
||||
if opt.isTrain:
|
||||
self.netG_A.train()
|
||||
# self.netG_B.train()
|
||||
else:
|
||||
self.netG_A.eval()
|
||||
# self.netG_B.eval()
|
||||
print('-----------------------------------------------')
|
||||
|
||||
def set_input(self, input):
|
||||
AtoB = self.opt.which_direction == 'AtoB'
|
||||
input_A = input['A' if AtoB else 'B']
|
||||
input_B = input['B' if AtoB else 'A']
|
||||
input_img = input['input_img']
|
||||
input_A_gray = input['A_gray']
|
||||
self.input_A.resize_(input_A.size()).copy_(input_A)
|
||||
self.input_A_gray.resize_(input_A_gray.size()).copy_(input_A_gray)
|
||||
self.input_B.resize_(input_B.size()).copy_(input_B)
|
||||
self.input_img.resize_(input_img.size()).copy_(input_img)
|
||||
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
||||
|
||||
|
||||
|
||||
|
||||
def test(self):
|
||||
self.real_A = Variable(self.input_A, volatile=True)
|
||||
self.real_A_gray = Variable(self.input_A_gray, volatile=True)
|
||||
if self.opt.noise > 0:
|
||||
self.noise = Variable(torch.cuda.FloatTensor(self.real_A.size()).normal_(mean=0, std=self.opt.noise/255.))
|
||||
self.real_A = self.real_A + self.noise
|
||||
if self.opt.input_linear:
|
||||
self.real_A = (self.real_A - torch.min(self.real_A))/(torch.max(self.real_A) - torch.min(self.real_A))
|
||||
# print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:])
|
||||
if self.opt.skip == 1:
|
||||
self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A, self.real_A_gray)
|
||||
else:
|
||||
self.fake_B = self.netG_A.forward(self.real_A, self.real_A_gray)
|
||||
# self.rec_A = self.netG_B.forward(self.fake_B)
|
||||
|
||||
self.real_B = Variable(self.input_B, volatile=True)
|
||||
|
||||
|
||||
def predict(self):
|
||||
self.real_A = Variable(self.input_A, volatile=True)
|
||||
self.real_A_gray = Variable(self.input_A_gray, volatile=True)
|
||||
if self.opt.noise > 0:
|
||||
self.noise = Variable(torch.cuda.FloatTensor(self.real_A.size()).normal_(mean=0, std=self.opt.noise/255.))
|
||||
self.real_A = self.real_A + self.noise
|
||||
if self.opt.input_linear:
|
||||
self.real_A = (self.real_A - torch.min(self.real_A))/(torch.max(self.real_A) - torch.min(self.real_A))
|
||||
# print(np.transpose(self.real_A.data[0].cpu().float().numpy(),(1,2,0))[:2][:2][:])
|
||||
if self.opt.skip == 1:
|
||||
self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_A, self.real_A_gray)
|
||||
else:
|
||||
self.fake_B = self.netG_A.forward(self.real_A, self.real_A_gray)
|
||||
# self.rec_A = self.netG_B.forward(self.fake_B)
|
||||
|
||||
real_A = util.tensor2im(self.real_A.data)
|
||||
fake_B = util.tensor2im(self.fake_B.data)
|
||||
A_gray = util.atten2im(self.real_A_gray.data)
|
||||
# rec_A = util.tensor2im(self.rec_A.data)
|
||||
# if self.opt.skip == 1:
|
||||
# latent_real_A = util.tensor2im(self.latent_real_A.data)
|
||||
# latent_show = util.latent2im(self.latent_real_A.data)
|
||||
# max_image = util.max2im(self.fake_B.data, self.latent_real_A.data)
|
||||
# return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A),
|
||||
# ('latent_show', latent_show), ('max_image', max_image), ('A_gray', A_gray)])
|
||||
# else:
|
||||
# return OrderedDict([('real_A', real_A), ('fake_B', fake_B)])
|
||||
# return OrderedDict([('fake_B', fake_B)])
|
||||
return OrderedDict([('real_A', real_A), ('fake_B', fake_B)])
|
||||
|
||||
# get image paths
|
||||
def get_image_paths(self):
|
||||
return self.image_paths
|
||||
|
||||
def backward_D_basic(self, netD, real, fake, use_ragan):
|
||||
# Real
|
||||
pred_real = netD.forward(real)
|
||||
pred_fake = netD.forward(fake.detach())
|
||||
if self.opt.use_wgan:
|
||||
loss_D_real = pred_real.mean()
|
||||
loss_D_fake = pred_fake.mean()
|
||||
loss_D = loss_D_fake - loss_D_real + self.criterionGAN.calc_gradient_penalty(netD,
|
||||
real.data, fake.data)
|
||||
elif self.opt.use_ragan and use_ragan:
|
||||
loss_D = (self.criterionGAN(pred_real - torch.mean(pred_fake), True) +
|
||||
self.criterionGAN(pred_fake - torch.mean(pred_real), False)) / 2
|
||||
else:
|
||||
loss_D_real = self.criterionGAN(pred_real, True)
|
||||
loss_D_fake = self.criterionGAN(pred_fake, False)
|
||||
loss_D = (loss_D_real + loss_D_fake) * 0.5
|
||||
# loss_D.backward()
|
||||
return loss_D
|
||||
|
||||
def backward_D_A(self):
|
||||
fake_B = self.fake_B_pool.query(self.fake_B)
|
||||
fake_B = self.fake_B
|
||||
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B, True)
|
||||
self.loss_D_A.backward()
|
||||
|
||||
def backward_D_P(self):
|
||||
if self.opt.hybrid_loss:
|
||||
loss_D_P = self.backward_D_basic(self.netD_P, self.real_patch, self.fake_patch, False)
|
||||
if self.opt.patchD_3 > 0:
|
||||
for i in range(self.opt.patchD_3):
|
||||
loss_D_P += self.backward_D_basic(self.netD_P, self.real_patch_1[i], self.fake_patch_1[i], False)
|
||||
self.loss_D_P = loss_D_P/float(self.opt.patchD_3 + 1)
|
||||
else:
|
||||
self.loss_D_P = loss_D_P
|
||||
else:
|
||||
loss_D_P = self.backward_D_basic(self.netD_P, self.real_patch, self.fake_patch, True)
|
||||
if self.opt.patchD_3 > 0:
|
||||
for i in range(self.opt.patchD_3):
|
||||
loss_D_P += self.backward_D_basic(self.netD_P, self.real_patch_1[i], self.fake_patch_1[i], True)
|
||||
self.loss_D_P = loss_D_P/float(self.opt.patchD_3 + 1)
|
||||
else:
|
||||
self.loss_D_P = loss_D_P
|
||||
if self.opt.D_P_times2:
|
||||
self.loss_D_P = self.loss_D_P*2
|
||||
self.loss_D_P.backward()
|
||||
|
||||
# def backward_D_B(self):
|
||||
# fake_A = self.fake_A_pool.query(self.fake_A)
|
||||
# self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
|
||||
def forward(self):
|
||||
self.real_A = Variable(self.input_A)
|
||||
self.real_B = Variable(self.input_B)
|
||||
self.real_A_gray = Variable(self.input_A_gray)
|
||||
self.real_img = Variable(self.input_img)
|
||||
if self.opt.noise > 0:
|
||||
self.noise = Variable(torch.cuda.FloatTensor(self.real_A.size()).normal_(mean=0, std=self.opt.noise/255.))
|
||||
self.real_A = self.real_A + self.noise
|
||||
if self.opt.input_linear:
|
||||
self.real_A = (self.real_A - torch.min(self.real_A))/(torch.max(self.real_A) - torch.min(self.real_A))
|
||||
if self.opt.skip == 1:
|
||||
self.fake_B, self.latent_real_A = self.netG_A.forward(self.real_img, self.real_A_gray)
|
||||
else:
|
||||
self.fake_B = self.netG_A.forward(self.real_img, self.real_A_gray)
|
||||
if self.opt.patchD:
|
||||
w = self.real_A.size(3)
|
||||
h = self.real_A.size(2)
|
||||
w_offset = random.randint(0, max(0, w - self.opt.patchSize - 1))
|
||||
h_offset = random.randint(0, max(0, h - self.opt.patchSize - 1))
|
||||
|
||||
self.fake_patch = self.fake_B[:,:, h_offset:h_offset + self.opt.patchSize,
|
||||
w_offset:w_offset + self.opt.patchSize]
|
||||
self.real_patch = self.real_B[:,:, h_offset:h_offset + self.opt.patchSize,
|
||||
w_offset:w_offset + self.opt.patchSize]
|
||||
self.input_patch = self.real_A[:,:, h_offset:h_offset + self.opt.patchSize,
|
||||
w_offset:w_offset + self.opt.patchSize]
|
||||
if self.opt.patchD_3 > 0:
|
||||
self.fake_patch_1 = []
|
||||
self.real_patch_1 = []
|
||||
self.input_patch_1 = []
|
||||
w = self.real_A.size(3)
|
||||
h = self.real_A.size(2)
|
||||
for i in range(self.opt.patchD_3):
|
||||
w_offset_1 = random.randint(0, max(0, w - self.opt.patchSize - 1))
|
||||
h_offset_1 = random.randint(0, max(0, h - self.opt.patchSize - 1))
|
||||
self.fake_patch_1.append(self.fake_B[:,:, h_offset_1:h_offset_1 + self.opt.patchSize,
|
||||
w_offset_1:w_offset_1 + self.opt.patchSize])
|
||||
self.real_patch_1.append(self.real_B[:,:, h_offset_1:h_offset_1 + self.opt.patchSize,
|
||||
w_offset_1:w_offset_1 + self.opt.patchSize])
|
||||
self.input_patch_1.append(self.real_A[:,:, h_offset_1:h_offset_1 + self.opt.patchSize,
|
||||
w_offset_1:w_offset_1 + self.opt.patchSize])
|
||||
|
||||
# w_offset_2 = random.randint(0, max(0, w - self.opt.patchSize - 1))
|
||||
# h_offset_2 = random.randint(0, max(0, h - self.opt.patchSize - 1))
|
||||
# self.fake_patch_2 = self.fake_B[:,:, h_offset_2:h_offset_2 + self.opt.patchSize,
|
||||
# w_offset_2:w_offset_2 + self.opt.patchSize]
|
||||
# self.real_patch_2 = self.real_B[:,:, h_offset_2:h_offset_2 + self.opt.patchSize,
|
||||
# w_offset_2:w_offset_2 + self.opt.patchSize]
|
||||
# self.input_patch_2 = self.real_A[:,:, h_offset_2:h_offset_2 + self.opt.patchSize,
|
||||
# w_offset_2:w_offset_2 + self.opt.patchSize]
|
||||
|
||||
def backward_G(self, epoch):
|
||||
pred_fake = self.netD_A.forward(self.fake_B)
|
||||
if self.opt.use_wgan:
|
||||
self.loss_G_A = -pred_fake.mean()
|
||||
elif self.opt.use_ragan:
|
||||
pred_real = self.netD_A.forward(self.real_B)
|
||||
|
||||
self.loss_G_A = (self.criterionGAN(pred_real - torch.mean(pred_fake), False) +
|
||||
self.criterionGAN(pred_fake - torch.mean(pred_real), True)) / 2
|
||||
|
||||
else:
|
||||
self.loss_G_A = self.criterionGAN(pred_fake, True)
|
||||
|
||||
loss_G_A = 0
|
||||
if self.opt.patchD:
|
||||
pred_fake_patch = self.netD_P.forward(self.fake_patch)
|
||||
if self.opt.hybrid_loss:
|
||||
loss_G_A += self.criterionGAN(pred_fake_patch, True)
|
||||
else:
|
||||
pred_real_patch = self.netD_P.forward(self.real_patch)
|
||||
|
||||
loss_G_A += (self.criterionGAN(pred_real_patch - torch.mean(pred_fake_patch), False) +
|
||||
self.criterionGAN(pred_fake_patch - torch.mean(pred_real_patch), True)) / 2
|
||||
if self.opt.patchD_3 > 0:
|
||||
for i in range(self.opt.patchD_3):
|
||||
pred_fake_patch_1 = self.netD_P.forward(self.fake_patch_1[i])
|
||||
if self.opt.hybrid_loss:
|
||||
loss_G_A += self.criterionGAN(pred_fake_patch_1, True)
|
||||
else:
|
||||
pred_real_patch_1 = self.netD_P.forward(self.real_patch_1[i])
|
||||
|
||||
loss_G_A += (self.criterionGAN(pred_real_patch_1 - torch.mean(pred_fake_patch_1), False) +
|
||||
self.criterionGAN(pred_fake_patch_1 - torch.mean(pred_real_patch_1), True)) / 2
|
||||
|
||||
if not self.opt.D_P_times2:
|
||||
self.loss_G_A += loss_G_A/float(self.opt.patchD_3 + 1)
|
||||
else:
|
||||
self.loss_G_A += loss_G_A/float(self.opt.patchD_3 + 1)*2
|
||||
else:
|
||||
if not self.opt.D_P_times2:
|
||||
self.loss_G_A += loss_G_A
|
||||
else:
|
||||
self.loss_G_A += loss_G_A*2
|
||||
|
||||
if epoch < 0:
|
||||
vgg_w = 0
|
||||
else:
|
||||
vgg_w = 1
|
||||
if self.opt.vgg > 0:
|
||||
self.loss_vgg_b = self.vgg_loss.compute_vgg_loss(self.vgg,
|
||||
self.fake_B, self.real_A) * self.opt.vgg if self.opt.vgg > 0 else 0
|
||||
if self.opt.patch_vgg:
|
||||
if not self.opt.IN_vgg:
|
||||
loss_vgg_patch = self.vgg_loss.compute_vgg_loss(self.vgg,
|
||||
self.fake_patch, self.input_patch) * self.opt.vgg
|
||||
else:
|
||||
loss_vgg_patch = self.vgg_patch_loss.compute_vgg_loss(self.vgg,
|
||||
self.fake_patch, self.input_patch) * self.opt.vgg
|
||||
if self.opt.patchD_3 > 0:
|
||||
for i in range(self.opt.patchD_3):
|
||||
if not self.opt.IN_vgg:
|
||||
loss_vgg_patch += self.vgg_loss.compute_vgg_loss(self.vgg,
|
||||
self.fake_patch_1[i], self.input_patch_1[i]) * self.opt.vgg
|
||||
else:
|
||||
loss_vgg_patch += self.vgg_patch_loss.compute_vgg_loss(self.vgg,
|
||||
self.fake_patch_1[i], self.input_patch_1[i]) * self.opt.vgg
|
||||
self.loss_vgg_b += loss_vgg_patch/float(self.opt.patchD_3 + 1)
|
||||
else:
|
||||
self.loss_vgg_b += loss_vgg_patch
|
||||
self.loss_G = self.loss_G_A + self.loss_vgg_b*vgg_w
|
||||
elif self.opt.fcn > 0:
|
||||
self.loss_fcn_b = self.fcn_loss.compute_fcn_loss(self.fcn,
|
||||
self.fake_B, self.real_A) * self.opt.fcn if self.opt.fcn > 0 else 0
|
||||
if self.opt.patchD:
|
||||
loss_fcn_patch = self.fcn_loss.compute_vgg_loss(self.fcn,
|
||||
self.fake_patch, self.input_patch) * self.opt.fcn
|
||||
if self.opt.patchD_3 > 0:
|
||||
for i in range(self.opt.patchD_3):
|
||||
loss_fcn_patch += self.fcn_loss.compute_vgg_loss(self.fcn,
|
||||
self.fake_patch_1[i], self.input_patch_1[i]) * self.opt.fcn
|
||||
self.loss_fcn_b += loss_fcn_patch/float(self.opt.patchD_3 + 1)
|
||||
else:
|
||||
self.loss_fcn_b += loss_fcn_patch
|
||||
self.loss_G = self.loss_G_A + self.loss_fcn_b*vgg_w
|
||||
# self.loss_G = self.L1_AB + self.L1_BA
|
||||
self.loss_G.backward()
|
||||
|
||||
|
||||
# def optimize_parameters(self, epoch):
|
||||
# # forward
|
||||
# self.forward()
|
||||
# # G_A and G_B
|
||||
# self.optimizer_G.zero_grad()
|
||||
# self.backward_G(epoch)
|
||||
# self.optimizer_G.step()
|
||||
# # D_A
|
||||
# self.optimizer_D_A.zero_grad()
|
||||
# self.backward_D_A()
|
||||
# self.optimizer_D_A.step()
|
||||
# if self.opt.patchD:
|
||||
# self.forward()
|
||||
# self.optimizer_D_P.zero_grad()
|
||||
# self.backward_D_P()
|
||||
# self.optimizer_D_P.step()
|
||||
# D_B
|
||||
# self.optimizer_D_B.zero_grad()
|
||||
# self.backward_D_B()
|
||||
# self.optimizer_D_B.step()
|
||||
def optimize_parameters(self, epoch):
|
||||
# forward
|
||||
self.forward()
|
||||
# G_A and G_B
|
||||
self.optimizer_G.zero_grad()
|
||||
self.backward_G(epoch)
|
||||
self.optimizer_G.step()
|
||||
# D_A
|
||||
self.optimizer_D_A.zero_grad()
|
||||
self.backward_D_A()
|
||||
if not self.opt.patchD:
|
||||
self.optimizer_D_A.step()
|
||||
else:
|
||||
# self.forward()
|
||||
self.optimizer_D_P.zero_grad()
|
||||
self.backward_D_P()
|
||||
self.optimizer_D_A.step()
|
||||
self.optimizer_D_P.step()
|
||||
|
||||
|
||||
def get_current_errors(self, epoch):
|
||||
D_A = self.loss_D_A.data[0]
|
||||
D_P = self.loss_D_P.data[0] if self.opt.patchD else 0
|
||||
G_A = self.loss_G_A.data[0]
|
||||
if self.opt.vgg > 0:
|
||||
vgg = self.loss_vgg_b.data[0]/self.opt.vgg if self.opt.vgg > 0 else 0
|
||||
return OrderedDict([('D_A', D_A), ('G_A', G_A), ("vgg", vgg), ("D_P", D_P)])
|
||||
elif self.opt.fcn > 0:
|
||||
fcn = self.loss_fcn_b.data[0]/self.opt.fcn if self.opt.fcn > 0 else 0
|
||||
return OrderedDict([('D_A', D_A), ('G_A', G_A), ("fcn", fcn), ("D_P", D_P)])
|
||||
|
||||
|
||||
def get_current_visuals(self):
|
||||
real_A = util.tensor2im(self.real_A.data)
|
||||
fake_B = util.tensor2im(self.fake_B.data)
|
||||
real_B = util.tensor2im(self.real_B.data)
|
||||
if self.opt.skip > 0:
|
||||
latent_real_A = util.tensor2im(self.latent_real_A.data)
|
||||
latent_show = util.latent2im(self.latent_real_A.data)
|
||||
if self.opt.patchD:
|
||||
fake_patch = util.tensor2im(self.fake_patch.data)
|
||||
real_patch = util.tensor2im(self.real_patch.data)
|
||||
if self.opt.patch_vgg:
|
||||
input_patch = util.tensor2im(self.input_patch.data)
|
||||
if not self.opt.self_attention:
|
||||
return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A),
|
||||
('latent_show', latent_show), ('real_B', real_B), ('real_patch', real_patch),
|
||||
('fake_patch', fake_patch), ('input_patch', input_patch)])
|
||||
else:
|
||||
self_attention = util.atten2im(self.real_A_gray.data)
|
||||
return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A),
|
||||
('latent_show', latent_show), ('real_B', real_B), ('real_patch', real_patch),
|
||||
('fake_patch', fake_patch), ('input_patch', input_patch), ('self_attention', self_attention)])
|
||||
else:
|
||||
if not self.opt.self_attention:
|
||||
return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A),
|
||||
('latent_show', latent_show), ('real_B', real_B), ('real_patch', real_patch),
|
||||
('fake_patch', fake_patch)])
|
||||
else:
|
||||
self_attention = util.atten2im(self.real_A_gray.data)
|
||||
return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A),
|
||||
('latent_show', latent_show), ('real_B', real_B), ('real_patch', real_patch),
|
||||
('fake_patch', fake_patch), ('self_attention', self_attention)])
|
||||
else:
|
||||
if not self.opt.self_attention:
|
||||
return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('latent_real_A', latent_real_A),
|
||||
('latent_show', latent_show), ('real_B', real_B)])
|
||||
else:
|
||||
self_attention = util.atten2im(self.real_A_gray.data)
|
||||
return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B),
|
||||
('latent_real_A', latent_real_A), ('latent_show', latent_show),
|
||||
('self_attention', self_attention)])
|
||||
else:
|
||||
if not self.opt.self_attention:
|
||||
return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B)])
|
||||
else:
|
||||
self_attention = util.atten2im(self.real_A_gray.data)
|
||||
return OrderedDict([('real_A', real_A), ('fake_B', fake_B), ('real_B', real_B),
|
||||
('self_attention', self_attention)])
|
||||
|
||||
def save(self, label):
|
||||
self.save_network(self.netG_A, 'G_A', label, self.gpu_ids)
|
||||
self.save_network(self.netD_A, 'D_A', label, self.gpu_ids)
|
||||
if self.opt.patchD:
|
||||
self.save_network(self.netD_P, 'D_P', label, self.gpu_ids)
|
||||
# self.save_network(self.netG_B, 'G_B', label, self.gpu_ids)
|
||||
# self.save_network(self.netD_B, 'D_B', label, self.gpu_ids)
|
||||
|
||||
def update_learning_rate(self):
|
||||
|
||||
if self.opt.new_lr:
|
||||
lr = self.old_lr/2
|
||||
else:
|
||||
lrd = self.opt.lr / self.opt.niter_decay
|
||||
lr = self.old_lr - lrd
|
||||
for param_group in self.optimizer_D_A.param_groups:
|
||||
param_group['lr'] = lr
|
||||
if self.opt.patchD:
|
||||
for param_group in self.optimizer_D_P.param_groups:
|
||||
param_group['lr'] = lr
|
||||
for param_group in self.optimizer_G.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
print('update learning rate: %f -> %f' % (self.old_lr, lr))
|
||||
self.old_lr = lr
|
@ -0,0 +1,32 @@
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
class ImagePool():
|
||||
def __init__(self, pool_size):
|
||||
self.pool_size = pool_size
|
||||
if self.pool_size > 0:
|
||||
self.num_imgs = 0
|
||||
self.images = []
|
||||
|
||||
def query(self, images):
|
||||
if self.pool_size == 0:
|
||||
return images
|
||||
return_images = []
|
||||
for image in images.data:
|
||||
image = torch.unsqueeze(image, 0)
|
||||
if self.num_imgs < self.pool_size:
|
||||
self.num_imgs = self.num_imgs + 1
|
||||
self.images.append(image)
|
||||
return_images.append(image)
|
||||
else:
|
||||
p = random.uniform(0, 1)
|
||||
if p > 0.5:
|
||||
random_id = random.randint(0, self.pool_size-1)
|
||||
tmp = self.images[random_id].clone()
|
||||
self.images[random_id] = image
|
||||
return_images.append(tmp)
|
||||
else:
|
||||
return_images.append(image)
|
||||
return_images = Variable(torch.cat(return_images, 0))
|
||||
return return_images
|
@ -0,0 +1,182 @@
|
||||
# from __future__ import print_function
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import inspect, re
|
||||
import numpy as np
|
||||
import torch
|
||||
import os
|
||||
import collections
|
||||
from torch.optim import lr_scheduler
|
||||
import torch.nn.init as init
|
||||
|
||||
|
||||
# Converts a Tensor into a Numpy array
|
||||
# |imtype|: the desired type of the converted numpy array
|
||||
def tensor2im(image_tensor, imtype=np.uint8):
|
||||
image_numpy = image_tensor[0].cpu().float().numpy()
|
||||
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
|
||||
image_numpy = np.maximum(image_numpy, 0)
|
||||
image_numpy = np.minimum(image_numpy, 255)
|
||||
return image_numpy.astype(imtype)
|
||||
|
||||
def atten2im(image_tensor, imtype=np.uint8):
|
||||
image_tensor = image_tensor[0]
|
||||
image_tensor = torch.cat((image_tensor, image_tensor, image_tensor), 0)
|
||||
image_numpy = image_tensor.cpu().float().numpy()
|
||||
image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0
|
||||
image_numpy = image_numpy/(image_numpy.max()/255.0)
|
||||
return image_numpy.astype(imtype)
|
||||
|
||||
def latent2im(image_tensor, imtype=np.uint8):
|
||||
# image_tensor = (image_tensor - torch.min(image_tensor))/(torch.max(image_tensor)-torch.min(image_tensor))
|
||||
image_numpy = image_tensor[0].cpu().float().numpy()
|
||||
image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0
|
||||
image_numpy = np.maximum(image_numpy, 0)
|
||||
image_numpy = np.minimum(image_numpy, 255)
|
||||
return image_numpy.astype(imtype)
|
||||
|
||||
def max2im(image_1, image_2, imtype=np.uint8):
|
||||
image_1 = image_1[0].cpu().float().numpy()
|
||||
image_2 = image_2[0].cpu().float().numpy()
|
||||
image_1 = (np.transpose(image_1, (1, 2, 0)) + 1) / 2.0 * 255.0
|
||||
image_2 = (np.transpose(image_2, (1, 2, 0))) * 255.0
|
||||
output = np.maximum(image_1, image_2)
|
||||
output = np.maximum(output, 0)
|
||||
output = np.minimum(output, 255)
|
||||
return output.astype(imtype)
|
||||
|
||||
def variable2im(image_tensor, imtype=np.uint8):
|
||||
image_numpy = image_tensor[0].data.cpu().float().numpy()
|
||||
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
|
||||
return image_numpy.astype(imtype)
|
||||
|
||||
|
||||
def diagnose_network(net, name='network'):
|
||||
mean = 0.0
|
||||
count = 0
|
||||
for param in net.parameters():
|
||||
if param.grad is not None:
|
||||
mean += torch.mean(torch.abs(param.grad.data))
|
||||
count += 1
|
||||
if count > 0:
|
||||
mean = mean / count
|
||||
print(name)
|
||||
print(mean)
|
||||
|
||||
|
||||
def save_image(image_numpy, image_path):
|
||||
image_pil = Image.fromarray(image_numpy)
|
||||
image_pil.save(image_path)
|
||||
|
||||
def info(object, spacing=10, collapse=1):
|
||||
"""Print methods and doc strings.
|
||||
Takes module, class, list, dictionary, or string."""
|
||||
methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)]
|
||||
processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s)
|
||||
print( "\n".join(["%s %s" %
|
||||
(method.ljust(spacing),
|
||||
processFunc(str(getattr(object, method).__doc__)))
|
||||
for method in methodList]) )
|
||||
|
||||
def varname(p):
|
||||
for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]:
|
||||
m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line)
|
||||
if m:
|
||||
return m.group(1)
|
||||
|
||||
def print_numpy(x, val=True, shp=False):
|
||||
x = x.astype(np.float64)
|
||||
if shp:
|
||||
print('shape,', x.shape)
|
||||
if val:
|
||||
x = x.flatten()
|
||||
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
|
||||
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
|
||||
|
||||
|
||||
def mkdirs(paths):
|
||||
if isinstance(paths, list) and not isinstance(paths, str):
|
||||
for path in paths:
|
||||
mkdir(path)
|
||||
else:
|
||||
mkdir(paths)
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
def get_model_list(dirname, key):
|
||||
if os.path.exists(dirname) is False:
|
||||
return None
|
||||
gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if
|
||||
os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f]
|
||||
if gen_models is None:
|
||||
return None
|
||||
gen_models.sort()
|
||||
last_model_name = gen_models[-1]
|
||||
return last_model_name
|
||||
|
||||
|
||||
def load_vgg16(model_dir):
|
||||
""" Use the model from https://github.com/abhiskk/fast-neural-style/blob/master/neural_style/utils.py """
|
||||
if not os.path.exists(model_dir):
|
||||
os.mkdir(model_dir)
|
||||
if not os.path.exists(os.path.join(model_dir, 'vgg16.weight')):
|
||||
if not os.path.exists(os.path.join(model_dir, 'vgg16.t7')):
|
||||
os.system('wget https://www.dropbox.com/s/76l3rt4kyi3s8x7/vgg16.t7?dl=1 -O ' + os.path.join(model_dir, 'vgg16.t7'))
|
||||
vgglua = load_lua(os.path.join(model_dir, 'vgg16.t7'))
|
||||
vgg = Vgg16()
|
||||
for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()):
|
||||
dst.data[:] = src
|
||||
torch.save(vgg.state_dict(), os.path.join(model_dir, 'vgg16.weight'))
|
||||
vgg = Vgg16()
|
||||
vgg.load_state_dict(torch.load(os.path.join(model_dir, 'vgg16.weight')))
|
||||
return vgg
|
||||
|
||||
|
||||
def vgg_preprocess(batch):
|
||||
tensortype = type(batch.data)
|
||||
(r, g, b) = torch.chunk(batch, 3, dim = 1)
|
||||
batch = torch.cat((b, g, r), dim = 1) # convert RGB to BGR
|
||||
batch = (batch + 1) * 255 * 0.5 # [-1, 1] -> [0, 255]
|
||||
mean = tensortype(batch.data.size())
|
||||
mean[:, 0, :, :] = 103.939
|
||||
mean[:, 1, :, :] = 116.779
|
||||
mean[:, 2, :, :] = 123.680
|
||||
batch = batch.sub(Variable(mean)) # subtract mean
|
||||
return batch
|
||||
|
||||
|
||||
def get_scheduler(optimizer, hyperparameters, iterations=-1):
|
||||
if 'lr_policy' not in hyperparameters or hyperparameters['lr_policy'] == 'constant':
|
||||
scheduler = None # constant scheduler
|
||||
elif hyperparameters['lr_policy'] == 'step':
|
||||
scheduler = lr_scheduler.StepLR(optimizer, step_size=hyperparameters['step_size'],
|
||||
gamma=hyperparameters['gamma'], last_epoch=iterations)
|
||||
else:
|
||||
return NotImplementedError('learning rate policy [%s] is not implemented', hyperparameters['lr_policy'])
|
||||
return scheduler
|
||||
|
||||
|
||||
def weights_init(init_type='gaussian'):
|
||||
def init_fun(m):
|
||||
classname = m.__class__.__name__
|
||||
if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'):
|
||||
# print m.__class__.__name__
|
||||
if init_type == 'gaussian':
|
||||
init.normal(m.weight.data, 0.0, 0.02)
|
||||
elif init_type == 'xavier':
|
||||
init.xavier_normal(m.weight.data, gain=math.sqrt(2))
|
||||
elif init_type == 'kaiming':
|
||||
init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
|
||||
elif init_type == 'orthogonal':
|
||||
init.orthogonal(m.weight.data, gain=math.sqrt(2))
|
||||
elif init_type == 'default':
|
||||
pass
|
||||
else:
|
||||
assert 0, "Unsupported initialization: {}".format(init_type)
|
||||
if hasattr(m, 'bias') and m.bias is not None:
|
||||
init.constant(m.bias.data, 0.0)
|
||||
|
||||
return init_fun
|
@ -0,0 +1,2 @@
|
||||
from .modules import *
|
||||
from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
|
@ -0,0 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : __init__.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
|
||||
from .replicate import DataParallelWithCallback, patch_replication_callback
|
@ -0,0 +1,329 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : batchnorm.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import collections
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
||||
|
||||
from .comm import SyncMaster
|
||||
|
||||
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
|
||||
|
||||
|
||||
def _sum_ft(tensor):
|
||||
"""sum over the first and last dimention"""
|
||||
return tensor.sum(dim=0).sum(dim=-1)
|
||||
|
||||
|
||||
def _unsqueeze_ft(tensor):
|
||||
"""add new dementions at the front and the tail"""
|
||||
return tensor.unsqueeze(0).unsqueeze(-1)
|
||||
|
||||
|
||||
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
||||
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
||||
|
||||
|
||||
class _SynchronizedBatchNorm(_BatchNorm):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True):
|
||||
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
|
||||
|
||||
self._sync_master = SyncMaster(self._data_parallel_master)
|
||||
|
||||
self._is_parallel = False
|
||||
self._parallel_id = None
|
||||
self._slave_pipe = None
|
||||
|
||||
# customed batch norm statistics
|
||||
self._moving_average_fraction = 1. - momentum
|
||||
self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features))
|
||||
self.register_buffer('_tmp_running_var', torch.ones(self.num_features))
|
||||
self.register_buffer('_running_iter', torch.ones(1))
|
||||
self._tmp_running_mean = self.running_mean.clone() * self._running_iter
|
||||
self._tmp_running_var = self.running_var.clone() * self._running_iter
|
||||
|
||||
def forward(self, input):
|
||||
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
||||
if not (self._is_parallel and self.training):
|
||||
return F.batch_norm(
|
||||
input, self.running_mean, self.running_var, self.weight, self.bias,
|
||||
self.training, self.momentum, self.eps)
|
||||
|
||||
# Resize the input to (B, C, -1).
|
||||
input_shape = input.size()
|
||||
input = input.view(input.size(0), self.num_features, -1)
|
||||
|
||||
# Compute the sum and square-sum.
|
||||
sum_size = input.size(0) * input.size(2)
|
||||
input_sum = _sum_ft(input)
|
||||
input_ssum = _sum_ft(input ** 2)
|
||||
|
||||
# Reduce-and-broadcast the statistics.
|
||||
if self._parallel_id == 0:
|
||||
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
||||
else:
|
||||
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
||||
|
||||
# Compute the output.
|
||||
if self.affine:
|
||||
# MJY:: Fuse the multiplication for speed.
|
||||
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
||||
else:
|
||||
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
||||
|
||||
# Reshape it.
|
||||
return output.view(input_shape)
|
||||
|
||||
def __data_parallel_replicate__(self, ctx, copy_id):
|
||||
self._is_parallel = True
|
||||
self._parallel_id = copy_id
|
||||
|
||||
# parallel_id == 0 means master device.
|
||||
if self._parallel_id == 0:
|
||||
ctx.sync_master = self._sync_master
|
||||
else:
|
||||
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
||||
|
||||
def _data_parallel_master(self, intermediates):
|
||||
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
||||
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
||||
|
||||
to_reduce = [i[1][:2] for i in intermediates]
|
||||
to_reduce = [j for i in to_reduce for j in i] # flatten
|
||||
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
||||
|
||||
sum_size = sum([i[1].sum_size for i in intermediates])
|
||||
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
||||
|
||||
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
||||
|
||||
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
||||
|
||||
outputs = []
|
||||
for i, rec in enumerate(intermediates):
|
||||
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
|
||||
|
||||
return outputs
|
||||
|
||||
def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0):
|
||||
"""return *dest* by `dest := dest*alpha + delta*beta + bias`"""
|
||||
return dest * alpha + delta * beta + bias
|
||||
|
||||
def _compute_mean_std(self, sum_, ssum, size):
|
||||
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
||||
also maintains the moving average on the master device."""
|
||||
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
||||
mean = sum_ / size
|
||||
sumvar = ssum - sum_ * mean
|
||||
unbias_var = sumvar / (size - 1)
|
||||
bias_var = sumvar / size
|
||||
|
||||
self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction)
|
||||
self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction)
|
||||
self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction)
|
||||
|
||||
self.running_mean = self._tmp_running_mean / self._running_iter
|
||||
self.running_var = self._tmp_running_var / self._running_iter
|
||||
|
||||
return mean, bias_var.clamp(self.eps) ** -0.5
|
||||
|
||||
|
||||
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
||||
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
||||
mini-batch.
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||
|
||||
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
||||
standard-deviation are reduced across all devices during training.
|
||||
|
||||
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||
training, PyTorch's implementation normalize the tensor on each device using
|
||||
the statistics only on that device, which accelerated the computation and
|
||||
is also easy to implement, but the statistics might be inaccurate.
|
||||
Instead, in this synchronized version, the statistics will be computed
|
||||
over all training samples distributed on multiple devices.
|
||||
|
||||
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||
as the built-in PyTorch implementation.
|
||||
|
||||
The mean and standard-deviation are calculated per-dimension over
|
||||
the mini-batches and gamma and beta are learnable parameter vectors
|
||||
of size C (where C is the input size).
|
||||
|
||||
During training, this layer keeps a running estimate of its computed mean
|
||||
and variance. The running sum is kept with a default momentum of 0.1.
|
||||
|
||||
During evaluation, this running mean/variance is used for normalization.
|
||||
|
||||
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
||||
|
||||
Args:
|
||||
num_features: num_features from an expected input of size
|
||||
`batch_size x num_features [x width]`
|
||||
eps: a value added to the denominator for numerical stability.
|
||||
Default: 1e-5
|
||||
momentum: the value used for the running_mean and running_var
|
||||
computation. Default: 0.1
|
||||
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||
affine parameters. Default: ``True``
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
||||
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
||||
|
||||
Examples:
|
||||
>>> # With Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm1d(100)
|
||||
>>> # Without Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
||||
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 2 and input.dim() != 3:
|
||||
raise ValueError('expected 2D or 3D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
|
||||
|
||||
|
||||
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
||||
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
||||
of 3d inputs
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||
|
||||
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
||||
standard-deviation are reduced across all devices during training.
|
||||
|
||||
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||
training, PyTorch's implementation normalize the tensor on each device using
|
||||
the statistics only on that device, which accelerated the computation and
|
||||
is also easy to implement, but the statistics might be inaccurate.
|
||||
Instead, in this synchronized version, the statistics will be computed
|
||||
over all training samples distributed on multiple devices.
|
||||
|
||||
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||
as the built-in PyTorch implementation.
|
||||
|
||||
The mean and standard-deviation are calculated per-dimension over
|
||||
the mini-batches and gamma and beta are learnable parameter vectors
|
||||
of size C (where C is the input size).
|
||||
|
||||
During training, this layer keeps a running estimate of its computed mean
|
||||
and variance. The running sum is kept with a default momentum of 0.1.
|
||||
|
||||
During evaluation, this running mean/variance is used for normalization.
|
||||
|
||||
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
||||
|
||||
Args:
|
||||
num_features: num_features from an expected input of
|
||||
size batch_size x num_features x height x width
|
||||
eps: a value added to the denominator for numerical stability.
|
||||
Default: 1e-5
|
||||
momentum: the value used for the running_mean and running_var
|
||||
computation. Default: 0.1
|
||||
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||
affine parameters. Default: ``True``
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C, H, W)`
|
||||
- Output: :math:`(N, C, H, W)` (same shape as input)
|
||||
|
||||
Examples:
|
||||
>>> # With Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm2d(100)
|
||||
>>> # Without Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
||||
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 4:
|
||||
raise ValueError('expected 4D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
|
||||
|
||||
|
||||
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
||||
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
||||
of 4d inputs
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
||||
|
||||
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
||||
standard-deviation are reduced across all devices during training.
|
||||
|
||||
For example, when one uses `nn.DataParallel` to wrap the network during
|
||||
training, PyTorch's implementation normalize the tensor on each device using
|
||||
the statistics only on that device, which accelerated the computation and
|
||||
is also easy to implement, but the statistics might be inaccurate.
|
||||
Instead, in this synchronized version, the statistics will be computed
|
||||
over all training samples distributed on multiple devices.
|
||||
|
||||
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
||||
as the built-in PyTorch implementation.
|
||||
|
||||
The mean and standard-deviation are calculated per-dimension over
|
||||
the mini-batches and gamma and beta are learnable parameter vectors
|
||||
of size C (where C is the input size).
|
||||
|
||||
During training, this layer keeps a running estimate of its computed mean
|
||||
and variance. The running sum is kept with a default momentum of 0.1.
|
||||
|
||||
During evaluation, this running mean/variance is used for normalization.
|
||||
|
||||
Because the BatchNorm is done over the `C` dimension, computing statistics
|
||||
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
||||
or Spatio-temporal BatchNorm
|
||||
|
||||
Args:
|
||||
num_features: num_features from an expected input of
|
||||
size batch_size x num_features x depth x height x width
|
||||
eps: a value added to the denominator for numerical stability.
|
||||
Default: 1e-5
|
||||
momentum: the value used for the running_mean and running_var
|
||||
computation. Default: 0.1
|
||||
affine: a boolean value that when set to ``True``, gives the layer learnable
|
||||
affine parameters. Default: ``True``
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C, D, H, W)`
|
||||
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
||||
|
||||
Examples:
|
||||
>>> # With Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm3d(100)
|
||||
>>> # Without Learnable Parameters
|
||||
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
||||
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def _check_input_dim(self, input):
|
||||
if input.dim() != 5:
|
||||
raise ValueError('expected 5D input (got {}D input)'
|
||||
.format(input.dim()))
|
||||
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|
@ -0,0 +1,131 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : comm.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import queue
|
||||
import collections
|
||||
import threading
|
||||
|
||||
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
||||
|
||||
|
||||
class FutureResult(object):
|
||||
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
||||
|
||||
def __init__(self):
|
||||
self._result = None
|
||||
self._lock = threading.Lock()
|
||||
self._cond = threading.Condition(self._lock)
|
||||
|
||||
def put(self, result):
|
||||
with self._lock:
|
||||
assert self._result is None, 'Previous result has\'t been fetched.'
|
||||
self._result = result
|
||||
self._cond.notify()
|
||||
|
||||
def get(self):
|
||||
with self._lock:
|
||||
if self._result is None:
|
||||
self._cond.wait()
|
||||
|
||||
res = self._result
|
||||
self._result = None
|
||||
return res
|
||||
|
||||
|
||||
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
||||
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
||||
|
||||
|
||||
class SlavePipe(_SlavePipeBase):
|
||||
"""Pipe for master-slave communication."""
|
||||
|
||||
def run_slave(self, msg):
|
||||
self.queue.put((self.identifier, msg))
|
||||
ret = self.result.get()
|
||||
self.queue.put(True)
|
||||
return ret
|
||||
|
||||
|
||||
class SyncMaster(object):
|
||||
"""An abstract `SyncMaster` object.
|
||||
|
||||
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
||||
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
||||
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
||||
and passed to a registered callback.
|
||||
- After receiving the messages, the master device should gather the information and determine to message passed
|
||||
back to each slave devices.
|
||||
"""
|
||||
|
||||
def __init__(self, master_callback):
|
||||
"""
|
||||
|
||||
Args:
|
||||
master_callback: a callback to be invoked after having collected messages from slave devices.
|
||||
"""
|
||||
self._master_callback = master_callback
|
||||
self._queue = queue.Queue()
|
||||
self._registry = collections.OrderedDict()
|
||||
self._activated = False
|
||||
|
||||
def register_slave(self, identifier):
|
||||
"""
|
||||
Register an slave device.
|
||||
|
||||
Args:
|
||||
identifier: an identifier, usually is the device id.
|
||||
|
||||
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
||||
|
||||
"""
|
||||
if self._activated:
|
||||
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
||||
self._activated = False
|
||||
self._registry.clear()
|
||||
future = FutureResult()
|
||||
self._registry[identifier] = _MasterRegistry(future)
|
||||
return SlavePipe(identifier, self._queue, future)
|
||||
|
||||
def run_master(self, master_msg):
|
||||
"""
|
||||
Main entry for the master device in each forward pass.
|
||||
The messages were first collected from each devices (including the master device), and then
|
||||
an callback will be invoked to compute the message to be sent back to each devices
|
||||
(including the master device).
|
||||
|
||||
Args:
|
||||
master_msg: the message that the master want to send to itself. This will be placed as the first
|
||||
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
||||
|
||||
Returns: the message to be sent back to the master device.
|
||||
|
||||
"""
|
||||
self._activated = True
|
||||
|
||||
intermediates = [(0, master_msg)]
|
||||
for i in range(self.nr_slaves):
|
||||
intermediates.append(self._queue.get())
|
||||
|
||||
results = self._master_callback(intermediates)
|
||||
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
||||
|
||||
for i, res in results:
|
||||
if i == 0:
|
||||
continue
|
||||
self._registry[i].result.put(res)
|
||||
|
||||
for i in range(self.nr_slaves):
|
||||
assert self._queue.get() is True
|
||||
|
||||
return results[0][1]
|
||||
|
||||
@property
|
||||
def nr_slaves(self):
|
||||
return len(self._registry)
|
@ -0,0 +1,94 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : replicate.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import functools
|
||||
|
||||
from torch.nn.parallel.data_parallel import DataParallel
|
||||
|
||||
__all__ = [
|
||||
'CallbackContext',
|
||||
'execute_replication_callbacks',
|
||||
'DataParallelWithCallback',
|
||||
'patch_replication_callback'
|
||||
]
|
||||
|
||||
|
||||
class CallbackContext(object):
|
||||
pass
|
||||
|
||||
|
||||
def execute_replication_callbacks(modules):
|
||||
"""
|
||||
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
|
||||
|
||||
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
||||
|
||||
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
||||
(shared among multiple copies of this module on different devices).
|
||||
Through this context, different copies can share some information.
|
||||
|
||||
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
|
||||
of any slave copies.
|
||||
"""
|
||||
master_copy = modules[0]
|
||||
nr_modules = len(list(master_copy.modules()))
|
||||
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
||||
|
||||
for i, module in enumerate(modules):
|
||||
for j, m in enumerate(module.modules()):
|
||||
if hasattr(m, '__data_parallel_replicate__'):
|
||||
m.__data_parallel_replicate__(ctxs[j], i)
|
||||
|
||||
|
||||
class DataParallelWithCallback(DataParallel):
|
||||
"""
|
||||
Data Parallel with a replication callback.
|
||||
|
||||
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
|
||||
original `replicate` function.
|
||||
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
||||
|
||||
Examples:
|
||||
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
||||
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
||||
# sync_bn.__data_parallel_replicate__ will be invoked.
|
||||
"""
|
||||
|
||||
def replicate(self, module, device_ids):
|
||||
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
|
||||
execute_replication_callbacks(modules)
|
||||
return modules
|
||||
|
||||
|
||||
def patch_replication_callback(data_parallel):
|
||||
"""
|
||||
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
||||
Useful when you have customized `DataParallel` implementation.
|
||||
|
||||
Examples:
|
||||
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
||||
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
||||
> patch_replication_callback(sync_bn)
|
||||
# this is equivalent to
|
||||
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
||||
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
||||
"""
|
||||
|
||||
assert isinstance(data_parallel, DataParallel)
|
||||
|
||||
old_replicate = data_parallel.replicate
|
||||
|
||||
@functools.wraps(old_replicate)
|
||||
def new_replicate(module, device_ids):
|
||||
modules = old_replicate(module, device_ids)
|
||||
execute_replication_callbacks(modules)
|
||||
return modules
|
||||
|
||||
data_parallel.replicate = new_replicate
|
@ -0,0 +1,29 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# File : unittest.py
|
||||
# Author : Jiayuan Mao
|
||||
# Email : maojiayuan@gmail.com
|
||||
# Date : 27/01/2018
|
||||
#
|
||||
# This file is part of Synchronized-BatchNorm-PyTorch.
|
||||
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
||||
# Distributed under MIT License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
def as_numpy(v):
|
||||
if isinstance(v, Variable):
|
||||
v = v.data
|
||||
return v.cpu().numpy()
|
||||
|
||||
|
||||
class TorchTestCase(unittest.TestCase):
|
||||
def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
|
||||
npa, npb = as_numpy(a), as_numpy(b)
|
||||
self.assertTrue(
|
||||
np.allclose(npa, npb, atol=atol),
|
||||
'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
|
||||
)
|
@ -0,0 +1 @@
|
||||
from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
|
@ -0,0 +1,112 @@
|
||||
# -*- coding: utf8 -*-
|
||||
|
||||
import torch.cuda as cuda
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import collections
|
||||
from torch.nn.parallel._functions import Gather
|
||||
|
||||
|
||||
__all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to']
|
||||
|
||||
|
||||
def async_copy_to(obj, dev, main_stream=None):
|
||||
if torch.is_tensor(obj):
|
||||
v = obj.cuda(dev, non_blocking=True)
|
||||
if main_stream is not None:
|
||||
v.data.record_stream(main_stream)
|
||||
return v
|
||||
elif isinstance(obj, collections.Mapping):
|
||||
return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()}
|
||||
elif isinstance(obj, collections.Sequence):
|
||||
return [async_copy_to(o, dev, main_stream) for o in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
def dict_gather(outputs, target_device, dim=0):
|
||||
"""
|
||||
Gathers variables from different GPUs on a specified device
|
||||
(-1 means the CPU), with dictionary support.
|
||||
"""
|
||||
def gather_map(outputs):
|
||||
out = outputs[0]
|
||||
if torch.is_tensor(out):
|
||||
# MJY(20180330) HACK:: force nr_dims > 0
|
||||
if out.dim() == 0:
|
||||
outputs = [o.unsqueeze(0) for o in outputs]
|
||||
return Gather.apply(target_device, dim, *outputs)
|
||||
elif out is None:
|
||||
return None
|
||||
elif isinstance(out, collections.Mapping):
|
||||
return {k: gather_map([o[k] for o in outputs]) for k in out}
|
||||
elif isinstance(out, collections.Sequence):
|
||||
return type(out)(map(gather_map, zip(*outputs)))
|
||||
return gather_map(outputs)
|
||||
|
||||
|
||||
class DictGatherDataParallel(nn.DataParallel):
|
||||
def gather(self, outputs, output_device):
|
||||
return dict_gather(outputs, output_device, dim=self.dim)
|
||||
|
||||
|
||||
class UserScatteredDataParallel(DictGatherDataParallel):
|
||||
def scatter(self, inputs, kwargs, device_ids):
|
||||
assert len(inputs) == 1
|
||||
inputs = inputs[0]
|
||||
inputs = _async_copy_stream(inputs, device_ids)
|
||||
inputs = [[i] for i in inputs]
|
||||
assert len(kwargs) == 0
|
||||
kwargs = [{} for _ in range(len(inputs))]
|
||||
|
||||
return inputs, kwargs
|
||||
|
||||
|
||||
def user_scattered_collate(batch):
|
||||
return batch
|
||||
|
||||
|
||||
def _async_copy(inputs, device_ids):
|
||||
nr_devs = len(device_ids)
|
||||
assert type(inputs) in (tuple, list)
|
||||
assert len(inputs) == nr_devs
|
||||
|
||||
outputs = []
|
||||
for i, dev in zip(inputs, device_ids):
|
||||
with cuda.device(dev):
|
||||
outputs.append(async_copy_to(i, dev))
|
||||
|
||||
return tuple(outputs)
|
||||
|
||||
|
||||
def _async_copy_stream(inputs, device_ids):
|
||||
nr_devs = len(device_ids)
|
||||
assert type(inputs) in (tuple, list)
|
||||
assert len(inputs) == nr_devs
|
||||
|
||||
outputs = []
|
||||
streams = [_get_stream(d) for d in device_ids]
|
||||
for i, dev, stream in zip(inputs, device_ids, streams):
|
||||
with cuda.device(dev):
|
||||
main_stream = cuda.current_stream()
|
||||
with cuda.stream(stream):
|
||||
outputs.append(async_copy_to(i, dev, main_stream=main_stream))
|
||||
main_stream.wait_stream(stream)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
"""Adapted from: torch/nn/parallel/_functions.py"""
|
||||
# background streams used for copying
|
||||
_streams = None
|
||||
|
||||
|
||||
def _get_stream(device):
|
||||
"""Gets a background stream for copying between CPU and GPU"""
|
||||
global _streams
|
||||
if device == -1:
|
||||
return None
|
||||
if _streams is None:
|
||||
_streams = [None] * cuda.device_count()
|
||||
if _streams[device] is None: _streams[device] = cuda.Stream(device)
|
||||
return _streams[device]
|
@ -0,0 +1,266 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from builtins import *
|
||||
|
||||
def resize_like(x, target, mode='bilinear'):
|
||||
return F.interpolate(x, target.shape[-2:], mode=mode, align_corners=False)
|
||||
|
||||
def get_norm(name, out_channels):
|
||||
if name == 'batch':
|
||||
norm = nn.BatchNorm2d(out_channels)
|
||||
elif name == 'instance':
|
||||
norm = nn.InstanceNorm2d(out_channels)
|
||||
else:
|
||||
norm = None
|
||||
return norm
|
||||
|
||||
|
||||
def get_activation(name):
|
||||
if name == 'relu':
|
||||
activation = nn.ReLU()
|
||||
elif name == 'elu':
|
||||
activation == nn.ELU()
|
||||
elif name == 'leaky_relu':
|
||||
activation = nn.LeakyReLU(negative_slope=0.2)
|
||||
elif name == 'tanh':
|
||||
activation = nn.Tanh()
|
||||
elif name == 'sigmoid':
|
||||
activation = nn.Sigmoid()
|
||||
else:
|
||||
activation = None
|
||||
return activation
|
||||
|
||||
|
||||
class Conv2dSame(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride):
|
||||
super().__init__()
|
||||
|
||||
padding = self.conv_same_pad(kernel_size, stride)
|
||||
if type(padding) is not tuple:
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size, stride, padding)
|
||||
else:
|
||||
self.conv = nn.Sequential(
|
||||
nn.ConstantPad2d(padding*2, 0),
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size, stride, 0)
|
||||
)
|
||||
|
||||
def conv_same_pad(self, ksize, stride):
|
||||
if (ksize - stride) % 2 == 0:
|
||||
return (ksize - stride) // 2
|
||||
else:
|
||||
left = (ksize - stride) // 2
|
||||
right = left + 1
|
||||
return left, right
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class ConvTranspose2dSame(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride):
|
||||
super().__init__()
|
||||
|
||||
padding, output_padding = self.deconv_same_pad(kernel_size, stride)
|
||||
self.trans_conv = nn.ConvTranspose2d(
|
||||
in_channels, out_channels, kernel_size, stride,
|
||||
padding, output_padding)
|
||||
|
||||
def deconv_same_pad(self, ksize, stride):
|
||||
pad = (ksize - stride + 1) // 2
|
||||
outpad = 2 * pad + stride - ksize
|
||||
return pad, outpad
|
||||
|
||||
def forward(self, x):
|
||||
return self.trans_conv(x)
|
||||
|
||||
|
||||
class UpBlock(nn.Module):
|
||||
|
||||
def __init__(self, mode='nearest', scale=2, channel=None, kernel_size=4):
|
||||
super().__init__()
|
||||
|
||||
self.mode = mode
|
||||
if mode == 'deconv':
|
||||
self.up = ConvTranspose2dSame(
|
||||
channel, channel, kernel_size, stride=scale)
|
||||
else:
|
||||
def upsample(x):
|
||||
return F.interpolate(x, scale_factor=scale, mode=mode)
|
||||
self.up = upsample
|
||||
|
||||
def forward(self, x):
|
||||
return self.up(x)
|
||||
|
||||
|
||||
class EncodeBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, in_channels, out_channels, kernel_size, stride,
|
||||
normalization=None, activation=None):
|
||||
super().__init__()
|
||||
|
||||
self.c_in = in_channels
|
||||
self.c_out = out_channels
|
||||
|
||||
layers = []
|
||||
layers.append(
|
||||
Conv2dSame(self.c_in, self.c_out, kernel_size, stride))
|
||||
if normalization:
|
||||
layers.append(get_norm(normalization, self.c_out))
|
||||
if activation:
|
||||
layers.append(get_activation(activation))
|
||||
self.encode = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.encode(x)
|
||||
|
||||
|
||||
class DecodeBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, c_from_up, c_from_down, c_out, mode='nearest',
|
||||
kernel_size=4, scale=2, normalization='batch', activation='relu'):
|
||||
super().__init__()
|
||||
|
||||
self.c_from_up = c_from_up
|
||||
self.c_from_down = c_from_down
|
||||
self.c_in = c_from_up + c_from_down
|
||||
self.c_out = c_out
|
||||
|
||||
self.up = UpBlock(mode, scale, c_from_up, kernel_size=scale)
|
||||
|
||||
layers = []
|
||||
layers.append(
|
||||
Conv2dSame(self.c_in, self.c_out, kernel_size, stride=1))
|
||||
if normalization:
|
||||
layers.append(get_norm(normalization, self.c_out))
|
||||
if activation:
|
||||
layers.append(get_activation(activation))
|
||||
self.decode = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x, concat=None):
|
||||
out = self.up(x)
|
||||
if self.c_from_down > 0:
|
||||
out = torch.cat([out, concat], dim=1)
|
||||
out = self.decode(out)
|
||||
return out
|
||||
|
||||
|
||||
class BlendBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, c_in, c_out, ksize_mid=3, norm='batch', act='leaky_relu'):
|
||||
super().__init__()
|
||||
c_mid = max(c_in // 2, 32)
|
||||
self.blend = nn.Sequential(
|
||||
Conv2dSame(c_in, c_mid, 1, 1),
|
||||
get_norm(norm, c_mid),
|
||||
get_activation(act),
|
||||
Conv2dSame(c_mid, c_out, ksize_mid, 1),
|
||||
get_norm(norm, c_out),
|
||||
get_activation(act),
|
||||
Conv2dSame(c_out, c_out, 1, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.blend(x)
|
||||
|
||||
|
||||
class FusionBlock(nn.Module):
|
||||
def __init__(self, c_feat, c_alpha=1):
|
||||
super().__init__()
|
||||
c_img = 3
|
||||
self.map2img = nn.Sequential(
|
||||
Conv2dSame(c_feat, c_img, 1, 1),
|
||||
nn.Sigmoid())
|
||||
self.blend = BlendBlock(c_img*2, c_alpha)
|
||||
|
||||
def forward(self, img_miss, feat_de):
|
||||
img_miss = resize_like(img_miss, feat_de)
|
||||
raw = self.map2img(feat_de)
|
||||
alpha = self.blend(torch.cat([img_miss, raw], dim=1))
|
||||
result = alpha * raw + (1 - alpha) * img_miss
|
||||
return result, alpha, raw
|
||||
|
||||
|
||||
class DFNet(nn.Module):
|
||||
def __init__(
|
||||
self, c_img=3, c_mask=1, c_alpha=3,
|
||||
mode='nearest', norm='batch', act_en='relu', act_de='leaky_relu',
|
||||
en_ksize=[7, 5, 5, 3, 3, 3, 3, 3], de_ksize=[3]*8,
|
||||
blend_layers=[0, 1, 2, 3, 4, 5]):
|
||||
super().__init__()
|
||||
|
||||
c_init = c_img + c_mask
|
||||
|
||||
self.n_en = len(en_ksize)
|
||||
self.n_de = len(de_ksize)
|
||||
assert self.n_en == self.n_de, (
|
||||
'The number layer of Encoder and Decoder must be equal.')
|
||||
assert self.n_en >= 1, (
|
||||
'The number layer of Encoder and Decoder must be greater than 1.')
|
||||
|
||||
assert 0 in blend_layers, 'Layer 0 must be blended.'
|
||||
|
||||
self.en = []
|
||||
c_in = c_init
|
||||
self.en.append(
|
||||
EncodeBlock(c_in, 64, en_ksize[0], 2, None, None))
|
||||
for k_en in en_ksize[1:]:
|
||||
c_in = self.en[-1].c_out
|
||||
c_out = min(c_in*2, 512)
|
||||
self.en.append(EncodeBlock(
|
||||
c_in, c_out, k_en, stride=2,
|
||||
normalization=norm, activation=act_en))
|
||||
|
||||
# register parameters
|
||||
for i, en in enumerate(self.en):
|
||||
self.__setattr__('en_{}'.format(i), en)
|
||||
|
||||
self.de = []
|
||||
self.fuse = []
|
||||
for i, k_de in enumerate(de_ksize):
|
||||
|
||||
c_from_up = self.en[-1].c_out if i == 0 else self.de[-1].c_out
|
||||
c_out = c_from_down = self.en[-i-1].c_in
|
||||
layer_idx = self.n_de - i - 1
|
||||
|
||||
self.de.append(DecodeBlock(
|
||||
c_from_up, c_from_down, c_out, mode, k_de, scale=2,
|
||||
normalization=norm, activation=act_de))
|
||||
if layer_idx in blend_layers:
|
||||
self.fuse.append(FusionBlock(c_out, c_alpha))
|
||||
else:
|
||||
self.fuse.append(None)
|
||||
|
||||
# register parameters
|
||||
for i, de in enumerate(self.de[::-1]):
|
||||
self.__setattr__('de_{}'.format(i), de)
|
||||
for i, fuse in enumerate(self.fuse[::-1]):
|
||||
if fuse:
|
||||
self.__setattr__('fuse_{}'.format(i), fuse)
|
||||
|
||||
def forward(self, img_miss, mask):
|
||||
|
||||
out = torch.cat([img_miss, mask], dim=1)
|
||||
|
||||
out_en = [out]
|
||||
for encode in self.en:
|
||||
out = encode(out)
|
||||
out_en.append(out)
|
||||
|
||||
results = []
|
||||
alphas = []
|
||||
raws = []
|
||||
for i, (decode, fuse) in enumerate(zip(self.de, self.fuse)):
|
||||
out = decode(out, out_en[-i-2])
|
||||
if fuse:
|
||||
result, alpha, raw = fuse(img_miss, out)
|
||||
results.append(result)
|
||||
|
||||
return results[::-1]
|
@ -0,0 +1,71 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from DFNet_core import get_norm, get_activation, Conv2dSame, ConvTranspose2dSame, UpBlock, EncodeBlock, DecodeBlock
|
||||
from builtins import *
|
||||
|
||||
|
||||
class RefinementNet(nn.Module):
|
||||
def __init__(
|
||||
self, c_img=19, c_mask=1,
|
||||
mode='nearest', norm='batch', act_en='relu', act_de='leaky_relu',
|
||||
en_ksize=[7, 5, 5, 3, 3, 3, 3, 3], de_ksize=[3]*8):
|
||||
super(RefinementNet, self).__init__()
|
||||
|
||||
c_in = c_img + c_mask
|
||||
|
||||
self.en1 = EncodeBlock(c_in, 96, en_ksize[0], 2, None, None)
|
||||
self.en2 = EncodeBlock(96, 192, en_ksize[1], stride=2, normalization=norm, activation=act_en)
|
||||
self.en3 = EncodeBlock(192, 384, en_ksize[2], stride=2, normalization=norm, activation=act_en)
|
||||
self.en4 = EncodeBlock(384, 512, en_ksize[3], stride=2, normalization=norm, activation=act_en)
|
||||
self.en5 = EncodeBlock(512, 512, en_ksize[4], stride=2, normalization=norm, activation=act_en)
|
||||
self.en6 = EncodeBlock(512, 512, en_ksize[5], stride=2, normalization=norm, activation=act_en)
|
||||
self.en7 = EncodeBlock(512, 512, en_ksize[6], stride=2, normalization=norm, activation=act_en)
|
||||
self.en8 = EncodeBlock(512, 512, en_ksize[7], stride=2, normalization=norm, activation=act_en)
|
||||
|
||||
self.de1 = DecodeBlock(512, 512, 512, mode, 3, scale=2,normalization=norm, activation=act_de)
|
||||
self.de2 = DecodeBlock(512, 512, 512, mode, 3, scale=2,normalization=norm, activation=act_de)
|
||||
self.de3 = DecodeBlock(512, 512, 512, mode, 3, scale=2,normalization=norm, activation=act_de)
|
||||
self.de4 = DecodeBlock(512, 512, 512, mode, 3, scale=2,normalization=norm, activation=act_de)
|
||||
self.de5 = DecodeBlock(512, 384, 384, mode, 3, scale=2,normalization=norm, activation=act_de)
|
||||
self.de6 = DecodeBlock(384, 192, 192, mode, 3, scale=2,normalization=norm, activation=act_de)
|
||||
self.de7 = DecodeBlock(192, 96, 96, mode, 3, scale=2,normalization=norm, activation=act_de)
|
||||
self.de8 = DecodeBlock(96, 20, 20, mode, 3, scale=2,normalization=norm, activation=act_de)
|
||||
|
||||
self.last_conv = nn.Sequential(Conv2dSame(c_in, 3, 1, 1), nn.Sigmoid())
|
||||
|
||||
def forward(self, img, mask):
|
||||
out = torch.cat([mask, img], dim=1)
|
||||
out_en = [out]
|
||||
|
||||
out = self.en1(out)
|
||||
out_en.append(out)
|
||||
out = self.en2(out)
|
||||
out_en.append(out)
|
||||
out = self.en3(out)
|
||||
out_en.append(out)
|
||||
out = self.en4(out)
|
||||
out_en.append(out)
|
||||
out = self.en5(out)
|
||||
out_en.append(out)
|
||||
out = self.en6(out)
|
||||
out_en.append(out)
|
||||
out = self.en7(out)
|
||||
out_en.append(out)
|
||||
out = self.en8(out)
|
||||
out_en.append(out)
|
||||
|
||||
|
||||
out = self.de1(out, out_en[-0-2])
|
||||
out = self.de2(out, out_en[-1-2])
|
||||
out = self.de3(out, out_en[-2-2])
|
||||
out = self.de4(out, out_en[-3-2])
|
||||
out = self.de5(out, out_en[-4-2])
|
||||
out = self.de6(out, out_en[-5-2])
|
||||
out = self.de7(out, out_en[-6-2])
|
||||
out = self.de8(out, out_en[-7-2])
|
||||
|
||||
output = self.last_conv(out)
|
||||
|
||||
output = mask * output + (1 - mask) * img[:, :3]
|
||||
|
||||
return output
|
@ -0,0 +1,73 @@
|
||||
'''
|
||||
Minor Modification from https://github.com/SaoYan/DnCNN-PyTorch SaoYan
|
||||
Re-implemented by Yuqian Zhou
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class DnCNN(nn.Module):
|
||||
'''
|
||||
Original DnCNN model without input conditions
|
||||
'''
|
||||
def __init__(self, channels, num_of_layers=17):
|
||||
super(DnCNN, self).__init__()
|
||||
kernel_size = 3
|
||||
padding = 1
|
||||
features = 64
|
||||
layers = []
|
||||
layers.append(nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
for _ in range(num_of_layers-2):
|
||||
layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
|
||||
layers.append(nn.BatchNorm2d(features))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False))
|
||||
self.dncnn = nn.Sequential(*layers)
|
||||
def forward(self, input_x):
|
||||
out = self.dncnn(input_x)
|
||||
return out
|
||||
|
||||
|
||||
class Estimation_direct(nn.Module):
|
||||
'''
|
||||
Noise estimator, with original 3 layers
|
||||
'''
|
||||
def __init__(self, input_channels = 1, output_channels = 3, num_of_layers=3):
|
||||
super(Estimation_direct, self).__init__()
|
||||
kernel_size = 3
|
||||
padding = 1
|
||||
features = 64
|
||||
layers = []
|
||||
layers.append(nn.Conv2d(in_channels=input_channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
for _ in range(num_of_layers-2):
|
||||
layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
|
||||
layers.append(nn.BatchNorm2d(features))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
layers.append(nn.Conv2d(in_channels=features, out_channels=output_channels, kernel_size=kernel_size, padding=padding, bias=False))
|
||||
self.dncnn = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, input):
|
||||
x = self.dncnn(input)
|
||||
return x
|
||||
|
||||
|
||||
class DnCNN_c(nn.Module):
|
||||
def __init__(self, channels, num_of_layers=17, num_of_est=3):
|
||||
super(DnCNN_c, self).__init__()
|
||||
kernel_size = 3
|
||||
padding = 1
|
||||
features = 64
|
||||
layers = []
|
||||
layers.append(nn.Conv2d(in_channels=channels+ num_of_est, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
for _ in range(num_of_layers-2):
|
||||
layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
|
||||
layers.append(nn.BatchNorm2d(features))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False))
|
||||
self.dncnn = nn.Sequential(*layers)
|
||||
def forward(self, x, c):
|
||||
input_x = torch.cat([x, c], dim=1)
|
||||
out = self.dncnn(input_x)
|
||||
return out
|
@ -0,0 +1,138 @@
|
||||
import cv2
|
||||
import os
|
||||
import argparse
|
||||
import glob
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
from models_denoise import *
|
||||
from utils import *
|
||||
from PIL import Image
|
||||
import scipy.io as sio
|
||||
|
||||
#the limitation range of each type of noise level: [0]Gaussian [1]Impulse
|
||||
limit_set = [[0,75], [0, 80]]
|
||||
|
||||
def img_normalize(data):
|
||||
return data/255.
|
||||
|
||||
def denoiser(Img, c, pss, model, model_est, opt, cFlag):
|
||||
|
||||
w, h, _ = Img.shape
|
||||
Img = pixelshuffle(Img, pss)
|
||||
Img = img_normalize(np.float32(Img))
|
||||
|
||||
noise_level_list = np.zeros((2 * c,1)) #two noise types with three channels
|
||||
if opt.cond == 0: #if we use the ground truth of noise for denoising, and only one single noise type
|
||||
noise_level_list = np.array(opt.test_noise_level)
|
||||
elif opt.cond == 2: #if we use an external fixed input condition for denoising
|
||||
noise_level_list = np.array(opt.ext_test_noise_level)
|
||||
|
||||
#Clean Image Tensor for evaluation
|
||||
ISource = np2ts(Img)
|
||||
# noisy image and true residual
|
||||
if opt.real_n == 0 and opt.spat_n == 0: #no spatial noise setting, and synthetic noise
|
||||
noisy_img = generate_comp_noisy(Img, np.array(opt.test_noise_level) / 255.)
|
||||
if opt.color == 0:
|
||||
noisy_img = np.expand_dims(noisy_img[:,:,0], 2)
|
||||
elif opt.real_n == 1 or opt.real_n == 2: #testing real noisy images
|
||||
noisy_img = Img
|
||||
elif opt.spat_n == 1:
|
||||
noisy_img = generate_noisy(Img, 2, 0, 20, 40)
|
||||
INoisy = np2ts(noisy_img, opt.color)
|
||||
INoisy = torch.clamp(INoisy, 0., 1.)
|
||||
True_Res = INoisy - ISource
|
||||
if torch.cuda.is_available() and not cFlag:
|
||||
ISource, INoisy, True_Res = Variable(ISource.cuda(),volatile=True), Variable(INoisy.cuda(),volatile=True), Variable(True_Res.cuda(),volatile=True)
|
||||
else:
|
||||
ISource, INoisy, True_Res = Variable(ISource,volatile=True), Variable(INoisy,volatile=True), Variable(True_Res,volatile=True)
|
||||
|
||||
|
||||
|
||||
if opt.mode == "MC":
|
||||
# obtain the corrresponding input_map
|
||||
if opt.cond == 0 or opt.cond == 2: #if we use ground choose level or the given fixed level
|
||||
#normalize noise leve map to [0,1]
|
||||
noise_level_list_n = np.zeros((2*c, 1))
|
||||
print(c)
|
||||
for noise_type in range(2):
|
||||
for chn in range(c):
|
||||
noise_level_list_n[noise_type * c + chn] = normalize(noise_level_list[noise_type * 3 + chn], 1, limit_set[noise_type][0], limit_set[noise_type][1]) #normalize the level value
|
||||
#generate noise maps
|
||||
noise_map = np.zeros((1, 2 * c, Img.shape[0], Img.shape[1])) #initialize the noise map
|
||||
noise_map[0, :, :, :] = np.reshape(np.tile(noise_level_list_n, Img.shape[0] * Img.shape[1]), (2*c, Img.shape[0], Img.shape[1]))
|
||||
NM_tensor = torch.from_numpy(noise_map).type(torch.FloatTensor)
|
||||
NM_tensor = Variable(NM_tensor.cuda(),volatile=True)
|
||||
#use the estimated noise-level map for blind denoising
|
||||
elif opt.cond == 1: #if we use the estimated map directly
|
||||
NM_tensor = torch.clamp(model_est(INoisy), 0., 1.)
|
||||
if opt.refine == 1: #if we need to refine the map before putting it to the denoiser
|
||||
NM_tensor_bundle = level_refine(NM_tensor, opt.refine_opt, 2*c, cFlag) #refine_opt can be max, freq and their average
|
||||
NM_tensor = NM_tensor_bundle[0]
|
||||
noise_estimation_table = np.reshape(NM_tensor_bundle[1], (2 * c,))
|
||||
if opt.zeroout == 1:
|
||||
NM_tensor = zeroing_out_maps(NM_tensor, opt.keep_ind)
|
||||
Res = model(INoisy, NM_tensor)
|
||||
|
||||
elif opt.mode == "B":
|
||||
Res = model(INoisy)
|
||||
|
||||
Out = torch.clamp(INoisy-Res, 0., 1.) #Output image after denoising
|
||||
|
||||
#get the maximum denoising result
|
||||
max_NM_tensor = level_refine(NM_tensor, 1, 2*c, cFlag)[0]
|
||||
max_Res = model(INoisy, max_NM_tensor)
|
||||
max_Out = torch.clamp(INoisy - max_Res, 0., 1.)
|
||||
max_out_numpy = visual_va2np(max_Out, opt.color, opt.ps, pss, 1, opt.rescale, w, h, c)
|
||||
del max_Out
|
||||
del max_Res
|
||||
del max_NM_tensor
|
||||
|
||||
if (opt.ps == 1 or opt.ps == 2) and pss!=1: #pixelshuffle multi-scale
|
||||
#create batch of images with one subsitution
|
||||
mosaic_den = visual_va2np(Out, opt.color, 1, pss, 1, opt.rescale, w, h, c)
|
||||
out_numpy = np.zeros((pss ** 2, c, w, h))
|
||||
#compute all the images in the ps scale set
|
||||
for row in range(pss):
|
||||
for column in range(pss):
|
||||
re_test = visual_va2np(Out, opt.color, 1, pss, 1, opt.rescale, w, h, c, 1, visual_va2np(INoisy, opt.color), [row, column])/255.
|
||||
#cv2.imwrite(os.path.join(opt.out_dir,file_name + '_%d_%d.png' % (row, column)), re_test[:,:,::-1]*255.)
|
||||
re_test = np.expand_dims(re_test, 0)
|
||||
if opt.color == 0: #if gray image
|
||||
re_test = np.expand_dims(re_test[:, :, :, 0], 3)
|
||||
re_test_tensor = torch.from_numpy(np.transpose(re_test, (0,3,1,2))).type(torch.FloatTensor)
|
||||
if torch.cuda.is_available() and not cFlag:
|
||||
re_test_tensor = Variable(re_test_tensor.cuda(),volatile=True)
|
||||
else:
|
||||
re_test_tensor = Variable(re_test_tensor, volatile=True)
|
||||
|
||||
re_NM_tensor = torch.clamp(model_est(re_test_tensor), 0., 1.)
|
||||
if opt.refine == 1: #if we need to refine the map before putting it to the denoiser
|
||||
re_NM_tensor_bundle = level_refine(re_NM_tensor, opt.refine_opt, 2*c, cFlag) #refine_opt can be max, freq and their average
|
||||
re_NM_tensor = re_NM_tensor_bundle[0]
|
||||
re_Res = model(re_test_tensor, re_NM_tensor)
|
||||
Out2 = torch.clamp(re_test_tensor - re_Res, 0., 1.)
|
||||
out_numpy[row*pss+column,:,:,:] = Out2.data.cpu().numpy()
|
||||
del Out2
|
||||
del re_Res
|
||||
del re_test_tensor
|
||||
del re_NM_tensor
|
||||
del re_test
|
||||
|
||||
out_numpy = np.mean(out_numpy, 0)
|
||||
out_numpy = np.transpose(out_numpy, (1,2,0)) * 255.0
|
||||
elif opt.ps == 0 or pss==1: #other cases
|
||||
out_numpy = visual_va2np(Out, opt.color, 0, 1, 1, opt.rescale, w, h, c)
|
||||
|
||||
out_numpy = out_numpy.astype(np.float32) #details
|
||||
max_out_numpy = max_out_numpy.astype(np.float32) #background
|
||||
|
||||
#merging the details and background to balance the effect
|
||||
k = opt.k
|
||||
merge_out_numpy = (1-k)*out_numpy + k*max_out_numpy
|
||||
merge_out_numpy = merge_out_numpy.astype(np.float32)
|
||||
|
||||
return merge_out_numpy
|
||||
|
||||
|
@ -0,0 +1,73 @@
|
||||
'''
|
||||
Minor Modification from https://github.com/SaoYan/DnCNN-PyTorch SaoYan
|
||||
Re-implemented by Yuqian Zhou
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class DnCNN(nn.Module):
|
||||
'''
|
||||
Original DnCNN model without input conditions
|
||||
'''
|
||||
def __init__(self, channels, num_of_layers=17):
|
||||
super(DnCNN, self).__init__()
|
||||
kernel_size = 3
|
||||
padding = 1
|
||||
features = 64
|
||||
layers = []
|
||||
layers.append(nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
for _ in range(num_of_layers-2):
|
||||
layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
|
||||
layers.append(nn.BatchNorm2d(features))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False))
|
||||
self.dncnn = nn.Sequential(*layers)
|
||||
def forward(self, input_x):
|
||||
out = self.dncnn(input_x)
|
||||
return out
|
||||
|
||||
|
||||
class Estimation_direct(nn.Module):
|
||||
'''
|
||||
Noise estimator, with original 3 layers
|
||||
'''
|
||||
def __init__(self, input_channels = 1, output_channels = 3, num_of_layers=3):
|
||||
super(Estimation_direct, self).__init__()
|
||||
kernel_size = 3
|
||||
padding = 1
|
||||
features = 64
|
||||
layers = []
|
||||
layers.append(nn.Conv2d(in_channels=input_channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
for _ in range(num_of_layers-2):
|
||||
layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
|
||||
layers.append(nn.BatchNorm2d(features))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
layers.append(nn.Conv2d(in_channels=features, out_channels=output_channels, kernel_size=kernel_size, padding=padding, bias=False))
|
||||
self.dncnn = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, input):
|
||||
x = self.dncnn(input)
|
||||
return x
|
||||
|
||||
|
||||
class DnCNN_c(nn.Module):
|
||||
def __init__(self, channels, num_of_layers=17, num_of_est=3):
|
||||
super(DnCNN_c, self).__init__()
|
||||
kernel_size = 3
|
||||
padding = 1
|
||||
features = 64
|
||||
layers = []
|
||||
layers.append(nn.Conv2d(in_channels=channels+ num_of_est, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
for _ in range(num_of_layers-2):
|
||||
layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
|
||||
layers.append(nn.BatchNorm2d(features))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False))
|
||||
self.dncnn = nn.Sequential(*layers)
|
||||
def forward(self, x, c):
|
||||
input_x = torch.cat([x, c], dim=1)
|
||||
out = self.dncnn(input_x)
|
||||
return out
|
@ -0,0 +1,532 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
# from skimage.measure.simple_metrics import compare_psnr
|
||||
from torch.autograd import Variable
|
||||
import cv2
|
||||
import scipy.ndimage
|
||||
import scipy.io as sio
|
||||
# import matplotlib as mpl
|
||||
# mpl.use('Agg')
|
||||
# import matplotlib.pyplot as plt
|
||||
|
||||
def weights_init_kaiming(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv') != -1:
|
||||
nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
|
||||
elif classname.find('Linear') != -1:
|
||||
nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
# nn.init.uniform(m.weight.data, 1.0, 0.02)
|
||||
m.weight.data.normal_(mean=0, std=math.sqrt(2./9./64.)).clamp_(-0.025,0.025)
|
||||
nn.init.constant(m.bias.data, 0.0)
|
||||
|
||||
# def batch_PSNR(img, imclean, data_range):
|
||||
# Img = img.data.cpu().numpy().astype(np.float32)
|
||||
# Iclean = imclean.data.cpu().numpy().astype(np.float32)
|
||||
# PSNR = 0
|
||||
# for i in range(Img.shape[0]):
|
||||
# PSNR += compare_psnr(Iclean[i,:,:,:], Img[i,:,:,:], data_range=data_range)
|
||||
# return (PSNR/Img.shape[0])
|
||||
|
||||
def data_augmentation(image, mode):
|
||||
out = np.transpose(image, (1,2,0))
|
||||
if mode == 0:
|
||||
# original
|
||||
out = out
|
||||
elif mode == 1:
|
||||
# flip up and down
|
||||
out = np.flipud(out)
|
||||
elif mode == 2:
|
||||
# rotate counterwise 90 degree
|
||||
out = np.rot90(out)
|
||||
elif mode == 3:
|
||||
# rotate 90 degree and flip up and down
|
||||
out = np.rot90(out)
|
||||
out = np.flipud(out)
|
||||
elif mode == 4:
|
||||
# rotate 180 degree
|
||||
out = np.rot90(out, k=2)
|
||||
elif mode == 5:
|
||||
# rotate 180 degree and flip
|
||||
out = np.rot90(out, k=2)
|
||||
out = np.flipud(out)
|
||||
elif mode == 6:
|
||||
# rotate 270 degree
|
||||
out = np.rot90(out, k=3)
|
||||
elif mode == 7:
|
||||
# rotate 270 degree and flip
|
||||
out = np.rot90(out, k=3)
|
||||
out = np.flipud(out)
|
||||
return np.transpose(out, (2,0,1))
|
||||
|
||||
|
||||
def visual_va2np(Out, mode=1, ps=0, pss=1, scal=1, rescale=0, w=10, h=10, c=3, refill=0, refill_img=0, refill_ind=[0, 0]):
|
||||
if mode == 0 or mode == 1 or mode==3:
|
||||
out_numpy = Out.data.squeeze(0).cpu().numpy()
|
||||
elif mode == 2:
|
||||
out_numpy = Out.data.squeeze(1).cpu().numpy()
|
||||
if out_numpy.shape[0] == 1:
|
||||
out_numpy = np.tile(out_numpy, (3, 1, 1))
|
||||
if mode == 0 or mode == 1:
|
||||
out_numpy = (np.transpose(out_numpy, (1, 2, 0))) * 255.0 * scal
|
||||
else:
|
||||
out_numpy = (np.transpose(out_numpy, (1, 2, 0)))
|
||||
|
||||
if ps == 1:
|
||||
out_numpy = reverse_pixelshuffle(out_numpy, pss, refill, refill_img, refill_ind)
|
||||
if rescale == 1:
|
||||
out_numpy = cv2.resize(out_numpy, (h, w))
|
||||
#print(out_numpy.shape)
|
||||
return out_numpy
|
||||
|
||||
def temp_ps_4comb(Out, In):
|
||||
pass
|
||||
|
||||
def np2ts(x, mode=0): #now assume the input only has one channel which is ignored
|
||||
w, h, c= x.shape
|
||||
x_ts = x.transpose(2, 0, 1)
|
||||
x_ts = torch.from_numpy(x_ts).type(torch.FloatTensor)
|
||||
if mode == 0 or mode == 1:
|
||||
x_ts = x_ts.unsqueeze(0)
|
||||
elif mode == 2:
|
||||
x_ts = x_ts.unsqueeze(1)
|
||||
return x_ts
|
||||
|
||||
def np2ts_4d(x):
|
||||
x_ts = x.transpose(0, 3, 1, 2)
|
||||
x_ts = torch.from_numpy(x_ts).type(torch.FloatTensor)
|
||||
return x_ts
|
||||
|
||||
def get_salient_noise_in_maps(lm, thre = 0., chn=3):
|
||||
'''
|
||||
Description: To find out the most frequent estimated noise level in the images
|
||||
----------
|
||||
[Input]
|
||||
a multi-channel tensor of noise map
|
||||
|
||||
[Output]
|
||||
A list of noise level value
|
||||
'''
|
||||
lm_numpy = lm.data.cpu().numpy()
|
||||
lm_numpy = (np.transpose(lm_numpy, (0, 2, 3, 1)))
|
||||
nl_list = np.zeros((lm_numpy.shape[0], chn,1))
|
||||
for n in range(lm_numpy.shape[0]):
|
||||
for c in range(chn):
|
||||
selected_lm = np.reshape(lm_numpy[n,:,:,c], (lm_numpy.shape[1]*lm_numpy.shape[2], 1))
|
||||
selected_lm = selected_lm[selected_lm>thre]
|
||||
if selected_lm.shape[0] == 0:
|
||||
nl_list[n, c] = 0
|
||||
else:
|
||||
hist = np.histogram(selected_lm, density=True)
|
||||
nl_ind = np.argmax(hist[0])
|
||||
#print(nl_ind)
|
||||
#print(hist[0])
|
||||
#print(hist[1])
|
||||
nl = ( hist[1][nl_ind] + hist[1][nl_ind+1] ) / 2.
|
||||
nl_list[n, c] = nl
|
||||
return nl_list
|
||||
|
||||
def get_cdf_noise_in_maps(lm, thre=0.8, chn=3):
|
||||
'''
|
||||
Description: To find out the most frequent estimated noise level in the images
|
||||
----------
|
||||
[Input]
|
||||
a multi-channel tensor of noise map
|
||||
|
||||
[Output]
|
||||
A list of noise level value
|
||||
'''
|
||||
lm_numpy = lm.data.cpu().numpy()
|
||||
lm_numpy = (np.transpose(lm_numpy, (0, 2, 3, 1)))
|
||||
nl_list = np.zeros((lm_numpy.shape[0], chn,1))
|
||||
for n in range(lm_numpy.shape[0]):
|
||||
for c in range(chn):
|
||||
selected_lm = np.reshape(lm_numpy[n,:,:,c], (lm_numpy.shape[1]*lm_numpy.shape[2], 1))
|
||||
H, x = np.histogram(selected_lm, normed=True)
|
||||
dx = x[1]-x[0]
|
||||
F = np.cumsum(H)*dx
|
||||
F_ind = np.where(F>0.9)[0][0]
|
||||
nl_list[n, c] = x[F_ind]
|
||||
print(nl_list[n,c])
|
||||
return nl_list
|
||||
|
||||
def get_pdf_in_maps(lm, mark, chn=1):
|
||||
'''
|
||||
Description: get the noise estimation cdf of each channel
|
||||
----------
|
||||
[Input]
|
||||
a multi-channel tensor of noise map and channel dimension
|
||||
chn: the channel number for gaussian
|
||||
[Output]
|
||||
CDF function of each sample and each channel
|
||||
'''
|
||||
lm_numpy = lm.data.cpu().numpy()
|
||||
lm_numpy = (np.transpose(lm_numpy, (0, 2, 3, 1)))
|
||||
pdf_list = np.zeros((lm_numpy.shape[0], chn, 10))
|
||||
for n in range(lm_numpy.shape[0]):
|
||||
for c in range(chn):
|
||||
selected_lm = np.reshape(lm_numpy[n,:,:,c], (lm_numpy.shape[1]*lm_numpy.shape[2], 1))
|
||||
H, x = np.histogram(selected_lm, range=(0.,1.), bins=10, normed=True)
|
||||
dx = x[1]-x[0]
|
||||
F = H * dx
|
||||
pdf_list[n, c, :] = F
|
||||
#sio.savemat(mark + str(c) + '.mat',{'F':F})
|
||||
# plt.bar(range(10), F)
|
||||
#plt.savefig(mark + str(c) + '.png')
|
||||
# plt.close()
|
||||
return pdf_list
|
||||
|
||||
def get_pdf_matching_score(F1, F2):
|
||||
'''
|
||||
Description: Given two sets of CDF, get the overall matching score for each channel
|
||||
-----------
|
||||
[Input] F1, F2
|
||||
[Output] score for each channel
|
||||
'''
|
||||
return np.mean((F1-F2)**2)
|
||||
|
||||
def decide_scale_factor(noisy_image, estimation_model, color=1, thre = 0, plot_flag = 1, stopping = 4, mark=''):
|
||||
'''
|
||||
Description: Given a noisy image and the noise estimation model, keep multiscaling the image\\
|
||||
using pixel-shuffle methods, and estimate the pdf and cdf of AWGN channel
|
||||
Compare the changes of the density function and decide the optimal scaling factor
|
||||
------------
|
||||
[Input] noisy_image, estimation_model, plot_flag, stopping
|
||||
[Output] plot the middle vector
|
||||
score_seq: the matching score sequence between the two subsequent pdf
|
||||
opt_scale: the optimal scaling factor
|
||||
'''
|
||||
if color == 1:
|
||||
c = 3
|
||||
elif color == 0:
|
||||
c = 1
|
||||
score_seq = []
|
||||
Pre_CDF = None
|
||||
flag = 0
|
||||
for pss in range(1, stopping+1): #scaling factor from 1 to the limit
|
||||
noisy_image = pixelshuffle(noisy_image, pss)
|
||||
INoisy = np2ts(noisy_image, color)
|
||||
INoisy = Variable(INoisy.cuda(), volatile=True)
|
||||
EMap = torch.clamp(estimation_model(INoisy), 0., 1.)
|
||||
EPDF = get_pdf_in_maps(EMap, mark + str(pss), c)[0]
|
||||
if flag != 0:
|
||||
score = get_pdf_matching_score(EPDF, Pre_PDF) #TODO: How to match these two
|
||||
print(score)
|
||||
score_seq.append(score)
|
||||
if score <= thre:
|
||||
print('optimal scale is %d:' % (pss-1))
|
||||
return (pss-1, score_seq)
|
||||
Pre_PDF = EPDF
|
||||
flag = 1
|
||||
return (stopping, score_seq)
|
||||
|
||||
|
||||
|
||||
def get_max_noise_in_maps(lm, chn=3):
|
||||
'''
|
||||
Description: To find out the maximum level of noise level in the images
|
||||
----------
|
||||
[Input]
|
||||
a multi-channel tensor of noise map
|
||||
|
||||
[Output]
|
||||
A list of noise level value
|
||||
'''
|
||||
lm_numpy = lm.data.cpu().numpy()
|
||||
lm_numpy = (np.transpose(lm_numpy, (0, 2, 3, 1)))
|
||||
nl_list = np.zeros((lm_numpy.shape[0], chn, 1))
|
||||
for n in range(lm_numpy.shape[0]):
|
||||
for c in range(chn):
|
||||
nl = np.amax(lm_numpy[n, :, :, c])
|
||||
nl_list[n, c] = nl
|
||||
return nl_list
|
||||
|
||||
def get_smooth_maps(lm, dilk = 50, gsd = 10):
|
||||
'''
|
||||
Description: To return the refined maps after dilation and gaussian blur
|
||||
[Input] a multi-channel tensor of noise map
|
||||
[Output] a multi-channel tensor of refined noise map
|
||||
'''
|
||||
kernel = np.ones((dilk, dilk))
|
||||
lm_numpy = lm.data.squeeze(0).cpu().numpy()
|
||||
lm_numpy = (np.transpose(lm_numpy, (1, 2, 0)))
|
||||
ref_lm_numpy = lm_numpy.copy() #a refined map
|
||||
for c in range(lm_numpy.shape[2]):
|
||||
nmap = lm_numpy[:, :, c]
|
||||
nmap_dilation = cv2.dilate(nmap, kernel, iterations=1)
|
||||
ref_lm_numpy[:, :, c] = nmap_dilation
|
||||
#ref_lm_numpy[:, :, c] = scipy.ndimage.filters.gaussian_filter(nmap_dilation, gsd)
|
||||
RF_tensor = np2ts(ref_lm_numpy)
|
||||
RF_tensor = Variable(RF_tensor.cuda(),volatile=True)
|
||||
def zeroing_out_maps(lm, keep=0):
|
||||
'''
|
||||
Only Keep one channel and zero out other channels
|
||||
[Input] a multi-channel tensor of noise map
|
||||
[Output] a multi-channel tensor of noise map after zeroing out items
|
||||
'''
|
||||
lm_numpy = lm.data.squeeze(0).cpu().numpy()
|
||||
lm_numpy = (np.transpose(lm_numpy, (1, 2, 0)))
|
||||
ref_lm_numpy = lm_numpy.copy() #a refined map
|
||||
for c in range(lm_numpy.shape[2]):
|
||||
if np.isin(c,keep)==0:
|
||||
ref_lm_numpy[:, :, c] = 0.
|
||||
print(ref_lm_numpy)
|
||||
RF_tensor = np2ts(ref_lm_numpy)
|
||||
RF_tensor = Variable(RF_tensor.cuda(),volatile=True)
|
||||
return RF_tensor
|
||||
|
||||
def level_refine(NM_tensor, ref_mode, chn=3,cFlag=False):
|
||||
'''
|
||||
Description: To refine the estimated noise level maps
|
||||
[Input] the noise map tensor, and a refinement mode
|
||||
Mode:
|
||||
[0] Get the most salient (the most frequent estimated noise level)
|
||||
[1] Get the maximum value of noise level
|
||||
[2] Gaussian smooth the noise level map to make the regional estimation more smooth
|
||||
[3] Get the average maximum value of the noise level
|
||||
[5] Get the CDF thresholded value
|
||||
|
||||
[Output] a refined map tensor with four channels
|
||||
'''
|
||||
#RF_tensor = NM_tensor.clone() #get a clone version of NM tensor without changing the original one
|
||||
if ref_mode == 0 or ref_mode == 1 or ref_mode == 4 or ref_mode==5: #if we use a single value for the map
|
||||
if ref_mode == 0 or ref_mode == 4:
|
||||
nl_list = get_salient_noise_in_maps(NM_tensor, 0., chn)
|
||||
|
||||
if ref_mode == 4: #half the estimation
|
||||
nl_list = nl_list - nl_list
|
||||
print(nl_list)
|
||||
elif ref_mode == 1:
|
||||
nl_list = get_max_noise_in_maps(NM_tensor, chn)
|
||||
elif ref_mode == 5:
|
||||
nl_list = get_cdf_noise_in_maps(NM_tensor, 0.999, chn)
|
||||
|
||||
noise_map = np.zeros((NM_tensor.shape[0], chn, NM_tensor.size()[2], NM_tensor.size()[3])) #initialize the noise map before concatenating
|
||||
for n in range(NM_tensor.shape[0]):
|
||||
noise_map[n,:,:,:] = np.reshape(np.tile(nl_list[n], NM_tensor.size()[2] * NM_tensor.size()[3]),
|
||||
(chn, NM_tensor.size()[2], NM_tensor.size()[3]))
|
||||
RF_tensor = torch.from_numpy(noise_map).type(torch.FloatTensor)
|
||||
if torch.cuda.is_available() and not cFlag:
|
||||
RF_tensor = Variable(RF_tensor.cuda(),volatile=True)
|
||||
else:
|
||||
RF_tensor = Variable(RF_tensor,volatile=True)
|
||||
|
||||
elif ref_mode == 2:
|
||||
RF_tensor = get_smooth_maps(NM_tensor, 10, 5)
|
||||
elif ref_mode == 3:
|
||||
lb = get_salient_noise_in_maps(NM_tensor)
|
||||
up = get_max_noise_in_maps(NM_tensor)
|
||||
nl_list = ( lb + up ) * 0.5
|
||||
noise_map = np.zeros((1, chn, NM_tensor.size()[2], NM_tensor.size()[3])) #initialize the noise map before concatenating
|
||||
noise_map[0, :, :, :] = np.reshape(np.tile(nl_list, NM_tensor.size()[2] * NM_tensor.size()[3]),
|
||||
(chn, NM_tensor.size()[2], NM_tensor.size()[3]))
|
||||
RF_tensor = torch.from_numpy(noise_map).type(torch.FloatTensor)
|
||||
RF_tensor = Variable(RF_tensor.cuda(),volatile=True)
|
||||
|
||||
|
||||
return (RF_tensor, nl_list)
|
||||
|
||||
def normalize(a, len_v, min_v, max_v):
|
||||
'''
|
||||
normalize the sequence of factors
|
||||
'''
|
||||
norm_a = np.reshape(a, (len_v,1))
|
||||
norm_a = (norm_a - float(min_v)) / float(max_v - min_v)
|
||||
return norm_a
|
||||
|
||||
def generate_training_noisy_image(current_image, s_or_m, limit_set, c, val=0):
|
||||
noise_level_list = np.zeros((c, 1))
|
||||
if s_or_m == 0: #single noise type
|
||||
if val == 0:
|
||||
for chn in range(c):
|
||||
noise_level_list[chn] = np.random.uniform(limit_set[0][0], limit_set[0][1])
|
||||
elif val == 1:
|
||||
for chn in range(c):
|
||||
noise_level_list[chn] = 35
|
||||
noisy_img = generate_noisy(current_image, 0, noise_level_list /255.)
|
||||
|
||||
return (noisy_img, noise_level_list)
|
||||
|
||||
def generate_ground_truth_noise_map(noise_map, n, noise_level_list, limit_set, c, pn, pw, ph):
|
||||
for chn in range(c):
|
||||
noise_level_list[chn] = normalize(noise_level_list[chn], 1, limit_set[0][0], limit_set[0][1]) #normalize the level value
|
||||
noise_map[n, :, :, :] = np.reshape(np.tile(noise_level_list, pw * ph), (c, pw, ph)) #total number of channels
|
||||
return noise_map
|
||||
|
||||
#Add noise to the original images
|
||||
def generate_noisy(image, noise_type, noise_level_list=0, sigma_s=20, sigma_c=40):
|
||||
'''
|
||||
Description: To generate noisy images of different types
|
||||
----------
|
||||
[Input]
|
||||
image : ndarray of float type: [0,1] just one image, current support gray or color image input (w,h,c)
|
||||
noise_type: 0,1,2,3
|
||||
noise_level_list: pre-defined noise level for each channel, without normalization: only information of 3 channels
|
||||
[0]'AWGN' Multi-channel Gaussian-distributed additive noise
|
||||
[1]'RVIN' Replaces random pixels with 0 or 1. noise_level: ratio of the occupation of the changed pixels
|
||||
[2]'Gaussian-Poisson' GP noise approximator, the combinatin of signal-dependent and signal independent noise
|
||||
[Output]
|
||||
A noisy image
|
||||
'''
|
||||
w, h, c = image.shape
|
||||
#Some unused noise type: Poisson and Uniform
|
||||
#if noise_type == *:
|
||||
#vals = len(np.unique(image))
|
||||
#vals = 2 ** np.ceil(np.log2(vals))
|
||||
#noisy = np.random.poisson(image * vals) / float(vals)
|
||||
|
||||
#if noise_type == *:
|
||||
#uni = np.random.uniform(-factor,factor,(w, h, c))
|
||||
#uni = uni.reshape(w, h, c)
|
||||
#noisy = image + uni
|
||||
|
||||
noisy = image.copy()
|
||||
|
||||
if noise_type == 0: #MC-AWGN model
|
||||
gauss = np.zeros((w, h, c))
|
||||
for chn in range(c):
|
||||
gauss[:,:,chn] = np.random.normal(0, noise_level_list[chn], (w, h))
|
||||
noisy = image + gauss
|
||||
elif noise_type == 1: #MC-RVIN model
|
||||
for chn in range(c): #process each channel separately
|
||||
prob_map = np.random.uniform(0.0, 1.0, (w, h))
|
||||
noise_map = np.random.uniform(0.0, 1.0, (w, h))
|
||||
noisy_chn = noisy[: , :, chn]
|
||||
noisy_chn[ prob_map < noise_level_list[chn] ] = noise_map[ prob_map < noise_level_list[chn] ]
|
||||
|
||||
elif noise_type == 2:
|
||||
#sigma_s = np.random.uniform(0.0, 0.16, (3,))
|
||||
#sigma_c = np.random.uniform(0.0, 0.06, (3,))
|
||||
sigma_c = [sigma_c]*3
|
||||
sigma_s = [sigma_s]*3
|
||||
sigma_s = np.reshape(sigma_s, (1, 1, c)) #reshape the sigma factor to [1,1,c] to multiply with the image
|
||||
noise_s_map = np.multiply(sigma_s, image) #according to x or temp_x?? (according to clean image or irradience)
|
||||
#print(noise_s_map) # different from the official code, here we use the original clean image x to compute the variance
|
||||
noise_s = np.random.randn(w, h, c) * noise_s_map #use the new variance to shift the normal distribution
|
||||
noisy = image + noise_s
|
||||
#add signal_independent noise to L
|
||||
noise_c = np.zeros((w, h, c))
|
||||
for chn in range(3):
|
||||
noise_c [:, :, chn] = np.random.normal(0, sigma_c[chn], (w, h))
|
||||
noisy = noisy + noise_c
|
||||
|
||||
return noisy
|
||||
|
||||
|
||||
#generate AWGN-RVIN noise together
|
||||
def generate_comp_noisy(image, noise_level_list):
|
||||
|
||||
'''
|
||||
Description: To generate mixed AWGN and RVIN noise together
|
||||
----------
|
||||
[Input]
|
||||
image: a float image between [0,1]
|
||||
noise_level_list: AWGN and RVIN noise level
|
||||
[Output]
|
||||
A noisy image
|
||||
'''
|
||||
w, h, c = image.shape
|
||||
noisy = image.copy()
|
||||
for chn in range(c):
|
||||
mix_thre = noise_level_list[c+chn] #get the mix ratio of AWGN and RVIN
|
||||
gau_std = noise_level_list[chn] #get the gaussian std
|
||||
prob_map = np.random.uniform( 0, 1, (w, h) ) #the prob map
|
||||
noise_map = np.random.uniform( 0, 1, (w, h) ) #the noisy map
|
||||
noisy_chn = noisy[: ,: ,chn]
|
||||
noisy_chn[prob_map < mix_thre ] = noise_map[prob_map < mix_thre ]
|
||||
gauss = np.random.normal(0, gau_std, (w, h))
|
||||
noisy_chn[prob_map >= mix_thre ] = noisy_chn[prob_map >= mix_thre ] + gauss[prob_map >= mix_thre]
|
||||
|
||||
return noisy
|
||||
|
||||
def generate_denoise(image, model, noise_level_list):
|
||||
'''
|
||||
Description: Generate Denoised Blur Images
|
||||
----------
|
||||
[Input]
|
||||
image:
|
||||
model:
|
||||
noise_level_list:
|
||||
|
||||
[Output]
|
||||
A blur image patch
|
||||
'''
|
||||
#input images
|
||||
ISource = np2ts(image)
|
||||
ISource = torch.clamp(ISource, 0., 1.)
|
||||
ISource = Variable(ISource.cuda(),volatile=True)
|
||||
#input denoise conditions
|
||||
noise_map = np.zeros((1, 6, image.shape[0], image.shape[1])) #initialize the noise map before concatenating
|
||||
noise_map[0, :, :, :] = np.reshape(np.tile(noise_level_list, image.shape[0] * image.shape[1]), (6, image.shape[0], image.shape[1]))
|
||||
NM_tensor = torch.from_numpy(noise_map).type(torch.FloatTensor)
|
||||
NM_tensor = Variable(NM_tensor.cuda(),volatile=True)
|
||||
#generate blur images
|
||||
Res = model(ISource, NM_tensor)
|
||||
Out = torch.clamp(ISource-Res, 0., 1.)
|
||||
out_numpy = Out.data.squeeze(0).cpu().numpy()
|
||||
out_numpy = np.transpose(out_numpy, (1, 2, 0))
|
||||
return out_numpy
|
||||
|
||||
|
||||
#TODO: two pixel shuffle functions to process the images
|
||||
def pixelshuffle(image, scale):
|
||||
'''
|
||||
Discription: Given an image, return a reversible sub-sampling
|
||||
[Input]: Image ndarray float
|
||||
[Return]: A mosic image of shuffled pixels
|
||||
'''
|
||||
if scale == 1:
|
||||
return image
|
||||
w, h ,c = image.shape
|
||||
mosaic = np.array([])
|
||||
for ws in range(scale):
|
||||
band = np.array([])
|
||||
for hs in range(scale):
|
||||
temp = image[ws::scale, hs::scale, :] #get the sub-sampled image
|
||||
band = np.concatenate((band, temp), axis = 1) if band.size else temp
|
||||
mosaic = np.concatenate((mosaic, band), axis = 0) if mosaic.size else band
|
||||
return mosaic
|
||||
|
||||
def reverse_pixelshuffle(image, scale, fill=0, fill_image=0, ind=[0,0]):
|
||||
'''
|
||||
Discription: Given a mosaic image of subsampling, recombine it to a full image
|
||||
[Input]: Image
|
||||
[Return]: Recombine it using different portions of pixels
|
||||
'''
|
||||
w, h, c = image.shape
|
||||
real = np.zeros((w, h, c)) #real image
|
||||
wf = 0
|
||||
hf = 0
|
||||
for ws in range(scale):
|
||||
hf = 0
|
||||
for hs in range(scale):
|
||||
temp = real[ws::scale, hs::scale, :]
|
||||
wc, hc, cc = temp.shape #get the shpae of the current images
|
||||
if fill==1 and ws==ind[0] and hs==ind[1]:
|
||||
real[ws::scale, hs::scale, :] = fill_image[wf:wf+wc, hf:hf+hc, :]
|
||||
else:
|
||||
real[ws::scale, hs::scale, :] = image[wf:wf+wc, hf:hf+hc, :]
|
||||
hf = hf + hc
|
||||
wf = wf + wc
|
||||
return real
|
||||
|
||||
def scal2map(level, h, w, min_v=0., max_v=255.):
|
||||
'''
|
||||
Change a single normalized noise level value to a map
|
||||
[Input]: level: a scaler noise level(0-1), h, w
|
||||
[Return]: a pytorch tensor of the cacatenated noise level map
|
||||
'''
|
||||
#get a tensor from the input level
|
||||
level_tensor = torch.from_numpy(np.reshape(level, (1,1))).type(torch.FloatTensor)
|
||||
#make the noise level to a map
|
||||
level_tensor = level_tensor.view(stdN_tensor.size(0), stdN_tensor.size(1), 1, 1)
|
||||
level_tensor = level_tensor.repeat(1, 1, h, w)
|
||||
return level_tensor
|
||||
|
||||
def scal2map_spatial(level1, level2, h, w):
|
||||
stdN_t1 = scal2map(level1, int(h/2), w)
|
||||
stdN_t2 = scal2map(level2, h-int(h/2), w)
|
||||
stdN_tensor = torch.cat([stdN_t1, stdN_t2], dim=2)
|
||||
return stdN_tensor
|
@ -0,0 +1,104 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import glob
|
||||
import random
|
||||
import cv2
|
||||
|
||||
random.seed(1143)
|
||||
|
||||
|
||||
def populate_train_list(orig_images_path, hazy_images_path):
|
||||
|
||||
|
||||
train_list = []
|
||||
val_list = []
|
||||
|
||||
image_list_haze = glob.glob(hazy_images_path + "*.jpg")
|
||||
|
||||
tmp_dict = {}
|
||||
|
||||
for image in image_list_haze:
|
||||
image = image.split("/")[-1]
|
||||
key = image.split("_")[0] + "_" + image.split("_")[1] + ".jpg"
|
||||
if key in tmp_dict.keys():
|
||||
tmp_dict[key].append(image)
|
||||
else:
|
||||
tmp_dict[key] = []
|
||||
tmp_dict[key].append(image)
|
||||
|
||||
|
||||
train_keys = []
|
||||
val_keys = []
|
||||
|
||||
len_keys = len(tmp_dict.keys())
|
||||
for i in range(len_keys):
|
||||
if i < len_keys*9/10:
|
||||
train_keys.append(list(tmp_dict.keys())[i])
|
||||
else:
|
||||
val_keys.append(list(tmp_dict.keys())[i])
|
||||
|
||||
|
||||
for key in list(tmp_dict.keys()):
|
||||
|
||||
if key in train_keys:
|
||||
for hazy_image in tmp_dict[key]:
|
||||
|
||||
train_list.append([orig_images_path + key, hazy_images_path + hazy_image])
|
||||
|
||||
|
||||
else:
|
||||
for hazy_image in tmp_dict[key]:
|
||||
|
||||
val_list.append([orig_images_path + key, hazy_images_path + hazy_image])
|
||||
|
||||
|
||||
|
||||
random.shuffle(train_list)
|
||||
random.shuffle(val_list)
|
||||
|
||||
return train_list, val_list
|
||||
|
||||
|
||||
|
||||
class dehazing_loader(data.Dataset):
|
||||
|
||||
def __init__(self, orig_images_path, hazy_images_path, mode='train'):
|
||||
|
||||
self.train_list, self.val_list = populate_train_list(orig_images_path, hazy_images_path)
|
||||
|
||||
if mode == 'train':
|
||||
self.data_list = self.train_list
|
||||
print("Total training examples:", len(self.train_list))
|
||||
else:
|
||||
self.data_list = self.val_list
|
||||
print("Total validation examples:", len(self.val_list))
|
||||
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
data_orig_path, data_hazy_path = self.data_list[index]
|
||||
|
||||
data_orig = Image.open(data_orig_path)
|
||||
data_hazy = Image.open(data_hazy_path)
|
||||
|
||||
data_orig = data_orig.resize((480,640), Image.ANTIALIAS)
|
||||
data_hazy = data_hazy.resize((480,640), Image.ANTIALIAS)
|
||||
|
||||
data_orig = (np.asarray(data_orig)/255.0)
|
||||
data_hazy = (np.asarray(data_hazy)/255.0)
|
||||
|
||||
data_orig = torch.from_numpy(data_orig).float()
|
||||
data_hazy = torch.from_numpy(data_hazy).float()
|
||||
|
||||
return data_orig.permute(2,0,1), data_hazy.permute(2,0,1)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_list)
|
||||
|
@ -0,0 +1,50 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
class dehaze_net(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(dehaze_net, self).__init__()
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
self.e_conv1 = nn.Conv2d(3,3,1,1,0,bias=True)
|
||||
self.e_conv2 = nn.Conv2d(3,3,3,1,1,bias=True)
|
||||
self.e_conv3 = nn.Conv2d(6,3,5,1,2,bias=True)
|
||||
self.e_conv4 = nn.Conv2d(6,3,7,1,3,bias=True)
|
||||
self.e_conv5 = nn.Conv2d(12,3,3,1,1,bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
source = []
|
||||
source.append(x)
|
||||
|
||||
x1 = self.relu(self.e_conv1(x))
|
||||
x2 = self.relu(self.e_conv2(x1))
|
||||
|
||||
concat1 = torch.cat((x1,x2), 1)
|
||||
x3 = self.relu(self.e_conv3(concat1))
|
||||
|
||||
concat2 = torch.cat((x2, x3), 1)
|
||||
x4 = self.relu(self.e_conv4(concat2))
|
||||
|
||||
concat3 = torch.cat((x1,x2,x3,x4),1)
|
||||
x5 = self.relu(self.e_conv5(concat3))
|
||||
|
||||
clean_image = self.relu((x5 * x) - x5 + 1)
|
||||
|
||||
return clean_image
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -0,0 +1,117 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from rife_model.warplayer import warp
|
||||
|
||||
|
||||
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=False),
|
||||
nn.BatchNorm2d(out_planes),
|
||||
)
|
||||
|
||||
|
||||
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=False),
|
||||
nn.BatchNorm2d(out_planes),
|
||||
nn.PReLU(out_planes)
|
||||
)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, stride=1):
|
||||
super(ResBlock, self).__init__()
|
||||
if in_planes == out_planes and stride == 1:
|
||||
self.conv0 = nn.Identity()
|
||||
else:
|
||||
self.conv0 = nn.Conv2d(in_planes, out_planes,
|
||||
3, stride, 1, bias=False)
|
||||
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
|
||||
self.conv2 = conv_wo_act(out_planes, out_planes, 3, 1, 1)
|
||||
self.relu1 = nn.PReLU(1)
|
||||
self.relu2 = nn.PReLU(out_planes)
|
||||
self.fc1 = nn.Conv2d(out_planes, 16, kernel_size=1, bias=False)
|
||||
self.fc2 = nn.Conv2d(16, out_planes, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.conv0(x)
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
w = x.mean(3, True).mean(2, True)
|
||||
w = self.relu1(self.fc1(w))
|
||||
w = torch.sigmoid(self.fc2(w))
|
||||
x = self.relu2(x * w + y)
|
||||
return x
|
||||
|
||||
|
||||
class IFBlock(nn.Module):
|
||||
def __init__(self, in_planes, scale=1, c=64):
|
||||
super(IFBlock, self).__init__()
|
||||
self.scale = scale
|
||||
self.conv0 = conv(in_planes, c, 3, 2, 1)
|
||||
self.res0 = ResBlock(c, c)
|
||||
self.res1 = ResBlock(c, c)
|
||||
self.res2 = ResBlock(c, c)
|
||||
self.res3 = ResBlock(c, c)
|
||||
self.res4 = ResBlock(c, c)
|
||||
self.res5 = ResBlock(c, c)
|
||||
self.conv1 = nn.Conv2d(c, 8, 3, 1, 1)
|
||||
self.up = nn.PixelShuffle(2)
|
||||
|
||||
def forward(self, x):
|
||||
if self.scale != 1:
|
||||
x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear",
|
||||
align_corners=False)
|
||||
x = self.conv0(x)
|
||||
x = self.res0(x)
|
||||
x = self.res1(x)
|
||||
x = self.res2(x)
|
||||
x = self.res3(x)
|
||||
x = self.res4(x)
|
||||
x = self.res5(x)
|
||||
x = self.conv1(x)
|
||||
flow = self.up(x)
|
||||
if self.scale != 1:
|
||||
flow = F.interpolate(flow, scale_factor=self.scale, mode="bilinear",
|
||||
align_corners=False)
|
||||
return flow
|
||||
|
||||
|
||||
class IFNet(nn.Module):
|
||||
def __init__(self, cFlag):
|
||||
super(IFNet, self).__init__()
|
||||
self.block0 = IFBlock(6, scale=4, c=192)
|
||||
self.block1 = IFBlock(8, scale=2, c=128)
|
||||
self.block2 = IFBlock(8, scale=1, c=64)
|
||||
self.cFlag = cFlag
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x, scale_factor=0.5, mode="bilinear",
|
||||
align_corners=False)
|
||||
flow0 = self.block0(x)
|
||||
F1 = flow0
|
||||
warped_img0 = warp(x[:, :3], F1, self.cFlag)
|
||||
warped_img1 = warp(x[:, 3:], -F1, self.cFlag)
|
||||
flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1), 1))
|
||||
F2 = (flow0 + flow1)
|
||||
warped_img0 = warp(x[:, :3], F2, self.cFlag)
|
||||
warped_img1 = warp(x[:, 3:], -F2, self.cFlag)
|
||||
flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2), 1))
|
||||
F3 = (flow0 + flow1 + flow2)
|
||||
return F3, [F1, F2, F3]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
|
||||
img1 = torch.tensor(np.random.normal(
|
||||
0, 1, (3, 3, 256, 256))).float().to(device)
|
||||
imgs = torch.cat((img0, img1), 1)
|
||||
flownet = IFNet()
|
||||
flow, _ = flownet(imgs)
|
||||
print(flow.shape)
|
@ -0,0 +1,115 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from model.warplayer import warp
|
||||
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=False),
|
||||
nn.BatchNorm2d(out_planes),
|
||||
)
|
||||
|
||||
|
||||
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=False),
|
||||
nn.BatchNorm2d(out_planes),
|
||||
nn.PReLU(out_planes)
|
||||
)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, stride=1):
|
||||
super(ResBlock, self).__init__()
|
||||
if in_planes == out_planes and stride == 1:
|
||||
self.conv0 = nn.Identity()
|
||||
else:
|
||||
self.conv0 = nn.Conv2d(in_planes, out_planes,
|
||||
3, stride, 1, bias=False)
|
||||
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
|
||||
self.conv2 = conv_wo_act(out_planes, out_planes, 3, 1, 1)
|
||||
self.relu1 = nn.PReLU(1)
|
||||
self.relu2 = nn.PReLU(out_planes)
|
||||
self.fc1 = nn.Conv2d(out_planes, 16, kernel_size=1, bias=False)
|
||||
self.fc2 = nn.Conv2d(16, out_planes, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.conv0(x)
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
w = x.mean(3, True).mean(2, True)
|
||||
w = self.relu1(self.fc1(w))
|
||||
w = torch.sigmoid(self.fc2(w))
|
||||
x = self.relu2(x * w + y)
|
||||
return x
|
||||
|
||||
|
||||
class IFBlock(nn.Module):
|
||||
def __init__(self, in_planes, scale=1, c=64):
|
||||
super(IFBlock, self).__init__()
|
||||
self.scale = scale
|
||||
self.conv0 = conv(in_planes, c, 3, 1, 1)
|
||||
self.res0 = ResBlock(c, c)
|
||||
self.res1 = ResBlock(c, c)
|
||||
self.res2 = ResBlock(c, c)
|
||||
self.res3 = ResBlock(c, c)
|
||||
self.res4 = ResBlock(c, c)
|
||||
self.res5 = ResBlock(c, c)
|
||||
self.conv1 = nn.Conv2d(c, 2, 3, 1, 1)
|
||||
self.up = nn.PixelShuffle(2)
|
||||
|
||||
def forward(self, x):
|
||||
if self.scale != 1:
|
||||
x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear",
|
||||
align_corners=False)
|
||||
x = self.conv0(x)
|
||||
x = self.res0(x)
|
||||
x = self.res1(x)
|
||||
x = self.res2(x)
|
||||
x = self.res3(x)
|
||||
x = self.res4(x)
|
||||
x = self.res5(x)
|
||||
x = self.conv1(x)
|
||||
flow = x # self.up(x)
|
||||
if self.scale != 1:
|
||||
flow = F.interpolate(flow, scale_factor=self.scale, mode="bilinear",
|
||||
align_corners=False)
|
||||
return flow
|
||||
|
||||
|
||||
class IFNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(IFNet, self).__init__()
|
||||
self.block0 = IFBlock(6, scale=4, c=192)
|
||||
self.block1 = IFBlock(8, scale=2, c=128)
|
||||
self.block2 = IFBlock(8, scale=1, c=64)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x, scale_factor=0.5, mode="bilinear",
|
||||
align_corners=False)
|
||||
flow0 = self.block0(x)
|
||||
F1 = flow0
|
||||
warped_img0 = warp(x[:, :3], F1)
|
||||
warped_img1 = warp(x[:, 3:], -F1)
|
||||
flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1), 1))
|
||||
F2 = (flow0 + flow1)
|
||||
warped_img0 = warp(x[:, :3], F2)
|
||||
warped_img1 = warp(x[:, 3:], -F2)
|
||||
flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2), 1))
|
||||
F3 = (flow0 + flow1 + flow2)
|
||||
return F3, [F1, F2, F3]
|
||||
|
||||
if __name__ == '__main__':
|
||||
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
|
||||
img1 = torch.tensor(np.random.normal(
|
||||
0, 1, (3, 3, 256, 256))).float().to(device)
|
||||
imgs = torch.cat((img0, img1), 1)
|
||||
flownet = IFNet()
|
||||
flow, _ = flownet(imgs)
|
||||
print(flow.shape)
|
@ -0,0 +1,262 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from torch.optim import AdamW
|
||||
import torch.optim as optim
|
||||
import itertools
|
||||
from rife_model.warplayer import warp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from rife_model.IFNet import *
|
||||
import torch.nn.functional as F
|
||||
from rife_model.loss import *
|
||||
|
||||
|
||||
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=True),
|
||||
nn.PReLU(out_planes)
|
||||
)
|
||||
|
||||
|
||||
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
||||
return nn.Sequential(
|
||||
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes,
|
||||
kernel_size=4, stride=2, padding=1, bias=True),
|
||||
nn.PReLU(out_planes)
|
||||
)
|
||||
|
||||
|
||||
def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=True),
|
||||
)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, stride=2):
|
||||
super(ResBlock, self).__init__()
|
||||
if in_planes == out_planes and stride == 1:
|
||||
self.conv0 = nn.Identity()
|
||||
else:
|
||||
self.conv0 = nn.Conv2d(in_planes, out_planes,
|
||||
3, stride, 1, bias=False)
|
||||
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
|
||||
self.conv2 = conv_woact(out_planes, out_planes, 3, 1, 1)
|
||||
self.relu1 = nn.PReLU(1)
|
||||
self.relu2 = nn.PReLU(out_planes)
|
||||
self.fc1 = nn.Conv2d(out_planes, 16, kernel_size=1, bias=False)
|
||||
self.fc2 = nn.Conv2d(16, out_planes, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.conv0(x)
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
w = x.mean(3, True).mean(2, True)
|
||||
w = self.relu1(self.fc1(w))
|
||||
w = torch.sigmoid(self.fc2(w))
|
||||
x = self.relu2(x * w + y)
|
||||
return x
|
||||
|
||||
|
||||
c = 16
|
||||
|
||||
|
||||
class ContextNet(nn.Module):
|
||||
def __init__(self, cFlag):
|
||||
super(ContextNet, self).__init__()
|
||||
self.conv1 = ResBlock(3, c)
|
||||
self.conv2 = ResBlock(c, 2 * c)
|
||||
self.conv3 = ResBlock(2 * c, 4 * c)
|
||||
self.conv4 = ResBlock(4 * c, 8 * c)
|
||||
self.cFlag = cFlag
|
||||
|
||||
def forward(self, x, flow):
|
||||
x = self.conv1(x)
|
||||
f1 = warp(x, flow, self.cFlag)
|
||||
x = self.conv2(x)
|
||||
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
|
||||
align_corners=False) * 0.5
|
||||
f2 = warp(x, flow, self.cFlag)
|
||||
x = self.conv3(x)
|
||||
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
|
||||
align_corners=False) * 0.5
|
||||
f3 = warp(x, flow, self.cFlag)
|
||||
x = self.conv4(x)
|
||||
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
|
||||
align_corners=False) * 0.5
|
||||
f4 = warp(x, flow, self.cFlag)
|
||||
return [f1, f2, f3, f4]
|
||||
|
||||
|
||||
class FusionNet(nn.Module):
|
||||
def __init__(self, cFlag):
|
||||
super(FusionNet, self).__init__()
|
||||
self.down0 = ResBlock(8, 2 * c)
|
||||
self.down1 = ResBlock(4 * c, 4 * c)
|
||||
self.down2 = ResBlock(8 * c, 8 * c)
|
||||
self.down3 = ResBlock(16 * c, 16 * c)
|
||||
self.up0 = deconv(32 * c, 8 * c)
|
||||
self.up1 = deconv(16 * c, 4 * c)
|
||||
self.up2 = deconv(8 * c, 2 * c)
|
||||
self.up3 = deconv(4 * c, c)
|
||||
self.conv = nn.Conv2d(c, 4, 3, 1, 1)
|
||||
self.cFlag = cFlag
|
||||
|
||||
def forward(self, img0, img1, flow, c0, c1, flow_gt):
|
||||
warped_img0 = warp(img0, flow, self.cFlag)
|
||||
warped_img1 = warp(img1, -flow, self.cFlag)
|
||||
if flow_gt == None:
|
||||
warped_img0_gt, warped_img1_gt = None, None
|
||||
else:
|
||||
warped_img0_gt = warp(img0, flow_gt[:, :2])
|
||||
warped_img1_gt = warp(img1, flow_gt[:, 2:4])
|
||||
s0 = self.down0(torch.cat((warped_img0, warped_img1, flow), 1))
|
||||
s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
|
||||
s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
|
||||
s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
|
||||
x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
|
||||
x = self.up1(torch.cat((x, s2), 1))
|
||||
x = self.up2(torch.cat((x, s1), 1))
|
||||
x = self.up3(torch.cat((x, s0), 1))
|
||||
x = self.conv(x)
|
||||
return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self, c_flag, local_rank=-1):
|
||||
self.flownet = IFNet(c_flag)
|
||||
self.contextnet = ContextNet(c_flag)
|
||||
self.fusionnet = FusionNet(c_flag)
|
||||
self.device(c_flag)
|
||||
self.optimG = AdamW(itertools.chain(
|
||||
self.flownet.parameters(),
|
||||
self.contextnet.parameters(),
|
||||
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5)
|
||||
self.schedulerG = optim.lr_scheduler.CyclicLR(
|
||||
self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False)
|
||||
self.epe = EPE()
|
||||
self.ter = Ternary(c_flag)
|
||||
self.sobel = SOBEL(c_flag)
|
||||
if local_rank != -1:
|
||||
self.flownet = DDP(self.flownet, device_ids=[
|
||||
local_rank], output_device=local_rank)
|
||||
self.contextnet = DDP(self.contextnet, device_ids=[
|
||||
local_rank], output_device=local_rank)
|
||||
self.fusionnet = DDP(self.fusionnet, device_ids=[
|
||||
local_rank], output_device=local_rank)
|
||||
|
||||
def train(self):
|
||||
self.flownet.train()
|
||||
self.contextnet.train()
|
||||
self.fusionnet.train()
|
||||
|
||||
def eval(self):
|
||||
self.flownet.eval()
|
||||
self.contextnet.eval()
|
||||
self.fusionnet.eval()
|
||||
|
||||
def device(self, c_flag):
|
||||
if torch.cuda.is_available() and not c_flag:
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
self.flownet.to(device)
|
||||
self.contextnet.to(device)
|
||||
self.fusionnet.to(device)
|
||||
|
||||
def load_model(self, path, rank=0):
|
||||
def convert(param):
|
||||
return {
|
||||
k.replace("module.", ""): v
|
||||
for k, v in param.items()
|
||||
if "module." in k
|
||||
}
|
||||
|
||||
if rank == 0:
|
||||
self.flownet.load_state_dict(
|
||||
convert(torch.load('{}/flownet.pkl'.format(path), map_location=torch.device("cpu"))))
|
||||
self.contextnet.load_state_dict(
|
||||
convert(torch.load('{}/contextnet.pkl'.format(path), map_location=torch.device("cpu"))))
|
||||
self.fusionnet.load_state_dict(
|
||||
convert(torch.load('{}/unet.pkl'.format(path), map_location=torch.device("cpu"))))
|
||||
|
||||
def save_model(self, path, rank=0):
|
||||
if rank == 0:
|
||||
torch.save(self.flownet.state_dict(),
|
||||
'{}/flownet.pkl'.format(path))
|
||||
torch.save(self.contextnet.state_dict(),
|
||||
'{}/contextnet.pkl'.format(path))
|
||||
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))
|
||||
|
||||
def predict(self, imgs, flow, training=True, flow_gt=None):
|
||||
img0 = imgs[:, :3]
|
||||
img1 = imgs[:, 3:]
|
||||
c0 = self.contextnet(img0, flow)
|
||||
c1 = self.contextnet(img1, -flow)
|
||||
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
|
||||
align_corners=False) * 2.0
|
||||
refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet(
|
||||
img0, img1, flow, c0, c1, flow_gt)
|
||||
res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
|
||||
mask = torch.sigmoid(refine_output[:, 3:4])
|
||||
merged_img = warped_img0 * mask + warped_img1 * (1 - mask)
|
||||
pred = merged_img + res
|
||||
pred = torch.clamp(pred, 0, 1)
|
||||
if training:
|
||||
return pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
|
||||
else:
|
||||
return pred
|
||||
|
||||
def inference(self, img0, img1):
|
||||
imgs = torch.cat((img0, img1), 1)
|
||||
flow, _ = self.flownet(imgs)
|
||||
return self.predict(imgs, flow, training=False).detach()
|
||||
|
||||
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
|
||||
for param_group in self.optimG.param_groups:
|
||||
param_group['lr'] = learning_rate
|
||||
if training:
|
||||
self.train()
|
||||
else:
|
||||
self.eval()
|
||||
flow, flow_list = self.flownet(imgs)
|
||||
pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.predict(
|
||||
imgs, flow, flow_gt=flow_gt)
|
||||
loss_ter = self.ter(pred, gt).mean()
|
||||
if training:
|
||||
with torch.no_grad():
|
||||
loss_flow = torch.abs(warped_img0_gt - gt).mean()
|
||||
loss_mask = torch.abs(
|
||||
merged_img - gt).sum(1, True).float().detach()
|
||||
loss_mask = F.interpolate(loss_mask, scale_factor=0.5, mode="bilinear",
|
||||
align_corners=False).detach()
|
||||
flow_gt = (F.interpolate(flow_gt, scale_factor=0.5, mode="bilinear",
|
||||
align_corners=False) * 0.5).detach()
|
||||
loss_cons = 0
|
||||
for i in range(3):
|
||||
loss_cons += self.epe(flow_list[i], flow_gt[:, :2], 1)
|
||||
loss_cons += self.epe(-flow_list[i], flow_gt[:, 2:4], 1)
|
||||
loss_cons = loss_cons.mean() * 0.01
|
||||
else:
|
||||
loss_cons = torch.tensor([0])
|
||||
loss_flow = torch.abs(warped_img0 - gt).mean()
|
||||
loss_mask = 1
|
||||
loss_l1 = (((pred - gt) ** 2 + 1e-6) ** 0.5).mean()
|
||||
if training:
|
||||
self.optimG.zero_grad()
|
||||
loss_G = loss_l1 + loss_cons + loss_ter
|
||||
loss_G.backward()
|
||||
self.optimG.step()
|
||||
return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
|
||||
img1 = torch.tensor(np.random.normal(
|
||||
0, 1, (3, 3, 256, 256))).float().to(device)
|
||||
imgs = torch.cat((img0, img1), 1)
|
||||
model = Model()
|
||||
model.eval()
|
||||
print(model.inference(imgs).shape)
|
@ -0,0 +1,250 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from torch.optim import AdamW
|
||||
import torch.optim as optim
|
||||
import itertools
|
||||
from model.warplayer import warp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from model.IFNet2F import *
|
||||
import torch.nn.functional as F
|
||||
from model.loss import *
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=True),
|
||||
nn.PReLU(out_planes)
|
||||
)
|
||||
|
||||
|
||||
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
|
||||
return nn.Sequential(
|
||||
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes,
|
||||
kernel_size=4, stride=2, padding=1, bias=True),
|
||||
nn.PReLU(out_planes)
|
||||
)
|
||||
|
||||
def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=True),
|
||||
)
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_planes, out_planes, stride=2):
|
||||
super(ResBlock, self).__init__()
|
||||
if in_planes == out_planes and stride == 1:
|
||||
self.conv0 = nn.Identity()
|
||||
else:
|
||||
self.conv0 = nn.Conv2d(in_planes, out_planes,
|
||||
3, stride, 1, bias=False)
|
||||
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
|
||||
self.conv2 = conv_woact(out_planes, out_planes, 3, 1, 1)
|
||||
self.relu1 = nn.PReLU(1)
|
||||
self.relu2 = nn.PReLU(out_planes)
|
||||
self.fc1 = nn.Conv2d(out_planes, 16, kernel_size=1, bias=False)
|
||||
self.fc2 = nn.Conv2d(16, out_planes, kernel_size=1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.conv0(x)
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
w = x.mean(3, True).mean(2, True)
|
||||
w = self.relu1(self.fc1(w))
|
||||
w = torch.sigmoid(self.fc2(w))
|
||||
x = self.relu2(x * w + y)
|
||||
return x
|
||||
|
||||
c = 16
|
||||
|
||||
class ContextNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(ContextNet, self).__init__()
|
||||
self.conv1 = ResBlock(3, c, 1)
|
||||
self.conv2 = ResBlock(c, 2*c)
|
||||
self.conv3 = ResBlock(2*c, 4*c)
|
||||
self.conv4 = ResBlock(4*c, 8*c)
|
||||
|
||||
def forward(self, x, flow):
|
||||
x = self.conv1(x)
|
||||
f1 = warp(x, flow)
|
||||
x = self.conv2(x)
|
||||
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
|
||||
align_corners=False) * 0.5
|
||||
f2 = warp(x, flow)
|
||||
x = self.conv3(x)
|
||||
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
|
||||
align_corners=False) * 0.5
|
||||
f3 = warp(x, flow)
|
||||
x = self.conv4(x)
|
||||
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
|
||||
align_corners=False) * 0.5
|
||||
f4 = warp(x, flow)
|
||||
return [f1, f2, f3, f4]
|
||||
|
||||
|
||||
class FusionNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(FusionNet, self).__init__()
|
||||
self.down0 = ResBlock(8, 2*c, 1)
|
||||
self.down1 = ResBlock(4*c, 4*c)
|
||||
self.down2 = ResBlock(8*c, 8*c)
|
||||
self.down3 = ResBlock(16*c, 16*c)
|
||||
self.up0 = deconv(32*c, 8*c)
|
||||
self.up1 = deconv(16*c, 4*c)
|
||||
self.up2 = deconv(8*c, 2*c)
|
||||
self.up3 = deconv(4*c, c)
|
||||
self.conv = nn.Conv2d(c, 4, 3, 2, 1)
|
||||
|
||||
def forward(self, img0, img1, flow, c0, c1, flow_gt):
|
||||
warped_img0 = warp(img0, flow)
|
||||
warped_img1 = warp(img1, -flow)
|
||||
if flow_gt == None:
|
||||
warped_img0_gt, warped_img1_gt = None, None
|
||||
else:
|
||||
warped_img0_gt = warp(img0, flow_gt[:, :2])
|
||||
warped_img1_gt = warp(img1, flow_gt[:, 2:4])
|
||||
s0 = self.down0(torch.cat((warped_img0, warped_img1, flow), 1))
|
||||
s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
|
||||
s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
|
||||
s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
|
||||
x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
|
||||
x = self.up1(torch.cat((x, s2), 1))
|
||||
x = self.up2(torch.cat((x, s1), 1))
|
||||
x = self.up3(torch.cat((x, s0), 1))
|
||||
x = self.conv(x)
|
||||
return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self, local_rank=-1):
|
||||
self.flownet = IFNet()
|
||||
self.contextnet = ContextNet()
|
||||
self.fusionnet = FusionNet()
|
||||
self.device()
|
||||
self.optimG = AdamW(itertools.chain(
|
||||
self.flownet.parameters(),
|
||||
self.contextnet.parameters(),
|
||||
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5)
|
||||
self.schedulerG = optim.lr_scheduler.CyclicLR(
|
||||
self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False)
|
||||
self.epe = EPE()
|
||||
self.ter = Ternary()
|
||||
self.sobel = SOBEL()
|
||||
if local_rank != -1:
|
||||
self.flownet = DDP(self.flownet, device_ids=[
|
||||
local_rank], output_device=local_rank)
|
||||
self.contextnet = DDP(self.contextnet, device_ids=[
|
||||
local_rank], output_device=local_rank)
|
||||
self.fusionnet = DDP(self.fusionnet, device_ids=[
|
||||
local_rank], output_device=local_rank)
|
||||
|
||||
def train(self):
|
||||
self.flownet.train()
|
||||
self.contextnet.train()
|
||||
self.fusionnet.train()
|
||||
|
||||
def eval(self):
|
||||
self.flownet.eval()
|
||||
self.contextnet.eval()
|
||||
self.fusionnet.eval()
|
||||
|
||||
def device(self):
|
||||
self.flownet.to(device)
|
||||
self.contextnet.to(device)
|
||||
self.fusionnet.to(device)
|
||||
|
||||
def load_model(self, path, rank=0):
|
||||
def convert(param):
|
||||
return {
|
||||
k.replace("module.", ""): v
|
||||
for k, v in param.items()
|
||||
if "module." in k
|
||||
}
|
||||
if rank == 0:
|
||||
self.flownet.load_state_dict(
|
||||
convert(torch.load('{}/flownet.pkl'.format(path), map_location=device)))
|
||||
self.contextnet.load_state_dict(
|
||||
convert(torch.load('{}/contextnet.pkl'.format(path), map_location=device)))
|
||||
self.fusionnet.load_state_dict(
|
||||
convert(torch.load('{}/unet.pkl'.format(path), map_location=device)))
|
||||
|
||||
def save_model(self, path, rank=0):
|
||||
if rank == 0:
|
||||
torch.save(self.flownet.state_dict(),
|
||||
'{}/flownet.pkl'.format(path))
|
||||
torch.save(self.contextnet.state_dict(),
|
||||
'{}/contextnet.pkl'.format(path))
|
||||
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))
|
||||
|
||||
def predict(self, imgs, flow, training=True, flow_gt=None):
|
||||
img0 = imgs[:, :3]
|
||||
img1 = imgs[:, 3:]
|
||||
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
|
||||
align_corners=False) * 2.0
|
||||
c0 = self.contextnet(img0, flow)
|
||||
c1 = self.contextnet(img1, -flow)
|
||||
refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet(
|
||||
img0, img1, flow, c0, c1, flow_gt)
|
||||
res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
|
||||
mask = torch.sigmoid(refine_output[:, 3:4])
|
||||
merged_img = warped_img0 * mask + warped_img1 * (1 - mask)
|
||||
pred = merged_img + res
|
||||
pred = torch.clamp(pred, 0, 1)
|
||||
if training:
|
||||
return pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
|
||||
else:
|
||||
return pred
|
||||
|
||||
def inference(self, img0, img1):
|
||||
with torch.no_grad():
|
||||
imgs = torch.cat((img0, img1), 1)
|
||||
flow, _ = self.flownet(imgs)
|
||||
return self.predict(imgs, flow, training=False).detach()
|
||||
|
||||
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
|
||||
for param_group in self.optimG.param_groups:
|
||||
param_group['lr'] = learning_rate
|
||||
if training:
|
||||
self.train()
|
||||
else:
|
||||
self.eval()
|
||||
flow, flow_list = self.flownet(imgs)
|
||||
pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.predict(
|
||||
imgs, flow, flow_gt=flow_gt)
|
||||
loss_ter = self.ter(pred, gt).mean()
|
||||
if training:
|
||||
with torch.no_grad():
|
||||
loss_flow = torch.abs(warped_img0_gt - gt).mean()
|
||||
loss_mask = torch.abs(
|
||||
merged_img - gt).sum(1, True).float().detach()
|
||||
loss_cons = 0
|
||||
for i in range(3):
|
||||
loss_cons += self.epe(flow_list[i], flow_gt[:, :2], 1)
|
||||
loss_cons += self.epe(-flow_list[i], flow_gt[:, 2:4], 1)
|
||||
loss_cons = loss_cons.mean() * 0.01
|
||||
else:
|
||||
loss_cons = torch.tensor([0])
|
||||
loss_flow = torch.abs(warped_img0 - gt).mean()
|
||||
loss_mask = 1
|
||||
loss_l1 = (((pred - gt) ** 2 + 1e-6) ** 0.5).mean()
|
||||
if training:
|
||||
self.optimG.zero_grad()
|
||||
loss_G = loss_l1 + loss_cons + loss_ter
|
||||
loss_G.backward()
|
||||
self.optimG.step()
|
||||
return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
|
||||
img1 = torch.tensor(np.random.normal(
|
||||
0, 1, (3, 3, 256, 256))).float().to(device)
|
||||
imgs = torch.cat((img0, img1), 1)
|
||||
model = Model()
|
||||
model.eval()
|
||||
print(model.inference(imgs).shape)
|
@ -0,0 +1,90 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
class EPE(nn.Module):
|
||||
def __init__(self):
|
||||
super(EPE, self).__init__()
|
||||
|
||||
def forward(self, flow, gt, loss_mask):
|
||||
loss_map = (flow - gt.detach()) ** 2
|
||||
loss_map = (loss_map.sum(1, True) + 1e-6) ** 0.5
|
||||
return (loss_map * loss_mask)
|
||||
|
||||
|
||||
class Ternary(nn.Module):
|
||||
def __init__(self, cFlag):
|
||||
super(Ternary, self).__init__()
|
||||
patch_size = 7
|
||||
out_channels = patch_size * patch_size
|
||||
self.w = np.eye(out_channels).reshape(
|
||||
(patch_size, patch_size, 1, out_channels))
|
||||
self.w = np.transpose(self.w, (3, 2, 0, 1))
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() and not cFlag else "cpu")
|
||||
self.w = torch.tensor(self.w).float().to(self.device)
|
||||
|
||||
def transform(self, img):
|
||||
patches = F.conv2d(img, self.w, padding=3, bias=None)
|
||||
transf = patches - img
|
||||
transf_norm = transf / torch.sqrt(0.81 + transf**2)
|
||||
return transf_norm
|
||||
|
||||
def rgb2gray(self, rgb):
|
||||
r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :]
|
||||
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
|
||||
return gray
|
||||
|
||||
def hamming(self, t1, t2):
|
||||
dist = (t1 - t2) ** 2
|
||||
dist_norm = torch.mean(dist / (0.1 + dist), 1, True)
|
||||
return dist_norm
|
||||
|
||||
def valid_mask(self, t, padding):
|
||||
n, _, h, w = t.size()
|
||||
inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t)
|
||||
mask = F.pad(inner, [padding] * 4)
|
||||
return mask
|
||||
|
||||
def forward(self, img0, img1):
|
||||
img0 = self.transform(self.rgb2gray(img0))
|
||||
img1 = self.transform(self.rgb2gray(img1))
|
||||
return self.hamming(img0, img1) * self.valid_mask(img0, 1)
|
||||
|
||||
|
||||
class SOBEL(nn.Module):
|
||||
def __init__(self, cFlag):
|
||||
super(SOBEL, self).__init__()
|
||||
self.kernelX = torch.tensor([
|
||||
[1, 0, -1],
|
||||
[2, 0, -2],
|
||||
[1, 0, -1],
|
||||
]).float()
|
||||
self.kernelY = self.kernelX.clone().T
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() and not cFlag else "cpu")
|
||||
self.kernelX = self.kernelX.unsqueeze(0).unsqueeze(0).to(self.device)
|
||||
self.kernelY = self.kernelY.unsqueeze(0).unsqueeze(0).to(self.device)
|
||||
|
||||
def forward(self, pred, gt):
|
||||
N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3]
|
||||
img_stack = torch.cat(
|
||||
[pred.reshape(N*C, 1, H, W), gt.reshape(N*C, 1, H, W)], 0)
|
||||
sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1)
|
||||
sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1)
|
||||
pred_X, gt_X = sobel_stack_x[:N*C], sobel_stack_x[N*C:]
|
||||
pred_Y, gt_Y = sobel_stack_y[:N*C], sobel_stack_y[N*C:]
|
||||
|
||||
L1X, L1Y = torch.abs(pred_X-gt_X), torch.abs(pred_Y-gt_Y)
|
||||
loss = (L1X+L1Y)
|
||||
return loss
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
|
||||
img1 = torch.tensor(np.random.normal(
|
||||
0, 1, (3, 3, 256, 256))).float().to(device)
|
||||
ternary_loss = Ternary()
|
||||
print(ternary_loss(img0, img1).shape)
|
@ -0,0 +1,23 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
backwarp_tenGrid = {}
|
||||
|
||||
|
||||
def warp(tenInput, tenFlow, cFlag):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not cFlag else "cpu")
|
||||
k = (str(tenFlow.device), str(tenFlow.size()))
|
||||
if k not in backwarp_tenGrid:
|
||||
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(
|
||||
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
||||
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(
|
||||
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
||||
backwarp_tenGrid[k] = torch.cat(
|
||||
[tenHorizontal, tenVertical], 1).to(device)
|
||||
|
||||
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
||||
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
|
||||
|
||||
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
||||
return torch.nn.functional.grid_sample(input=tenInput, grid=torch.clamp(g, -1, 1), mode='bilinear',
|
||||
padding_mode='zeros', align_corners=True)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue