mirror of
https://github.com/Haoming02/sd-webui-old-photo-restoration.git
synced 2026-01-26 19:29:52 +00:00
commit cd7a9c103d1ea981ecd236d4e9111fd3c1cd6c2b Author: Haoming <hmstudy02@gmail.com> Date: Tue Dec 19 11:33:44 2023 +0800 add README commit 30127cbb2a8e5f461c540729dc7ad457f66eb94c Author: Haoming <hmstudy02@gmail.com> Date: Tue Dec 19 11:12:16 2023 +0800 fix Face Enhancement distortion commit 6d52de5368c6cfbd9342465b5238725c186e00b9 Author: Haoming <hmstudy02@gmail.com> Date: Mon Dec 18 18:27:25 2023 +0800 better? args handling commit 0d1938b59eb77a038ee0a91a66b07fb9d7b3d6d4 Author: Haoming <hmstudy02@gmail.com> Date: Mon Dec 18 17:40:19 2023 +0800 bug fix related to Scratch commit 8315cd05ffeb2d651b4c57d70bf04b413ca8901d Author: Haoming <hmstudy02@gmail.com> Date: Mon Dec 18 17:24:52 2023 +0800 implement step 2 ~ 4 commit a5feb04b3980bdd80c6b012a94c743ba48cdfe39 Author: Haoming <hmstudy02@gmail.com> Date: Mon Dec 18 11:55:20 2023 +0800 process scratch commit3b18f7b042Author: Haoming <hmstudy02@gmail.com> Date: Wed Dec 13 11:57:20 2023 +0800 "init" commitd0148e0e82Author: Haoming <hmstudy02@gmail.com> Date: Wed Dec 13 10:34:39 2023 +0800 clone repo
174 lines
6.1 KiB
Python
174 lines
6.1 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torchvision
|
|
import torch.nn.utils.spectral_norm as spectral_norm
|
|
from .normalization import SPADE
|
|
|
|
|
|
# ResNet block that uses SPADE.
|
|
# It differs from the ResNet block of pix2pixHD in that
|
|
# it takes in the segmentation map as input, learns the skip connection if necessary,
|
|
# and applies normalization first and then convolution.
|
|
# This architecture seemed like a standard architecture for unconditional or
|
|
# class-conditional GAN architecture using residual block.
|
|
# The code was inspired from https://github.com/LMescheder/GAN_stability.
|
|
class SPADEResnetBlock(nn.Module):
|
|
def __init__(self, fin, fout, opt):
|
|
super().__init__()
|
|
# Attributes
|
|
self.learned_shortcut = fin != fout
|
|
fmiddle = min(fin, fout)
|
|
|
|
self.opt = opt
|
|
# create conv layers
|
|
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
|
|
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
|
|
if self.learned_shortcut:
|
|
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
|
|
|
|
# apply spectral norm if specified
|
|
if "spectral" in opt.norm_G:
|
|
self.conv_0 = spectral_norm(self.conv_0)
|
|
self.conv_1 = spectral_norm(self.conv_1)
|
|
if self.learned_shortcut:
|
|
self.conv_s = spectral_norm(self.conv_s)
|
|
|
|
# define normalization layers
|
|
spade_config_str = opt.norm_G.replace("spectral", "")
|
|
self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
|
|
self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt)
|
|
if self.learned_shortcut:
|
|
self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
|
|
|
|
# note the resnet block with SPADE also takes in |seg|,
|
|
# the semantic segmentation map as input
|
|
def forward(self, x, seg, degraded_image):
|
|
x_s = self.shortcut(x, seg, degraded_image)
|
|
|
|
dx = self.conv_0(self.actvn(self.norm_0(x, seg, degraded_image)))
|
|
dx = self.conv_1(self.actvn(self.norm_1(dx, seg, degraded_image)))
|
|
|
|
out = x_s + dx
|
|
|
|
return out
|
|
|
|
def shortcut(self, x, seg, degraded_image):
|
|
if self.learned_shortcut:
|
|
x_s = self.conv_s(self.norm_s(x, seg, degraded_image))
|
|
else:
|
|
x_s = x
|
|
return x_s
|
|
|
|
def actvn(self, x):
|
|
return F.leaky_relu(x, 2e-1)
|
|
|
|
|
|
# ResNet block used in pix2pixHD
|
|
# We keep the same architecture as pix2pixHD.
|
|
class ResnetBlock(nn.Module):
|
|
def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3):
|
|
super().__init__()
|
|
|
|
pw = (kernel_size - 1) // 2
|
|
self.conv_block = nn.Sequential(
|
|
nn.ReflectionPad2d(pw),
|
|
norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),
|
|
activation,
|
|
nn.ReflectionPad2d(pw),
|
|
norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),
|
|
)
|
|
|
|
def forward(self, x):
|
|
y = self.conv_block(x)
|
|
out = x + y
|
|
return out
|
|
|
|
|
|
# VGG architecter, used for the perceptual loss using a pretrained VGG network
|
|
class VGG19(torch.nn.Module):
|
|
def __init__(self, requires_grad=False):
|
|
super().__init__()
|
|
vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
|
|
self.slice1 = torch.nn.Sequential()
|
|
self.slice2 = torch.nn.Sequential()
|
|
self.slice3 = torch.nn.Sequential()
|
|
self.slice4 = torch.nn.Sequential()
|
|
self.slice5 = torch.nn.Sequential()
|
|
for x in range(2):
|
|
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
|
for x in range(2, 7):
|
|
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
|
for x in range(7, 12):
|
|
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
|
for x in range(12, 21):
|
|
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
|
for x in range(21, 30):
|
|
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
|
if not requires_grad:
|
|
for param in self.parameters():
|
|
param.requires_grad = False
|
|
|
|
def forward(self, X):
|
|
h_relu1 = self.slice1(X)
|
|
h_relu2 = self.slice2(h_relu1)
|
|
h_relu3 = self.slice3(h_relu2)
|
|
h_relu4 = self.slice4(h_relu3)
|
|
h_relu5 = self.slice5(h_relu4)
|
|
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
|
return out
|
|
|
|
|
|
class SPADEResnetBlock_non_spade(nn.Module):
|
|
def __init__(self, fin, fout, opt):
|
|
super().__init__()
|
|
# Attributes
|
|
self.learned_shortcut = fin != fout
|
|
fmiddle = min(fin, fout)
|
|
|
|
self.opt = opt
|
|
# create conv layers
|
|
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
|
|
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
|
|
if self.learned_shortcut:
|
|
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
|
|
|
|
# apply spectral norm if specified
|
|
if "spectral" in opt.norm_G:
|
|
self.conv_0 = spectral_norm(self.conv_0)
|
|
self.conv_1 = spectral_norm(self.conv_1)
|
|
if self.learned_shortcut:
|
|
self.conv_s = spectral_norm(self.conv_s)
|
|
|
|
# define normalization layers
|
|
spade_config_str = opt.norm_G.replace("spectral", "")
|
|
self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
|
|
self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt)
|
|
if self.learned_shortcut:
|
|
self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
|
|
|
|
# note the resnet block with SPADE also takes in |seg|,
|
|
# the semantic segmentation map as input
|
|
def forward(self, x, seg, degraded_image):
|
|
x_s = self.shortcut(x, seg, degraded_image)
|
|
|
|
dx = self.conv_0(self.actvn(x))
|
|
dx = self.conv_1(self.actvn(dx))
|
|
|
|
out = x_s + dx
|
|
|
|
return out
|
|
|
|
def shortcut(self, x, seg, degraded_image):
|
|
if self.learned_shortcut:
|
|
x_s = self.conv_s(x)
|
|
else:
|
|
x_s = x
|
|
return x_s
|
|
|
|
def actvn(self, x):
|
|
return F.leaky_relu(x, 2e-1)
|