mirror of
https://github.com/Haoming02/sd-webui-old-photo-restoration.git
synced 2026-01-26 19:29:52 +00:00
commit cd7a9c103d1ea981ecd236d4e9111fd3c1cd6c2b Author: Haoming <hmstudy02@gmail.com> Date: Tue Dec 19 11:33:44 2023 +0800 add README commit 30127cbb2a8e5f461c540729dc7ad457f66eb94c Author: Haoming <hmstudy02@gmail.com> Date: Tue Dec 19 11:12:16 2023 +0800 fix Face Enhancement distortion commit 6d52de5368c6cfbd9342465b5238725c186e00b9 Author: Haoming <hmstudy02@gmail.com> Date: Mon Dec 18 18:27:25 2023 +0800 better? args handling commit 0d1938b59eb77a038ee0a91a66b07fb9d7b3d6d4 Author: Haoming <hmstudy02@gmail.com> Date: Mon Dec 18 17:40:19 2023 +0800 bug fix related to Scratch commit 8315cd05ffeb2d651b4c57d70bf04b413ca8901d Author: Haoming <hmstudy02@gmail.com> Date: Mon Dec 18 17:24:52 2023 +0800 implement step 2 ~ 4 commit a5feb04b3980bdd80c6b012a94c743ba48cdfe39 Author: Haoming <hmstudy02@gmail.com> Date: Mon Dec 18 11:55:20 2023 +0800 process scratch commit3b18f7b042Author: Haoming <hmstudy02@gmail.com> Date: Wed Dec 13 11:57:20 2023 +0800 "init" commitd0148e0e82Author: Haoming <hmstudy02@gmail.com> Date: Wed Dec 13 10:34:39 2023 +0800 clone repo
234 lines
8.1 KiB
Python
234 lines
8.1 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from .base_network import BaseNetwork
|
|
from .normalization import get_nonspade_norm_layer
|
|
from .architecture import ResnetBlock as ResnetBlock
|
|
from .architecture import SPADEResnetBlock as SPADEResnetBlock
|
|
from .architecture import SPADEResnetBlock_non_spade as SPADEResnetBlock_non_spade
|
|
|
|
|
|
class SPADEGenerator(BaseNetwork):
|
|
@staticmethod
|
|
def modify_commandline_options(parser, is_train):
|
|
parser.set_defaults(norm_G="spectralspadesyncbatch3x3")
|
|
parser.add_argument(
|
|
"--num_upsampling_layers",
|
|
choices=("normal", "more", "most"),
|
|
default="normal",
|
|
help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator",
|
|
)
|
|
|
|
return parser
|
|
|
|
def __init__(self, opt):
|
|
super().__init__()
|
|
self.opt = opt
|
|
nf = opt.ngf
|
|
|
|
self.sw, self.sh = self.compute_latent_vector_size(opt)
|
|
|
|
print("The size of the latent vector size is [%d,%d]" % (self.sw, self.sh))
|
|
|
|
if opt.use_vae:
|
|
# In case of VAE, we will sample from random z vector
|
|
self.fc = nn.Linear(opt.z_dim, 16 * nf * self.sw * self.sh)
|
|
else:
|
|
# Otherwise, we make the network deterministic by starting with
|
|
# downsampled segmentation map instead of random z
|
|
if self.opt.no_parsing_map:
|
|
self.fc = nn.Conv2d(3, 16 * nf, 3, padding=1)
|
|
else:
|
|
self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * nf, 3, padding=1)
|
|
|
|
if self.opt.injection_layer == "all" or self.opt.injection_layer == "1":
|
|
self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
|
|
else:
|
|
self.head_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)
|
|
|
|
if self.opt.injection_layer == "all" or self.opt.injection_layer == "2":
|
|
self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
|
|
self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf, opt)
|
|
|
|
else:
|
|
self.G_middle_0 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)
|
|
self.G_middle_1 = SPADEResnetBlock_non_spade(16 * nf, 16 * nf, opt)
|
|
|
|
if self.opt.injection_layer == "all" or self.opt.injection_layer == "3":
|
|
self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf, opt)
|
|
else:
|
|
self.up_0 = SPADEResnetBlock_non_spade(16 * nf, 8 * nf, opt)
|
|
|
|
if self.opt.injection_layer == "all" or self.opt.injection_layer == "4":
|
|
self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt)
|
|
else:
|
|
self.up_1 = SPADEResnetBlock_non_spade(8 * nf, 4 * nf, opt)
|
|
|
|
if self.opt.injection_layer == "all" or self.opt.injection_layer == "5":
|
|
self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt)
|
|
else:
|
|
self.up_2 = SPADEResnetBlock_non_spade(4 * nf, 2 * nf, opt)
|
|
|
|
if self.opt.injection_layer == "all" or self.opt.injection_layer == "6":
|
|
self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt)
|
|
else:
|
|
self.up_3 = SPADEResnetBlock_non_spade(2 * nf, 1 * nf, opt)
|
|
|
|
final_nc = nf
|
|
|
|
if opt.num_upsampling_layers == "most":
|
|
self.up_4 = SPADEResnetBlock(1 * nf, nf // 2, opt)
|
|
final_nc = nf // 2
|
|
|
|
self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)
|
|
|
|
self.up = nn.Upsample(scale_factor=2)
|
|
|
|
def compute_latent_vector_size(self, opt):
|
|
if opt.num_upsampling_layers == "normal":
|
|
num_up_layers = 5
|
|
elif opt.num_upsampling_layers == "more":
|
|
num_up_layers = 6
|
|
elif opt.num_upsampling_layers == "most":
|
|
num_up_layers = 7
|
|
else:
|
|
raise ValueError("opt.num_upsampling_layers [%s] not recognized" % opt.num_upsampling_layers)
|
|
|
|
sw = opt.load_size // (2 ** num_up_layers)
|
|
sh = round(sw / opt.aspect_ratio)
|
|
|
|
return sw, sh
|
|
|
|
def forward(self, input, degraded_image, z=None):
|
|
seg = input
|
|
|
|
if self.opt.use_vae:
|
|
# we sample z from unit normal and reshape the tensor
|
|
if z is None:
|
|
z = torch.randn(input.size(0), self.opt.z_dim, dtype=torch.float32, device=input.get_device())
|
|
x = self.fc(z)
|
|
x = x.view(-1, 16 * self.opt.ngf, self.sh, self.sw)
|
|
else:
|
|
# we downsample segmap and run convolution
|
|
if self.opt.no_parsing_map:
|
|
x = F.interpolate(degraded_image, size=(self.sh, self.sw), mode="bilinear")
|
|
else:
|
|
x = F.interpolate(seg, size=(self.sh, self.sw), mode="nearest")
|
|
x = self.fc(x)
|
|
|
|
x = self.head_0(x, seg, degraded_image)
|
|
|
|
x = self.up(x)
|
|
x = self.G_middle_0(x, seg, degraded_image)
|
|
|
|
if self.opt.num_upsampling_layers == "more" or self.opt.num_upsampling_layers == "most":
|
|
x = self.up(x)
|
|
|
|
x = self.G_middle_1(x, seg, degraded_image)
|
|
|
|
x = self.up(x)
|
|
x = self.up_0(x, seg, degraded_image)
|
|
x = self.up(x)
|
|
x = self.up_1(x, seg, degraded_image)
|
|
x = self.up(x)
|
|
x = self.up_2(x, seg, degraded_image)
|
|
x = self.up(x)
|
|
x = self.up_3(x, seg, degraded_image)
|
|
|
|
if self.opt.num_upsampling_layers == "most":
|
|
x = self.up(x)
|
|
x = self.up_4(x, seg, degraded_image)
|
|
|
|
x = self.conv_img(F.leaky_relu(x, 2e-1))
|
|
x = F.tanh(x)
|
|
|
|
return x
|
|
|
|
|
|
class Pix2PixHDGenerator(BaseNetwork):
|
|
@staticmethod
|
|
def modify_commandline_options(parser, is_train):
|
|
parser.add_argument(
|
|
"--resnet_n_downsample", type=int, default=4, help="number of downsampling layers in netG"
|
|
)
|
|
parser.add_argument(
|
|
"--resnet_n_blocks",
|
|
type=int,
|
|
default=9,
|
|
help="number of residual blocks in the global generator network",
|
|
)
|
|
parser.add_argument(
|
|
"--resnet_kernel_size", type=int, default=3, help="kernel size of the resnet block"
|
|
)
|
|
parser.add_argument(
|
|
"--resnet_initial_kernel_size", type=int, default=7, help="kernel size of the first convolution"
|
|
)
|
|
# parser.set_defaults(norm_G='instance')
|
|
return parser
|
|
|
|
def __init__(self, opt):
|
|
super().__init__()
|
|
input_nc = 3
|
|
|
|
# print("xxxxx")
|
|
# print(opt.norm_G)
|
|
norm_layer = get_nonspade_norm_layer(opt, opt.norm_G)
|
|
activation = nn.ReLU(False)
|
|
|
|
model = []
|
|
|
|
# initial conv
|
|
model += [
|
|
nn.ReflectionPad2d(opt.resnet_initial_kernel_size // 2),
|
|
norm_layer(nn.Conv2d(input_nc, opt.ngf, kernel_size=opt.resnet_initial_kernel_size, padding=0)),
|
|
activation,
|
|
]
|
|
|
|
# downsample
|
|
mult = 1
|
|
for i in range(opt.resnet_n_downsample):
|
|
model += [
|
|
norm_layer(nn.Conv2d(opt.ngf * mult, opt.ngf * mult * 2, kernel_size=3, stride=2, padding=1)),
|
|
activation,
|
|
]
|
|
mult *= 2
|
|
|
|
# resnet blocks
|
|
for i in range(opt.resnet_n_blocks):
|
|
model += [
|
|
ResnetBlock(
|
|
opt.ngf * mult,
|
|
norm_layer=norm_layer,
|
|
activation=activation,
|
|
kernel_size=opt.resnet_kernel_size,
|
|
)
|
|
]
|
|
|
|
# upsample
|
|
for i in range(opt.resnet_n_downsample):
|
|
nc_in = int(opt.ngf * mult)
|
|
nc_out = int((opt.ngf * mult) / 2)
|
|
model += [
|
|
norm_layer(
|
|
nn.ConvTranspose2d(nc_in, nc_out, kernel_size=3, stride=2, padding=1, output_padding=1)
|
|
),
|
|
activation,
|
|
]
|
|
mult = mult // 2
|
|
|
|
# final output conv
|
|
model += [
|
|
nn.ReflectionPad2d(3),
|
|
nn.Conv2d(nc_out, opt.output_nc, kernel_size=7, padding=0),
|
|
nn.Tanh(),
|
|
]
|
|
|
|
self.model = nn.Sequential(*model)
|
|
|
|
def forward(self, input, degraded_image, z=None):
|
|
return self.model(degraded_image)
|
|
|