mirror of
https://github.com/Haoming02/sd-webui-old-photo-restoration.git
synced 2026-01-26 11:19:51 +00:00
205 lines
6.1 KiB
Python
205 lines
6.1 KiB
Python
# Copyright (c) Microsoft Corporation
|
|
|
|
import torch.nn as nn
|
|
from . import networks
|
|
|
|
|
|
class Mapping_Model_with_mask(nn.Module):
|
|
def __init__(
|
|
self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None
|
|
):
|
|
super(Mapping_Model_with_mask, self).__init__()
|
|
|
|
norm_layer = networks.get_norm_layer(norm_type=norm)
|
|
activation = nn.ReLU(True)
|
|
model = []
|
|
|
|
tmp_nc = 64
|
|
n_up = 4
|
|
|
|
for i in range(n_up):
|
|
ic = min(tmp_nc * (2**i), mc)
|
|
oc = min(tmp_nc * (2 ** (i + 1)), mc)
|
|
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
|
|
|
|
self.before_NL = nn.Sequential(*model)
|
|
|
|
if opt.NL_res:
|
|
self.NL = networks.NonLocalBlock2D_with_mask_Res(
|
|
mc,
|
|
mc,
|
|
opt.NL_fusion_method,
|
|
opt.correlation_renormalize,
|
|
opt.softmax_temperature,
|
|
opt.use_self,
|
|
opt.cosin_similarity,
|
|
)
|
|
|
|
print("using NL + Res...")
|
|
|
|
model = []
|
|
for i in range(n_blocks):
|
|
model += [
|
|
networks.ResnetBlock(
|
|
mc,
|
|
padding_type=padding_type,
|
|
activation=activation,
|
|
norm_layer=norm_layer,
|
|
opt=opt,
|
|
dilation=opt.mapping_net_dilation,
|
|
)
|
|
]
|
|
|
|
for i in range(n_up - 1):
|
|
ic = min(64 * (2 ** (4 - i)), mc)
|
|
oc = min(64 * (2 ** (3 - i)), mc)
|
|
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
|
|
model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)]
|
|
if opt.feat_dim > 0 and opt.feat_dim < 64:
|
|
model += [
|
|
norm_layer(tmp_nc),
|
|
activation,
|
|
nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1),
|
|
]
|
|
# model += [nn.Conv2d(64, 1, 1, 1, 0)]
|
|
self.after_NL = nn.Sequential(*model)
|
|
|
|
def forward(self, input, mask):
|
|
x1 = self.before_NL(input)
|
|
del input
|
|
x2 = self.NL(x1, mask)
|
|
del x1, mask
|
|
x3 = self.after_NL(x2)
|
|
del x2
|
|
|
|
return x3
|
|
|
|
|
|
class Mapping_Model_with_mask_2(nn.Module): ## Multi-Scale Patch Attention
|
|
def __init__(
|
|
self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None
|
|
):
|
|
super(Mapping_Model_with_mask_2, self).__init__()
|
|
|
|
norm_layer = networks.get_norm_layer(norm_type=norm)
|
|
activation = nn.ReLU(True)
|
|
model = []
|
|
|
|
tmp_nc = 64
|
|
n_up = 4
|
|
|
|
for i in range(n_up):
|
|
ic = min(tmp_nc * (2**i), mc)
|
|
oc = min(tmp_nc * (2 ** (i + 1)), mc)
|
|
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
|
|
|
|
for i in range(2):
|
|
model += [
|
|
networks.ResnetBlock(
|
|
mc,
|
|
padding_type=padding_type,
|
|
activation=activation,
|
|
norm_layer=norm_layer,
|
|
opt=opt,
|
|
dilation=opt.mapping_net_dilation,
|
|
)
|
|
]
|
|
|
|
print("using multi-scale patch attention, conv combine + mask input...")
|
|
|
|
self.before_NL = nn.Sequential(*model)
|
|
|
|
if opt.mapping_exp == 1:
|
|
self.NL_scale_1 = networks.Patch_Attention_4(mc, mc, 8)
|
|
|
|
model = []
|
|
for i in range(2):
|
|
model += [
|
|
networks.ResnetBlock(
|
|
mc,
|
|
padding_type=padding_type,
|
|
activation=activation,
|
|
norm_layer=norm_layer,
|
|
opt=opt,
|
|
dilation=opt.mapping_net_dilation,
|
|
)
|
|
]
|
|
|
|
self.res_block_1 = nn.Sequential(*model)
|
|
|
|
if opt.mapping_exp == 1:
|
|
self.NL_scale_2 = networks.Patch_Attention_4(mc, mc, 4)
|
|
|
|
model = []
|
|
for i in range(2):
|
|
model += [
|
|
networks.ResnetBlock(
|
|
mc,
|
|
padding_type=padding_type,
|
|
activation=activation,
|
|
norm_layer=norm_layer,
|
|
opt=opt,
|
|
dilation=opt.mapping_net_dilation,
|
|
)
|
|
]
|
|
|
|
self.res_block_2 = nn.Sequential(*model)
|
|
|
|
if opt.mapping_exp == 1:
|
|
self.NL_scale_3 = networks.Patch_Attention_4(mc, mc, 2)
|
|
# self.NL_scale_3=networks.Patch_Attention_2(mc,mc,2)
|
|
|
|
model = []
|
|
for i in range(2):
|
|
model += [
|
|
networks.ResnetBlock(
|
|
mc,
|
|
padding_type=padding_type,
|
|
activation=activation,
|
|
norm_layer=norm_layer,
|
|
opt=opt,
|
|
dilation=opt.mapping_net_dilation,
|
|
)
|
|
]
|
|
|
|
for i in range(n_up - 1):
|
|
ic = min(64 * (2 ** (4 - i)), mc)
|
|
oc = min(64 * (2 ** (3 - i)), mc)
|
|
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
|
|
model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)]
|
|
if opt.feat_dim > 0 and opt.feat_dim < 64:
|
|
model += [
|
|
norm_layer(tmp_nc),
|
|
activation,
|
|
nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1),
|
|
]
|
|
# model += [nn.Conv2d(64, 1, 1, 1, 0)]
|
|
self.after_NL = nn.Sequential(*model)
|
|
|
|
def forward(self, input, mask):
|
|
x1 = self.before_NL(input)
|
|
x2 = self.NL_scale_1(x1, mask)
|
|
x3 = self.res_block_1(x2)
|
|
x4 = self.NL_scale_2(x3, mask)
|
|
x5 = self.res_block_2(x4)
|
|
x6 = self.NL_scale_3(x5, mask)
|
|
x7 = self.after_NL(x6)
|
|
return x7
|
|
|
|
def inference_forward(self, input, mask):
|
|
x1 = self.before_NL(input)
|
|
del input
|
|
x2 = self.NL_scale_1.inference_forward(x1, mask)
|
|
del x1
|
|
x3 = self.res_block_1(x2)
|
|
del x2
|
|
x4 = self.NL_scale_2.inference_forward(x3, mask)
|
|
del x3
|
|
x5 = self.res_block_2(x4)
|
|
del x4
|
|
x6 = self.NL_scale_3.inference_forward(x5, mask)
|
|
del x5
|
|
x7 = self.after_NL(x6)
|
|
del x6
|
|
return x7
|