mirror of
https://github.com/Haoming02/sd-webui-old-photo-restoration.git
synced 2026-01-26 19:29:52 +00:00
commit cd7a9c103d1ea981ecd236d4e9111fd3c1cd6c2b Author: Haoming <hmstudy02@gmail.com> Date: Tue Dec 19 11:33:44 2023 +0800 add README commit 30127cbb2a8e5f461c540729dc7ad457f66eb94c Author: Haoming <hmstudy02@gmail.com> Date: Tue Dec 19 11:12:16 2023 +0800 fix Face Enhancement distortion commit 6d52de5368c6cfbd9342465b5238725c186e00b9 Author: Haoming <hmstudy02@gmail.com> Date: Mon Dec 18 18:27:25 2023 +0800 better? args handling commit 0d1938b59eb77a038ee0a91a66b07fb9d7b3d6d4 Author: Haoming <hmstudy02@gmail.com> Date: Mon Dec 18 17:40:19 2023 +0800 bug fix related to Scratch commit 8315cd05ffeb2d651b4c57d70bf04b413ca8901d Author: Haoming <hmstudy02@gmail.com> Date: Mon Dec 18 17:24:52 2023 +0800 implement step 2 ~ 4 commit a5feb04b3980bdd80c6b012a94c743ba48cdfe39 Author: Haoming <hmstudy02@gmail.com> Date: Mon Dec 18 11:55:20 2023 +0800 process scratch commit3b18f7b042Author: Haoming <hmstudy02@gmail.com> Date: Wed Dec 13 11:57:20 2023 +0800 "init" commitd0148e0e82Author: Haoming <hmstudy02@gmail.com> Date: Wed Dec 13 10:34:39 2023 +0800 clone repo
333 lines
11 KiB
Python
333 lines
11 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from .sync_batchnorm import DataParallelWithCallback
|
|
from .antialiasing import Downsample
|
|
|
|
|
|
class UNet(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels=3,
|
|
out_channels=3,
|
|
depth=5,
|
|
conv_num=2,
|
|
wf=6,
|
|
padding=True,
|
|
batch_norm=True,
|
|
up_mode="upsample",
|
|
with_tanh=False,
|
|
sync_bn=True,
|
|
antialiasing=True,
|
|
):
|
|
"""
|
|
Implementation of
|
|
U-Net: Convolutional Networks for Biomedical Image Segmentation
|
|
(Ronneberger et al., 2015)
|
|
https://arxiv.org/abs/1505.04597
|
|
Using the default arguments will yield the exact version used
|
|
in the original paper
|
|
Args:
|
|
in_channels (int): number of input channels
|
|
out_channels (int): number of output channels
|
|
depth (int): depth of the network
|
|
wf (int): number of filters in the first layer is 2**wf
|
|
padding (bool): if True, apply padding such that the input shape
|
|
is the same as the output.
|
|
This may introduce artifacts
|
|
batch_norm (bool): Use BatchNorm after layers with an
|
|
activation function
|
|
up_mode (str): one of 'upconv' or 'upsample'.
|
|
'upconv' will use transposed convolutions for
|
|
learned upsampling.
|
|
'upsample' will use bilinear upsampling.
|
|
"""
|
|
super().__init__()
|
|
assert up_mode in ("upconv", "upsample")
|
|
self.padding = padding
|
|
self.depth = depth - 1
|
|
prev_channels = in_channels
|
|
|
|
self.first = nn.Sequential(
|
|
*[nn.ReflectionPad2d(3), nn.Conv2d(in_channels, 2 ** wf, kernel_size=7), nn.LeakyReLU(0.2, True)]
|
|
)
|
|
prev_channels = 2 ** wf
|
|
|
|
self.down_path = nn.ModuleList()
|
|
self.down_sample = nn.ModuleList()
|
|
for i in range(depth):
|
|
if antialiasing and depth > 0:
|
|
self.down_sample.append(
|
|
nn.Sequential(
|
|
*[
|
|
nn.ReflectionPad2d(1),
|
|
nn.Conv2d(prev_channels, prev_channels, kernel_size=3, stride=1, padding=0),
|
|
nn.BatchNorm2d(prev_channels),
|
|
nn.LeakyReLU(0.2, True),
|
|
Downsample(channels=prev_channels, stride=2),
|
|
]
|
|
)
|
|
)
|
|
else:
|
|
self.down_sample.append(
|
|
nn.Sequential(
|
|
*[
|
|
nn.ReflectionPad2d(1),
|
|
nn.Conv2d(prev_channels, prev_channels, kernel_size=4, stride=2, padding=0),
|
|
nn.BatchNorm2d(prev_channels),
|
|
nn.LeakyReLU(0.2, True),
|
|
]
|
|
)
|
|
)
|
|
self.down_path.append(
|
|
UNetConvBlock(conv_num, prev_channels, 2 ** (wf + i + 1), padding, batch_norm)
|
|
)
|
|
prev_channels = 2 ** (wf + i + 1)
|
|
|
|
self.up_path = nn.ModuleList()
|
|
for i in reversed(range(depth)):
|
|
self.up_path.append(
|
|
UNetUpBlock(conv_num, prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm)
|
|
)
|
|
prev_channels = 2 ** (wf + i)
|
|
|
|
if with_tanh:
|
|
self.last = nn.Sequential(
|
|
*[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3), nn.Tanh()]
|
|
)
|
|
else:
|
|
self.last = nn.Sequential(
|
|
*[nn.ReflectionPad2d(1), nn.Conv2d(prev_channels, out_channels, kernel_size=3)]
|
|
)
|
|
|
|
if sync_bn:
|
|
self = DataParallelWithCallback(self)
|
|
|
|
def forward(self, x):
|
|
x = self.first(x)
|
|
|
|
blocks = []
|
|
for i, down_block in enumerate(self.down_path):
|
|
blocks.append(x)
|
|
x = self.down_sample[i](x)
|
|
x = down_block(x)
|
|
|
|
for i, up in enumerate(self.up_path):
|
|
x = up(x, blocks[-i - 1])
|
|
|
|
return self.last(x)
|
|
|
|
|
|
class UNetConvBlock(nn.Module):
|
|
def __init__(self, conv_num, in_size, out_size, padding, batch_norm):
|
|
super(UNetConvBlock, self).__init__()
|
|
block = []
|
|
|
|
for _ in range(conv_num):
|
|
block.append(nn.ReflectionPad2d(padding=int(padding)))
|
|
block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=0))
|
|
if batch_norm:
|
|
block.append(nn.BatchNorm2d(out_size))
|
|
block.append(nn.LeakyReLU(0.2, True))
|
|
in_size = out_size
|
|
|
|
self.block = nn.Sequential(*block)
|
|
|
|
def forward(self, x):
|
|
out = self.block(x)
|
|
return out
|
|
|
|
|
|
class UNetUpBlock(nn.Module):
|
|
def __init__(self, conv_num, in_size, out_size, up_mode, padding, batch_norm):
|
|
super(UNetUpBlock, self).__init__()
|
|
if up_mode == "upconv":
|
|
self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
|
|
elif up_mode == "upsample":
|
|
self.up = nn.Sequential(
|
|
nn.Upsample(mode="bilinear", scale_factor=2, align_corners=False),
|
|
nn.ReflectionPad2d(1),
|
|
nn.Conv2d(in_size, out_size, kernel_size=3, padding=0),
|
|
)
|
|
|
|
self.conv_block = UNetConvBlock(conv_num, in_size, out_size, padding, batch_norm)
|
|
|
|
def center_crop(self, layer, target_size):
|
|
_, _, layer_height, layer_width = layer.size()
|
|
diff_y = (layer_height - target_size[0]) // 2
|
|
diff_x = (layer_width - target_size[1]) // 2
|
|
return layer[:, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])]
|
|
|
|
def forward(self, x, bridge):
|
|
up = self.up(x)
|
|
crop1 = self.center_crop(bridge, up.shape[2:])
|
|
out = torch.cat([up, crop1], 1)
|
|
out = self.conv_block(out)
|
|
|
|
return out
|
|
|
|
|
|
class UnetGenerator(nn.Module):
|
|
"""Create a Unet-based generator"""
|
|
|
|
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_type="BN", use_dropout=False):
|
|
"""Construct a Unet generator
|
|
Parameters:
|
|
input_nc (int) -- the number of channels in input images
|
|
output_nc (int) -- the number of channels in output images
|
|
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
|
|
image of size 128x128 will become of size 1x1 # at the bottleneck
|
|
ngf (int) -- the number of filters in the last conv layer
|
|
norm_layer -- normalization layer
|
|
We construct the U-Net from the innermost layer to the outermost layer.
|
|
It is a recursive process.
|
|
"""
|
|
super().__init__()
|
|
if norm_type == "BN":
|
|
norm_layer = nn.BatchNorm2d
|
|
elif norm_type == "IN":
|
|
norm_layer = nn.InstanceNorm2d
|
|
else:
|
|
raise NameError("Unknown norm layer")
|
|
|
|
# construct unet structure
|
|
unet_block = UnetSkipConnectionBlock(
|
|
ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True
|
|
) # add the innermost layer
|
|
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
|
|
unet_block = UnetSkipConnectionBlock(
|
|
ngf * 8,
|
|
ngf * 8,
|
|
input_nc=None,
|
|
submodule=unet_block,
|
|
norm_layer=norm_layer,
|
|
use_dropout=use_dropout,
|
|
)
|
|
# gradually reduce the number of filters from ngf * 8 to ngf
|
|
unet_block = UnetSkipConnectionBlock(
|
|
ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer
|
|
)
|
|
unet_block = UnetSkipConnectionBlock(
|
|
ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer
|
|
)
|
|
unet_block = UnetSkipConnectionBlock(
|
|
ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer
|
|
)
|
|
self.model = UnetSkipConnectionBlock(
|
|
output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer
|
|
) # add the outermost layer
|
|
|
|
def forward(self, input):
|
|
return self.model(input)
|
|
|
|
|
|
class UnetSkipConnectionBlock(nn.Module):
|
|
"""Defines the Unet submodule with skip connection.
|
|
|
|
-------------------identity----------------------
|
|
|-- downsampling -- |submodule| -- upsampling --|
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
outer_nc,
|
|
inner_nc,
|
|
input_nc=None,
|
|
submodule=None,
|
|
outermost=False,
|
|
innermost=False,
|
|
norm_layer=nn.BatchNorm2d,
|
|
use_dropout=False,
|
|
):
|
|
"""Construct a Unet submodule with skip connections.
|
|
Parameters:
|
|
outer_nc (int) -- the number of filters in the outer conv layer
|
|
inner_nc (int) -- the number of filters in the inner conv layer
|
|
input_nc (int) -- the number of channels in input images/features
|
|
submodule (UnetSkipConnectionBlock) -- previously defined submodules
|
|
outermost (bool) -- if this module is the outermost module
|
|
innermost (bool) -- if this module is the innermost module
|
|
norm_layer -- normalization layer
|
|
user_dropout (bool) -- if use dropout layers.
|
|
"""
|
|
super().__init__()
|
|
self.outermost = outermost
|
|
use_bias = norm_layer == nn.InstanceNorm2d
|
|
if input_nc is None:
|
|
input_nc = outer_nc
|
|
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
|
|
downrelu = nn.LeakyReLU(0.2, True)
|
|
downnorm = norm_layer(inner_nc)
|
|
uprelu = nn.LeakyReLU(0.2, True)
|
|
upnorm = norm_layer(outer_nc)
|
|
|
|
if outermost:
|
|
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
|
|
down = [downconv]
|
|
up = [uprelu, upconv, nn.Tanh()]
|
|
model = down + [submodule] + up
|
|
elif innermost:
|
|
upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
|
|
down = [downrelu, downconv]
|
|
up = [uprelu, upconv, upnorm]
|
|
model = down + up
|
|
else:
|
|
upconv = nn.ConvTranspose2d(
|
|
inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias
|
|
)
|
|
down = [downrelu, downconv, downnorm]
|
|
up = [uprelu, upconv, upnorm]
|
|
|
|
if use_dropout:
|
|
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
|
else:
|
|
model = down + [submodule] + up
|
|
|
|
self.model = nn.Sequential(*model)
|
|
|
|
def forward(self, x):
|
|
if self.outermost:
|
|
return self.model(x)
|
|
else: # add skip connections
|
|
return torch.cat([x, self.model(x)], 1)
|
|
|
|
|
|
# ============================================
|
|
# Network testing
|
|
# ============================================
|
|
if __name__ == "__main__":
|
|
from torchsummary import summary
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
model = UNet_two_decoders(
|
|
in_channels=3,
|
|
out_channels1=3,
|
|
out_channels2=1,
|
|
depth=4,
|
|
conv_num=1,
|
|
wf=6,
|
|
padding=True,
|
|
batch_norm=True,
|
|
up_mode="upsample",
|
|
with_tanh=False,
|
|
)
|
|
model.to(device)
|
|
|
|
model_pix2pix = UnetGenerator(3, 3, 5, ngf=64, norm_type="BN", use_dropout=False)
|
|
model_pix2pix.to(device)
|
|
|
|
print("customized unet:")
|
|
summary(model, (3, 256, 256))
|
|
|
|
print("cyclegan unet:")
|
|
summary(model_pix2pix, (3, 256, 256))
|
|
|
|
x = torch.zeros(1, 3, 256, 256).requires_grad_(True).cuda()
|
|
g = make_dot(model(x))
|
|
g.render("models/Digraph.gv", view=False)
|
|
|