Files
sd-webui-old-photo-restoration/Face_Enhancement/models/networks/architecture.py
Haoming 89a8626838 Squashed commit of the following:
commit cd7a9c103d1ea981ecd236d4e9111fd3c1cd6c2b
Author: Haoming <hmstudy02@gmail.com>
Date:   Tue Dec 19 11:33:44 2023 +0800

    add README

commit 30127cbb2a8e5f461c540729dc7ad457f66eb94c
Author: Haoming <hmstudy02@gmail.com>
Date:   Tue Dec 19 11:12:16 2023 +0800

    fix Face Enhancement distortion

commit 6d52de5368c6cfbd9342465b5238725c186e00b9
Author: Haoming <hmstudy02@gmail.com>
Date:   Mon Dec 18 18:27:25 2023 +0800

    better? args handling

commit 0d1938b59eb77a038ee0a91a66b07fb9d7b3d6d4
Author: Haoming <hmstudy02@gmail.com>
Date:   Mon Dec 18 17:40:19 2023 +0800

    bug fix related to Scratch

commit 8315cd05ffeb2d651b4c57d70bf04b413ca8901d
Author: Haoming <hmstudy02@gmail.com>
Date:   Mon Dec 18 17:24:52 2023 +0800

    implement step 2 ~ 4

commit a5feb04b3980bdd80c6b012a94c743ba48cdfe39
Author: Haoming <hmstudy02@gmail.com>
Date:   Mon Dec 18 11:55:20 2023 +0800

    process scratch

commit 3b18f7b042
Author: Haoming <hmstudy02@gmail.com>
Date:   Wed Dec 13 11:57:20 2023 +0800

    "init"

commit d0148e0e82
Author: Haoming <hmstudy02@gmail.com>
Date:   Wed Dec 13 10:34:39 2023 +0800

    clone repo
2023-12-19 11:35:38 +08:00

174 lines
6.1 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.nn.utils.spectral_norm as spectral_norm
from .normalization import SPADE
# ResNet block that uses SPADE.
# It differs from the ResNet block of pix2pixHD in that
# it takes in the segmentation map as input, learns the skip connection if necessary,
# and applies normalization first and then convolution.
# This architecture seemed like a standard architecture for unconditional or
# class-conditional GAN architecture using residual block.
# The code was inspired from https://github.com/LMescheder/GAN_stability.
class SPADEResnetBlock(nn.Module):
def __init__(self, fin, fout, opt):
super().__init__()
# Attributes
self.learned_shortcut = fin != fout
fmiddle = min(fin, fout)
self.opt = opt
# create conv layers
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
# apply spectral norm if specified
if "spectral" in opt.norm_G:
self.conv_0 = spectral_norm(self.conv_0)
self.conv_1 = spectral_norm(self.conv_1)
if self.learned_shortcut:
self.conv_s = spectral_norm(self.conv_s)
# define normalization layers
spade_config_str = opt.norm_G.replace("spectral", "")
self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt)
if self.learned_shortcut:
self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
# note the resnet block with SPADE also takes in |seg|,
# the semantic segmentation map as input
def forward(self, x, seg, degraded_image):
x_s = self.shortcut(x, seg, degraded_image)
dx = self.conv_0(self.actvn(self.norm_0(x, seg, degraded_image)))
dx = self.conv_1(self.actvn(self.norm_1(dx, seg, degraded_image)))
out = x_s + dx
return out
def shortcut(self, x, seg, degraded_image):
if self.learned_shortcut:
x_s = self.conv_s(self.norm_s(x, seg, degraded_image))
else:
x_s = x
return x_s
def actvn(self, x):
return F.leaky_relu(x, 2e-1)
# ResNet block used in pix2pixHD
# We keep the same architecture as pix2pixHD.
class ResnetBlock(nn.Module):
def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3):
super().__init__()
pw = (kernel_size - 1) // 2
self.conv_block = nn.Sequential(
nn.ReflectionPad2d(pw),
norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),
activation,
nn.ReflectionPad2d(pw),
norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)),
)
def forward(self, x):
y = self.conv_block(x)
out = x + y
return out
# VGG architecter, used for the perceptual loss using a pretrained VGG network
class VGG19(torch.nn.Module):
def __init__(self, requires_grad=False):
super().__init__()
vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
class SPADEResnetBlock_non_spade(nn.Module):
def __init__(self, fin, fout, opt):
super().__init__()
# Attributes
self.learned_shortcut = fin != fout
fmiddle = min(fin, fout)
self.opt = opt
# create conv layers
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
# apply spectral norm if specified
if "spectral" in opt.norm_G:
self.conv_0 = spectral_norm(self.conv_0)
self.conv_1 = spectral_norm(self.conv_1)
if self.learned_shortcut:
self.conv_s = spectral_norm(self.conv_s)
# define normalization layers
spade_config_str = opt.norm_G.replace("spectral", "")
self.norm_0 = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
self.norm_1 = SPADE(spade_config_str, fmiddle, opt.semantic_nc, opt)
if self.learned_shortcut:
self.norm_s = SPADE(spade_config_str, fin, opt.semantic_nc, opt)
# note the resnet block with SPADE also takes in |seg|,
# the semantic segmentation map as input
def forward(self, x, seg, degraded_image):
x_s = self.shortcut(x, seg, degraded_image)
dx = self.conv_0(self.actvn(x))
dx = self.conv_1(self.actvn(dx))
out = x_s + dx
return out
def shortcut(self, x, seg, degraded_image):
if self.learned_shortcut:
x_s = self.conv_s(x)
else:
x_s = x
return x_s
def actvn(self, x):
return F.leaky_relu(x, 2e-1)