Files
sd-webui-old-photo-restoration/Global/models/NonLocal_feature_mapping_model.py
2024-03-25 12:37:54 +08:00

205 lines
6.1 KiB
Python

# Copyright (c) Microsoft Corporation
import torch.nn as nn
from . import networks
class Mapping_Model_with_mask(nn.Module):
def __init__(
self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None
):
super(Mapping_Model_with_mask, self).__init__()
norm_layer = networks.get_norm_layer(norm_type=norm)
activation = nn.ReLU(True)
model = []
tmp_nc = 64
n_up = 4
for i in range(n_up):
ic = min(tmp_nc * (2**i), mc)
oc = min(tmp_nc * (2 ** (i + 1)), mc)
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
self.before_NL = nn.Sequential(*model)
if opt.NL_res:
self.NL = networks.NonLocalBlock2D_with_mask_Res(
mc,
mc,
opt.NL_fusion_method,
opt.correlation_renormalize,
opt.softmax_temperature,
opt.use_self,
opt.cosin_similarity,
)
print("using NL + Res...")
model = []
for i in range(n_blocks):
model += [
networks.ResnetBlock(
mc,
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
dilation=opt.mapping_net_dilation,
)
]
for i in range(n_up - 1):
ic = min(64 * (2 ** (4 - i)), mc)
oc = min(64 * (2 ** (3 - i)), mc)
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)]
if opt.feat_dim > 0 and opt.feat_dim < 64:
model += [
norm_layer(tmp_nc),
activation,
nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1),
]
# model += [nn.Conv2d(64, 1, 1, 1, 0)]
self.after_NL = nn.Sequential(*model)
def forward(self, input, mask):
x1 = self.before_NL(input)
del input
x2 = self.NL(x1, mask)
del x1, mask
x3 = self.after_NL(x2)
del x2
return x3
class Mapping_Model_with_mask_2(nn.Module): ## Multi-Scale Patch Attention
def __init__(
self, nc, mc=64, n_blocks=3, norm="instance", padding_type="reflect", opt=None
):
super(Mapping_Model_with_mask_2, self).__init__()
norm_layer = networks.get_norm_layer(norm_type=norm)
activation = nn.ReLU(True)
model = []
tmp_nc = 64
n_up = 4
for i in range(n_up):
ic = min(tmp_nc * (2**i), mc)
oc = min(tmp_nc * (2 ** (i + 1)), mc)
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
for i in range(2):
model += [
networks.ResnetBlock(
mc,
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
dilation=opt.mapping_net_dilation,
)
]
print("using multi-scale patch attention, conv combine + mask input...")
self.before_NL = nn.Sequential(*model)
if opt.mapping_exp == 1:
self.NL_scale_1 = networks.Patch_Attention_4(mc, mc, 8)
model = []
for i in range(2):
model += [
networks.ResnetBlock(
mc,
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
dilation=opt.mapping_net_dilation,
)
]
self.res_block_1 = nn.Sequential(*model)
if opt.mapping_exp == 1:
self.NL_scale_2 = networks.Patch_Attention_4(mc, mc, 4)
model = []
for i in range(2):
model += [
networks.ResnetBlock(
mc,
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
dilation=opt.mapping_net_dilation,
)
]
self.res_block_2 = nn.Sequential(*model)
if opt.mapping_exp == 1:
self.NL_scale_3 = networks.Patch_Attention_4(mc, mc, 2)
# self.NL_scale_3=networks.Patch_Attention_2(mc,mc,2)
model = []
for i in range(2):
model += [
networks.ResnetBlock(
mc,
padding_type=padding_type,
activation=activation,
norm_layer=norm_layer,
opt=opt,
dilation=opt.mapping_net_dilation,
)
]
for i in range(n_up - 1):
ic = min(64 * (2 ** (4 - i)), mc)
oc = min(64 * (2 ** (3 - i)), mc)
model += [nn.Conv2d(ic, oc, 3, 1, 1), norm_layer(oc), activation]
model += [nn.Conv2d(tmp_nc * 2, tmp_nc, 3, 1, 1)]
if opt.feat_dim > 0 and opt.feat_dim < 64:
model += [
norm_layer(tmp_nc),
activation,
nn.Conv2d(tmp_nc, opt.feat_dim, 1, 1),
]
# model += [nn.Conv2d(64, 1, 1, 1, 0)]
self.after_NL = nn.Sequential(*model)
def forward(self, input, mask):
x1 = self.before_NL(input)
x2 = self.NL_scale_1(x1, mask)
x3 = self.res_block_1(x2)
x4 = self.NL_scale_2(x3, mask)
x5 = self.res_block_2(x4)
x6 = self.NL_scale_3(x5, mask)
x7 = self.after_NL(x6)
return x7
def inference_forward(self, input, mask):
x1 = self.before_NL(input)
del input
x2 = self.NL_scale_1.inference_forward(x1, mask)
del x1
x3 = self.res_block_1(x2)
del x2
x4 = self.NL_scale_2.inference_forward(x3, mask)
del x3
x5 = self.res_block_2(x4)
del x4
x6 = self.NL_scale_3.inference_forward(x5, mask)
del x5
x7 = self.after_NL(x6)
del x6
return x7