mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added training for a custom version of ERSGAN arcitecture. Testing training now
This commit is contained in:
@@ -20,6 +20,7 @@ process_dict = {
|
||||
'slider_old': 'TrainSliderProcessOld',
|
||||
'lora_hack': 'TrainLoRAHack',
|
||||
'rescale_sd': 'TrainSDRescaleProcess',
|
||||
'esrgan': 'TrainESRGANProcess',
|
||||
}
|
||||
|
||||
|
||||
|
||||
575
jobs/process/TrainESRGANProcess.py
Normal file
575
jobs/process/TrainESRGANProcess.py
Normal file
@@ -0,0 +1,575 @@
|
||||
import copy
|
||||
import glob
|
||||
import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
# from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from toolkit.models.RRDB import RRDBNet as ESRGAN
|
||||
from safetensors.torch import save_file, load_file
|
||||
from torch.utils.data import DataLoader, ConcatDataset
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchvision.transforms import transforms
|
||||
|
||||
from jobs.process import BaseTrainProcess
|
||||
from toolkit.data_loader import AugmentedImageDataset
|
||||
from toolkit.esrgan_utils import convert_state_dict_to_basicsr, convert_basicsr_state_dict_to_save_format
|
||||
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.optimizer import get_optimizer
|
||||
from toolkit.style import get_style_model_and_losses
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from diffusers import AutoencoderKL
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
import numpy as np
|
||||
from .models.vgg19_critic import Critic
|
||||
|
||||
IMAGE_TRANSFORMS = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
# transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class TrainESRGANProcess(BaseTrainProcess):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
self.data_loader = None
|
||||
self.model = None
|
||||
self.device = self.get_conf('device', self.job.device)
|
||||
self.pretrained_path = self.get_conf('pretrained_path', 'None')
|
||||
self.datasets_objects = self.get_conf('datasets', required=True)
|
||||
self.batch_size = self.get_conf('batch_size', 1, as_type=int)
|
||||
self.resolution = self.get_conf('resolution', 256, as_type=int)
|
||||
self.learning_rate = self.get_conf('learning_rate', 1e-6, as_type=float)
|
||||
self.sample_every = self.get_conf('sample_every', None)
|
||||
self.optimizer_type = self.get_conf('optimizer', 'adam')
|
||||
self.epochs = self.get_conf('epochs', None, as_type=int)
|
||||
self.max_steps = self.get_conf('max_steps', None, as_type=int)
|
||||
self.save_every = self.get_conf('save_every', None)
|
||||
self.upscale_sample = self.get_conf('upscale_sample', 4)
|
||||
self.dtype = self.get_conf('dtype', 'float32')
|
||||
self.sample_sources = self.get_conf('sample_sources', None)
|
||||
self.log_every = self.get_conf('log_every', 100, as_type=int)
|
||||
self.style_weight = self.get_conf('style_weight', 0, as_type=float)
|
||||
self.content_weight = self.get_conf('content_weight', 0, as_type=float)
|
||||
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
|
||||
self.zoom = self.get_conf('zoom', 4, as_type=int)
|
||||
self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
|
||||
self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
|
||||
self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float)
|
||||
self.optimizer_params = self.get_conf('optimizer_params', {})
|
||||
self.augmentations = self.get_conf('augmentations', {})
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
if self.torch_dtype == torch.bfloat16:
|
||||
self.esrgan_dtype = torch.float16
|
||||
else:
|
||||
self.esrgan_dtype = torch.float32
|
||||
self.vgg_19 = None
|
||||
self.style_weight_scalers = []
|
||||
self.content_weight_scalers = []
|
||||
|
||||
# throw error if zoom if not divisible by 2
|
||||
if self.zoom % 2 != 0:
|
||||
raise ValueError('zoom must be divisible by 2')
|
||||
|
||||
self.step_num = 0
|
||||
self.epoch_num = 0
|
||||
|
||||
self.use_critic = self.get_conf('use_critic', False, as_type=bool)
|
||||
self.critic = None
|
||||
|
||||
if self.use_critic:
|
||||
self.critic = Critic(
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
process=self,
|
||||
**self.get_conf('critic', {}) # pass any other params
|
||||
)
|
||||
|
||||
if self.sample_every is not None and self.sample_sources is None:
|
||||
raise ValueError('sample_every is specified but sample_sources is not')
|
||||
|
||||
if self.epochs is None and self.max_steps is None:
|
||||
raise ValueError('epochs or max_steps must be specified')
|
||||
|
||||
self.data_loaders = []
|
||||
# check datasets
|
||||
assert isinstance(self.datasets_objects, list)
|
||||
for dataset in self.datasets_objects:
|
||||
if 'path' not in dataset:
|
||||
raise ValueError('dataset must have a path')
|
||||
# check if is dir
|
||||
if not os.path.isdir(dataset['path']):
|
||||
raise ValueError(f"dataset path does is not a directory: {dataset['path']}")
|
||||
|
||||
# make training folder
|
||||
if not os.path.exists(self.save_root):
|
||||
os.makedirs(self.save_root, exist_ok=True)
|
||||
|
||||
self._pattern_loss = None
|
||||
|
||||
# build augmentation transforms
|
||||
aug_transforms = []
|
||||
|
||||
def update_training_metadata(self):
|
||||
self.add_meta(OrderedDict({"training_info": self.get_training_info()}))
|
||||
|
||||
def get_training_info(self):
|
||||
info = OrderedDict({
|
||||
'step': self.step_num,
|
||||
'epoch': self.epoch_num,
|
||||
})
|
||||
return info
|
||||
|
||||
def load_datasets(self):
|
||||
if self.data_loader is None:
|
||||
print(f"Loading datasets")
|
||||
datasets = []
|
||||
for dataset in self.datasets_objects:
|
||||
print(f" - Dataset: {dataset['path']}")
|
||||
ds = copy.copy(dataset)
|
||||
ds['resolution'] = self.resolution
|
||||
|
||||
if 'augmentations' not in ds:
|
||||
ds['augmentations'] = self.augmentations
|
||||
|
||||
# add the resize down augmentation
|
||||
ds['augmentations'] = [{
|
||||
'method': 'Resize',
|
||||
'params': {
|
||||
'width': int(self.resolution // self.zoom),
|
||||
'height': int(self.resolution // self.zoom),
|
||||
# downscale interpolation, string will be evaluated
|
||||
'interpolation': 'cv2.INTER_AREA'
|
||||
}
|
||||
}] + ds['augmentations']
|
||||
|
||||
image_dataset = AugmentedImageDataset(ds)
|
||||
datasets.append(image_dataset)
|
||||
|
||||
concatenated_dataset = ConcatDataset(datasets)
|
||||
self.data_loader = DataLoader(
|
||||
concatenated_dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=6
|
||||
)
|
||||
|
||||
def setup_vgg19(self):
|
||||
if self.vgg_19 is None:
|
||||
self.vgg_19, self.style_losses, self.content_losses, self.vgg19_pool_4 = get_style_model_and_losses(
|
||||
single_target=True,
|
||||
device=self.device,
|
||||
output_layer_name='pool_4',
|
||||
dtype=self.torch_dtype
|
||||
)
|
||||
self.vgg_19.to(self.device, dtype=self.torch_dtype)
|
||||
self.vgg_19.requires_grad_(False)
|
||||
|
||||
# we run random noise through first to get layer scalers to normalize the loss per layer
|
||||
# bs of 2 because we run pred and target through stacked
|
||||
noise = torch.randn((2, 3, self.resolution, self.resolution), device=self.device, dtype=self.torch_dtype)
|
||||
self.vgg_19(noise)
|
||||
for style_loss in self.style_losses:
|
||||
# get a scaler to normalize to 1
|
||||
scaler = 1 / torch.mean(style_loss.loss).item()
|
||||
self.style_weight_scalers.append(scaler)
|
||||
for content_loss in self.content_losses:
|
||||
# get a scaler to normalize to 1
|
||||
scaler = 1 / torch.mean(content_loss.loss).item()
|
||||
# if is nan, set to 1
|
||||
if scaler != scaler:
|
||||
scaler = 1
|
||||
print(f"Warning: content loss scaler is nan, setting to 1")
|
||||
self.content_weight_scalers.append(scaler)
|
||||
|
||||
self.print(f"Style weight scalers: {self.style_weight_scalers}")
|
||||
self.print(f"Content weight scalers: {self.content_weight_scalers}")
|
||||
|
||||
def get_style_loss(self):
|
||||
if self.style_weight > 0:
|
||||
# scale all losses with loss scalers
|
||||
loss = torch.sum(
|
||||
torch.stack([loss.loss * scaler for loss, scaler in zip(self.style_losses, self.style_weight_scalers)]))
|
||||
return loss
|
||||
else:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
def get_content_loss(self):
|
||||
if self.content_weight > 0:
|
||||
# scale all losses with loss scalers
|
||||
loss = torch.sum(torch.stack(
|
||||
[loss.loss * scaler for loss, scaler in zip(self.content_losses, self.content_weight_scalers)]))
|
||||
return loss
|
||||
else:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
def get_mse_loss(self, pred, target):
|
||||
if self.mse_weight > 0:
|
||||
loss_fn = nn.MSELoss()
|
||||
loss = loss_fn(pred, target)
|
||||
return loss
|
||||
else:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
def get_tv_loss(self, pred, target):
|
||||
if self.tv_weight > 0:
|
||||
get_tv_loss = ComparativeTotalVariation()
|
||||
loss = get_tv_loss(pred, target)
|
||||
return loss
|
||||
else:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
def get_pattern_loss(self, pred, target):
|
||||
if self._pattern_loss is None:
|
||||
self._pattern_loss = PatternLoss(
|
||||
pattern_size=self.zoom,
|
||||
dtype=self.torch_dtype
|
||||
).to(self.device, dtype=self.torch_dtype)
|
||||
loss = torch.mean(self._pattern_loss(pred, target))
|
||||
return loss
|
||||
|
||||
def save(self, step=None):
|
||||
if not os.path.exists(self.save_root):
|
||||
os.makedirs(self.save_root, exist_ok=True)
|
||||
|
||||
step_num = ''
|
||||
if step is not None:
|
||||
# zeropad 9 digits
|
||||
step_num = f"_{str(step).zfill(9)}"
|
||||
|
||||
self.update_training_metadata()
|
||||
filename = f'{self.job.name}{step_num}.safetensors'
|
||||
# prepare meta
|
||||
save_meta = get_meta_for_safetensors(self.meta, self.job.name)
|
||||
|
||||
# state_dict = self.model.state_dict()
|
||||
|
||||
# state has the original state dict keys so we can save what we started from
|
||||
save_state_dict = self.model.state
|
||||
|
||||
for key in list(save_state_dict.keys()):
|
||||
v = save_state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(torch.float32)
|
||||
save_state_dict[key] = v
|
||||
|
||||
# having issues with meta
|
||||
save_file(save_state_dict, os.path.join(self.save_root, filename), save_meta)
|
||||
|
||||
self.print(f"Saved to {os.path.join(self.save_root, filename)}")
|
||||
|
||||
if self.use_critic:
|
||||
self.critic.save(step)
|
||||
|
||||
def sample(self, step=None):
|
||||
sample_folder = os.path.join(self.save_root, 'samples')
|
||||
if not os.path.exists(sample_folder):
|
||||
os.makedirs(sample_folder, exist_ok=True)
|
||||
|
||||
self.model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
for i, img_url in enumerate(self.sample_sources):
|
||||
img = exif_transpose(Image.open(img_url))
|
||||
img = img.convert('RGB')
|
||||
# crop if not square
|
||||
if img.width != img.height:
|
||||
min_dim = min(img.width, img.height)
|
||||
img = img.crop((0, 0, min_dim, min_dim))
|
||||
# resize
|
||||
img = img.resize((self.resolution * self.zoom, self.resolution * self.zoom), resample=Image.BICUBIC)
|
||||
|
||||
target_image = img
|
||||
# downscale the image input
|
||||
img = img.resize((self.resolution, self.resolution), resample=Image.BICUBIC)
|
||||
|
||||
# downscale the image input
|
||||
|
||||
img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.esrgan_dtype)
|
||||
img = img
|
||||
output = self.model(img)
|
||||
# output = (output / 2 + 0.5).clamp(0, 1)
|
||||
output = output.clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
output = output.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
|
||||
|
||||
# convert to pillow image
|
||||
output = Image.fromarray((output * 255).astype(np.uint8))
|
||||
|
||||
# upscale to size * self.upscale_sample while maintaining pixels
|
||||
output = output.resize(
|
||||
(self.resolution * self.upscale_sample, self.resolution * self.upscale_sample),
|
||||
resample=Image.NEAREST
|
||||
)
|
||||
|
||||
width, height = output.size
|
||||
|
||||
# stack input image and decoded image
|
||||
target_image = target_image.resize((width, height))
|
||||
output = output.resize((width, height))
|
||||
|
||||
output_img = Image.new('RGB', (width * 2, height))
|
||||
output_img.paste(target_image, (0, 0))
|
||||
output_img.paste(output, (width, 0))
|
||||
|
||||
step_num = ''
|
||||
if step is not None:
|
||||
# zero-pad 9 digits
|
||||
step_num = f"_{str(step).zfill(9)}"
|
||||
seconds_since_epoch = int(time.time())
|
||||
# zero-pad 2 digits
|
||||
i_str = str(i).zfill(2)
|
||||
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
|
||||
output_img.save(os.path.join(sample_folder, filename))
|
||||
|
||||
self.model.train()
|
||||
|
||||
def load_model(self):
|
||||
state_dict = None
|
||||
path_to_load = self.pretrained_path
|
||||
# see if we have a checkpoint in out output to resume from
|
||||
self.print(f"Looking for latest checkpoint in {self.save_root}")
|
||||
files = glob.glob(os.path.join(self.save_root, f"{self.job.name}*.safetensors"))
|
||||
if files and len(files) > 0:
|
||||
latest_file = max(files, key=os.path.getmtime)
|
||||
print(f" - Latest checkpoint is: {latest_file}")
|
||||
path_to_load = latest_file
|
||||
# todo update step and epoch count
|
||||
elif self.pretrained_path is None:
|
||||
self.print(f" - No checkpoint found, starting from scratch")
|
||||
else:
|
||||
self.print(f" - No checkpoint found, loading pretrained model")
|
||||
self.print(f" - path: {path_to_load}")
|
||||
|
||||
if path_to_load is not None:
|
||||
self.print(f" - Loading pretrained checkpoint: {self.pretrained_path}")
|
||||
# if ends with pth then assume pytorch checkpoint
|
||||
if path_to_load.endswith('.pth') or path_to_load.endswith('.pt'):
|
||||
state_dict = torch.load(path_to_load, map_location=self.device)
|
||||
elif path_to_load.endswith('.safetensors'):
|
||||
state_dict = load_file(path_to_load)
|
||||
else:
|
||||
raise Exception(f"Unknown file extension for checkpoint: {path_to_load}")
|
||||
|
||||
# todo determine architecture from checkpoint
|
||||
self.model = ESRGAN(
|
||||
state_dict
|
||||
).to(self.device, dtype=self.esrgan_dtype)
|
||||
|
||||
# set the model to training mode
|
||||
self.model.train()
|
||||
self.model.requires_grad_(True)
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
self.load_datasets()
|
||||
|
||||
max_step_epochs = self.max_steps // len(self.data_loader)
|
||||
num_epochs = self.epochs
|
||||
if num_epochs is None or num_epochs > max_step_epochs:
|
||||
num_epochs = max_step_epochs
|
||||
|
||||
max_epoch_steps = len(self.data_loader) * num_epochs
|
||||
num_steps = self.max_steps
|
||||
if num_steps is None or num_steps > max_epoch_steps:
|
||||
num_steps = max_epoch_steps
|
||||
self.max_steps = num_steps
|
||||
self.epochs = num_epochs
|
||||
start_step = self.step_num
|
||||
self.first_step = start_step
|
||||
|
||||
self.print(f"Training ESRGAN model:")
|
||||
self.print(f" - Training folder: {self.training_folder}")
|
||||
self.print(f" - Batch size: {self.batch_size}")
|
||||
self.print(f" - Learning rate: {self.learning_rate}")
|
||||
self.print(f" - Epochs: {num_epochs}")
|
||||
self.print(f" - Max steps: {self.max_steps}")
|
||||
|
||||
# load model
|
||||
self.load_model()
|
||||
|
||||
params = self.model.parameters()
|
||||
|
||||
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
|
||||
self.setup_vgg19()
|
||||
self.vgg_19.requires_grad_(False)
|
||||
self.vgg_19.eval()
|
||||
if self.use_critic:
|
||||
self.critic.setup()
|
||||
|
||||
optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate,
|
||||
optimizer_params=self.optimizer_params)
|
||||
|
||||
# setup scheduler
|
||||
# todo allow other schedulers
|
||||
scheduler = torch.optim.lr_scheduler.ConstantLR(
|
||||
optimizer,
|
||||
total_iters=num_steps,
|
||||
factor=1,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# setup tqdm progress bar
|
||||
self.progress_bar = tqdm(
|
||||
total=num_steps,
|
||||
desc='Training ESRGAN',
|
||||
leave=True
|
||||
)
|
||||
|
||||
blank_losses = OrderedDict({
|
||||
"total": [],
|
||||
"style": [],
|
||||
"content": [],
|
||||
"mse": [],
|
||||
"kl": [],
|
||||
"tv": [],
|
||||
"ptn": [],
|
||||
"crD": [],
|
||||
"crG": [],
|
||||
})
|
||||
epoch_losses = copy.deepcopy(blank_losses)
|
||||
log_losses = copy.deepcopy(blank_losses)
|
||||
print("Generating baseline samples")
|
||||
self.sample(step=0)
|
||||
# range start at self.epoch_num go to self.epochs
|
||||
for epoch in range(self.epoch_num, self.epochs, 1):
|
||||
if self.step_num >= self.max_steps:
|
||||
break
|
||||
for targets, inputs in self.data_loader:
|
||||
if self.step_num >= self.max_steps:
|
||||
break
|
||||
with torch.no_grad():
|
||||
targets = targets.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1)
|
||||
inputs = inputs.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1)
|
||||
|
||||
pred = self.model(inputs)
|
||||
|
||||
pred = pred.to(self.device, dtype=self.torch_dtype).clamp(0, 1)
|
||||
targets = targets.to(self.device, dtype=self.torch_dtype).clamp(0, 1)
|
||||
|
||||
# Run through VGG19
|
||||
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
|
||||
stacked = torch.cat([pred, targets], dim=0)
|
||||
# stacked = (stacked / 2 + 0.5).clamp(0, 1)
|
||||
stacked = stacked.clamp(0, 1)
|
||||
self.vgg_19(stacked)
|
||||
|
||||
if self.use_critic:
|
||||
critic_d_loss = self.critic.step(self.vgg19_pool_4.tensor.detach())
|
||||
else:
|
||||
critic_d_loss = 0.0
|
||||
|
||||
style_loss = self.get_style_loss() * self.style_weight
|
||||
content_loss = self.get_content_loss() * self.content_weight
|
||||
mse_loss = self.get_mse_loss(pred, targets) * self.mse_weight
|
||||
tv_loss = self.get_tv_loss(pred, targets) * self.tv_weight
|
||||
pattern_loss = self.get_pattern_loss(pred, targets) * self.pattern_weight
|
||||
if self.use_critic:
|
||||
critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight
|
||||
else:
|
||||
critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
loss = style_loss + content_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
# update progress bar
|
||||
loss_value = loss.item()
|
||||
# get exponent like 3.54e-4
|
||||
loss_string = f"loss: {loss_value:.2e}"
|
||||
if self.content_weight > 0:
|
||||
loss_string += f" cnt: {content_loss.item():.2e}"
|
||||
if self.style_weight > 0:
|
||||
loss_string += f" sty: {style_loss.item():.2e}"
|
||||
if self.mse_weight > 0:
|
||||
loss_string += f" mse: {mse_loss.item():.2e}"
|
||||
if self.tv_weight > 0:
|
||||
loss_string += f" tv: {tv_loss.item():.2e}"
|
||||
if self.pattern_weight > 0:
|
||||
loss_string += f" ptn: {pattern_loss.item():.2e}"
|
||||
if self.use_critic and self.critic_weight > 0:
|
||||
loss_string += f" crG: {critic_gen_loss.item():.2e}"
|
||||
if self.use_critic:
|
||||
loss_string += f" crD: {critic_d_loss:.2e}"
|
||||
|
||||
if self.optimizer_type.startswith('dadaptation') or self.optimizer_type.startswith('prodigy'):
|
||||
learning_rate = (
|
||||
optimizer.param_groups[0]["d"] *
|
||||
optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
else:
|
||||
learning_rate = optimizer.param_groups[0]['lr']
|
||||
|
||||
lr_critic_string = ''
|
||||
if self.use_critic:
|
||||
lr_critic = self.critic.get_lr()
|
||||
lr_critic_string = f" lrC: {lr_critic:.1e}"
|
||||
|
||||
self.progress_bar.set_postfix_str(f"lr: {learning_rate:.1e}{lr_critic_string} {loss_string}")
|
||||
self.progress_bar.set_description(f"E: {epoch}")
|
||||
self.progress_bar.update(1)
|
||||
|
||||
epoch_losses["total"].append(loss_value)
|
||||
epoch_losses["style"].append(style_loss.item())
|
||||
epoch_losses["content"].append(content_loss.item())
|
||||
epoch_losses["mse"].append(mse_loss.item())
|
||||
epoch_losses["tv"].append(tv_loss.item())
|
||||
epoch_losses["ptn"].append(pattern_loss.item())
|
||||
epoch_losses["crG"].append(critic_gen_loss.item())
|
||||
epoch_losses["crD"].append(critic_d_loss)
|
||||
|
||||
log_losses["total"].append(loss_value)
|
||||
log_losses["style"].append(style_loss.item())
|
||||
log_losses["content"].append(content_loss.item())
|
||||
log_losses["mse"].append(mse_loss.item())
|
||||
log_losses["tv"].append(tv_loss.item())
|
||||
log_losses["ptn"].append(pattern_loss.item())
|
||||
log_losses["crG"].append(critic_gen_loss.item())
|
||||
log_losses["crD"].append(critic_d_loss)
|
||||
|
||||
# don't do on first step
|
||||
if self.step_num != start_step:
|
||||
if self.sample_every and self.step_num % self.sample_every == 0:
|
||||
# print above the progress bar
|
||||
self.print(f"Sampling at step {self.step_num}")
|
||||
self.sample(self.step_num)
|
||||
|
||||
if self.save_every and self.step_num % self.save_every == 0:
|
||||
# print above the progress bar
|
||||
self.print(f"Saving at step {self.step_num}")
|
||||
self.save(self.step_num)
|
||||
|
||||
if self.log_every and self.step_num % self.log_every == 0:
|
||||
# log to tensorboard
|
||||
if self.writer is not None:
|
||||
# get avg loss
|
||||
for key in log_losses:
|
||||
log_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + 1e-6)
|
||||
# if log_losses[key] > 0:
|
||||
self.writer.add_scalar(f"loss/{key}", log_losses[key], self.step_num)
|
||||
# reset log losses
|
||||
log_losses = copy.deepcopy(blank_losses)
|
||||
|
||||
self.step_num += 1
|
||||
# end epoch
|
||||
if self.writer is not None:
|
||||
eps = 1e-6
|
||||
# get avg loss
|
||||
for key in epoch_losses:
|
||||
epoch_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + eps)
|
||||
if epoch_losses[key] > 0:
|
||||
self.writer.add_scalar(f"epoch loss/{key}", epoch_losses[key], epoch)
|
||||
# reset epoch losses
|
||||
epoch_losses = copy.deepcopy(blank_losses)
|
||||
|
||||
self.save()
|
||||
@@ -24,6 +24,7 @@ from diffusers import AutoencoderKL
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
import numpy as np
|
||||
from .models.vgg19_critic import Critic
|
||||
|
||||
IMAGE_TRANSFORMS = transforms.Compose(
|
||||
[
|
||||
@@ -37,145 +38,6 @@ def unnormalize(tensor):
|
||||
return (tensor / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
|
||||
class Critic:
|
||||
process: 'TrainVAEProcess'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate=1e-5,
|
||||
device='cpu',
|
||||
optimizer='adam',
|
||||
num_critic_per_gen=1,
|
||||
dtype='float32',
|
||||
lambda_gp=10,
|
||||
start_step=0,
|
||||
warmup_steps=1000,
|
||||
process=None,
|
||||
optimizer_params=None,
|
||||
):
|
||||
self.learning_rate = learning_rate
|
||||
self.device = device
|
||||
self.optimizer_type = optimizer
|
||||
self.num_critic_per_gen = num_critic_per_gen
|
||||
self.dtype = dtype
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
self.process = process
|
||||
self.model = None
|
||||
self.optimizer = None
|
||||
self.scheduler = None
|
||||
self.warmup_steps = warmup_steps
|
||||
self.start_step = start_step
|
||||
self.lambda_gp = lambda_gp
|
||||
|
||||
if optimizer_params is None:
|
||||
optimizer_params = {}
|
||||
self.optimizer_params = optimizer_params
|
||||
self.print = self.process.print
|
||||
print(f" Critic config: {self.__dict__}")
|
||||
|
||||
def setup(self):
|
||||
from .models.vgg19_critic import Vgg19Critic
|
||||
self.model = Vgg19Critic().to(self.device, dtype=self.torch_dtype)
|
||||
self.load_weights()
|
||||
self.model.train()
|
||||
self.model.requires_grad_(True)
|
||||
params = self.model.parameters()
|
||||
self.optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate,
|
||||
optimizer_params=self.optimizer_params)
|
||||
self.scheduler = torch.optim.lr_scheduler.ConstantLR(
|
||||
self.optimizer,
|
||||
total_iters=self.process.max_steps * self.num_critic_per_gen,
|
||||
factor=1,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
def load_weights(self):
|
||||
path_to_load = None
|
||||
self.print(f"Critic: Looking for latest checkpoint in {self.process.save_root}")
|
||||
files = glob.glob(os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}*.safetensors"))
|
||||
if files and len(files) > 0:
|
||||
latest_file = max(files, key=os.path.getmtime)
|
||||
print(f" - Latest checkpoint is: {latest_file}")
|
||||
path_to_load = latest_file
|
||||
else:
|
||||
self.print(f" - No checkpoint found, starting from scratch")
|
||||
if path_to_load:
|
||||
self.model.load_state_dict(load_file(path_to_load))
|
||||
|
||||
def save(self, step=None):
|
||||
self.process.update_training_metadata()
|
||||
save_meta = get_meta_for_safetensors(self.process.meta, self.process.job.name)
|
||||
step_num = ''
|
||||
if step is not None:
|
||||
# zeropad 9 digits
|
||||
step_num = f"_{str(step).zfill(9)}"
|
||||
save_path = os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}{step_num}.safetensors")
|
||||
save_file(self.model.state_dict(), save_path, save_meta)
|
||||
self.print(f"Saved critic to {save_path}")
|
||||
|
||||
def get_critic_loss(self, vgg_output):
|
||||
if self.start_step > self.process.step_num:
|
||||
return torch.tensor(0.0, dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
warmup_scaler = 1.0
|
||||
# we need a warmup when we come on of 1000 steps
|
||||
# we want to scale the loss by 0.0 at self.start_step steps and 1.0 at self.start_step + warmup_steps
|
||||
if self.process.step_num < self.start_step + self.warmup_steps:
|
||||
warmup_scaler = (self.process.step_num - self.start_step) / self.warmup_steps
|
||||
# set model to not train for generator loss
|
||||
self.model.eval()
|
||||
self.model.requires_grad_(False)
|
||||
vgg_pred, vgg_target = torch.chunk(vgg_output, 2, dim=0)
|
||||
|
||||
# run model
|
||||
stacked_output = self.model(vgg_pred)
|
||||
|
||||
return (-torch.mean(stacked_output)) * warmup_scaler
|
||||
|
||||
def step(self, vgg_output):
|
||||
|
||||
# train critic here
|
||||
self.model.train()
|
||||
self.model.requires_grad_(True)
|
||||
|
||||
critic_losses = []
|
||||
for i in range(self.num_critic_per_gen):
|
||||
inputs = vgg_output.detach()
|
||||
inputs = inputs.to(self.device, dtype=self.torch_dtype)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
|
||||
|
||||
stacked_output = self.model(inputs)
|
||||
out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
|
||||
|
||||
# Compute gradient penalty
|
||||
gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
|
||||
|
||||
# Compute WGAN-GP critic loss
|
||||
critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty
|
||||
critic_loss.backward()
|
||||
self.optimizer.zero_grad()
|
||||
self.optimizer.step()
|
||||
self.scheduler.step()
|
||||
critic_losses.append(critic_loss.item())
|
||||
|
||||
# avg loss
|
||||
loss = np.mean(critic_losses)
|
||||
return loss
|
||||
|
||||
def get_lr(self):
|
||||
if self.optimizer_type.startswith('dadaptation'):
|
||||
learning_rate = (
|
||||
self.optimizer.param_groups[0]["d"] *
|
||||
self.optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
else:
|
||||
learning_rate = self.optimizer.param_groups[0]['lr']
|
||||
|
||||
return learning_rate
|
||||
|
||||
|
||||
class TrainVAEProcess(BaseTrainProcess):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
|
||||
@@ -12,3 +12,4 @@ from .TrainSDRescaleProcess import TrainSDRescaleProcess
|
||||
from .ModRescaleLoraProcess import ModRescaleLoraProcess
|
||||
from .GenerateProcess import GenerateProcess
|
||||
from .BaseExtensionProcess import BaseExtensionProcess
|
||||
from .TrainESRGANProcess import TrainESRGANProcess
|
||||
|
||||
@@ -1,5 +1,17 @@
|
||||
import glob
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from toolkit.losses import get_gradient_penalty
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.optimizer import get_optimizer
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
|
||||
class MeanReduce(nn.Module):
|
||||
@@ -36,3 +48,147 @@ class Vgg19Critic(nn.Module):
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.main(inputs)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from jobs.process.TrainVAEProcess import TrainVAEProcess
|
||||
from jobs.process.TrainESRGANProcess import TrainESRGANProcess
|
||||
|
||||
|
||||
class Critic:
|
||||
process: Union['TrainVAEProcess', 'TrainESRGANProcess']
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate=1e-5,
|
||||
device='cpu',
|
||||
optimizer='adam',
|
||||
num_critic_per_gen=1,
|
||||
dtype='float32',
|
||||
lambda_gp=10,
|
||||
start_step=0,
|
||||
warmup_steps=1000,
|
||||
process=None,
|
||||
optimizer_params=None,
|
||||
):
|
||||
self.learning_rate = learning_rate
|
||||
self.device = device
|
||||
self.optimizer_type = optimizer
|
||||
self.num_critic_per_gen = num_critic_per_gen
|
||||
self.dtype = dtype
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
self.process = process
|
||||
self.model = None
|
||||
self.optimizer = None
|
||||
self.scheduler = None
|
||||
self.warmup_steps = warmup_steps
|
||||
self.start_step = start_step
|
||||
self.lambda_gp = lambda_gp
|
||||
|
||||
if optimizer_params is None:
|
||||
optimizer_params = {}
|
||||
self.optimizer_params = optimizer_params
|
||||
self.print = self.process.print
|
||||
print(f" Critic config: {self.__dict__}")
|
||||
|
||||
def setup(self):
|
||||
self.model = Vgg19Critic().to(self.device, dtype=self.torch_dtype)
|
||||
self.load_weights()
|
||||
self.model.train()
|
||||
self.model.requires_grad_(True)
|
||||
params = self.model.parameters()
|
||||
self.optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate,
|
||||
optimizer_params=self.optimizer_params)
|
||||
self.scheduler = torch.optim.lr_scheduler.ConstantLR(
|
||||
self.optimizer,
|
||||
total_iters=self.process.max_steps * self.num_critic_per_gen,
|
||||
factor=1,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
def load_weights(self):
|
||||
path_to_load = None
|
||||
self.print(f"Critic: Looking for latest checkpoint in {self.process.save_root}")
|
||||
files = glob.glob(os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}*.safetensors"))
|
||||
if files and len(files) > 0:
|
||||
latest_file = max(files, key=os.path.getmtime)
|
||||
print(f" - Latest checkpoint is: {latest_file}")
|
||||
path_to_load = latest_file
|
||||
else:
|
||||
self.print(f" - No checkpoint found, starting from scratch")
|
||||
if path_to_load:
|
||||
self.model.load_state_dict(load_file(path_to_load))
|
||||
|
||||
def save(self, step=None):
|
||||
self.process.update_training_metadata()
|
||||
save_meta = get_meta_for_safetensors(self.process.meta, self.process.job.name)
|
||||
step_num = ''
|
||||
if step is not None:
|
||||
# zeropad 9 digits
|
||||
step_num = f"_{str(step).zfill(9)}"
|
||||
save_path = os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}{step_num}.safetensors")
|
||||
save_file(self.model.state_dict(), save_path, save_meta)
|
||||
self.print(f"Saved critic to {save_path}")
|
||||
|
||||
def get_critic_loss(self, vgg_output):
|
||||
if self.start_step > self.process.step_num:
|
||||
return torch.tensor(0.0, dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
warmup_scaler = 1.0
|
||||
# we need a warmup when we come on of 1000 steps
|
||||
# we want to scale the loss by 0.0 at self.start_step steps and 1.0 at self.start_step + warmup_steps
|
||||
if self.process.step_num < self.start_step + self.warmup_steps:
|
||||
warmup_scaler = (self.process.step_num - self.start_step) / self.warmup_steps
|
||||
# set model to not train for generator loss
|
||||
self.model.eval()
|
||||
self.model.requires_grad_(False)
|
||||
vgg_pred, vgg_target = torch.chunk(vgg_output, 2, dim=0)
|
||||
|
||||
# run model
|
||||
stacked_output = self.model(vgg_pred)
|
||||
|
||||
return (-torch.mean(stacked_output)) * warmup_scaler
|
||||
|
||||
def step(self, vgg_output):
|
||||
|
||||
# train critic here
|
||||
self.model.train()
|
||||
self.model.requires_grad_(True)
|
||||
|
||||
critic_losses = []
|
||||
for i in range(self.num_critic_per_gen):
|
||||
inputs = vgg_output.detach()
|
||||
inputs = inputs.to(self.device, dtype=self.torch_dtype)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
|
||||
|
||||
stacked_output = self.model(inputs)
|
||||
out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
|
||||
|
||||
# Compute gradient penalty
|
||||
gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
|
||||
|
||||
# Compute WGAN-GP critic loss
|
||||
critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty
|
||||
critic_loss.backward()
|
||||
self.optimizer.zero_grad()
|
||||
self.optimizer.step()
|
||||
self.scheduler.step()
|
||||
critic_losses.append(critic_loss.item())
|
||||
|
||||
# avg loss
|
||||
loss = np.mean(critic_losses)
|
||||
return loss
|
||||
|
||||
def get_lr(self):
|
||||
if self.optimizer_type.startswith('dadaptation'):
|
||||
learning_rate = (
|
||||
self.optimizer.param_groups[0]["d"] *
|
||||
self.optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
else:
|
||||
learning_rate = self.optimizer.param_groups[0]['lr']
|
||||
|
||||
return learning_rate
|
||||
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
import albumentations as A
|
||||
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
@@ -38,7 +42,7 @@ class ImageDataset(Dataset):
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
|
||||
])
|
||||
|
||||
def get_config(self, key, default=None, required=False):
|
||||
@@ -65,7 +69,7 @@ class ImageDataset(Dataset):
|
||||
if self.random_scale and min_img_size > self.resolution:
|
||||
if min_img_size < self.resolution:
|
||||
print(
|
||||
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={file}")
|
||||
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={img_path}")
|
||||
scale_size = self.resolution
|
||||
else:
|
||||
scale_size = random.randint(self.resolution, int(min_img_size))
|
||||
@@ -78,3 +82,61 @@ class ImageDataset(Dataset):
|
||||
img = self.transform(img)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class Augments:
|
||||
def __init__(self, **kwargs):
|
||||
self.method_name = kwargs.get('method', None)
|
||||
self.params = kwargs.get('params', {})
|
||||
|
||||
# convert kwargs enums for cv2
|
||||
for key, value in self.params.items():
|
||||
if isinstance(value, str):
|
||||
# split the string
|
||||
split_string = value.split('.')
|
||||
if len(split_string) == 2 and split_string[0] == 'cv2':
|
||||
if hasattr(cv2, split_string[1]):
|
||||
self.params[key] = getattr(cv2, split_string[1].upper())
|
||||
else:
|
||||
raise ValueError(f"invalid cv2 enum: {split_string[1]}")
|
||||
|
||||
|
||||
class AugmentedImageDataset(ImageDataset):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.augmentations = self.get_config('augmentations', [])
|
||||
self.augmentations = [Augments(**aug) for aug in self.augmentations]
|
||||
|
||||
augmentation_list = []
|
||||
for aug in self.augmentations:
|
||||
# make sure method name is valid
|
||||
assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}"
|
||||
# get the method
|
||||
method = getattr(A, aug.method_name)
|
||||
# add the method to the list
|
||||
augmentation_list.append(method(**aug.params))
|
||||
|
||||
self.aug_transform = A.Compose(augmentation_list)
|
||||
self.original_transform = self.transform
|
||||
# replace transform so we get raw pil image
|
||||
self.transform = transforms.Compose([])
|
||||
|
||||
def __getitem__(self, index):
|
||||
# get the original image
|
||||
# image is a PIL image, convert to bgr
|
||||
pil_image = super().__getitem__(index)
|
||||
open_cv_image = np.array(pil_image)
|
||||
# Convert RGB to BGR
|
||||
open_cv_image = open_cv_image[:, :, ::-1].copy()
|
||||
|
||||
# apply augmentations
|
||||
augmented = self.aug_transform(image=open_cv_image)["image"]
|
||||
|
||||
# convert back to RGB tensor
|
||||
augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# convert to PIL image
|
||||
augmented = Image.fromarray(augmented)
|
||||
|
||||
# return both # return image as 0 - 1 tensor
|
||||
return transforms.ToTensor()(pil_image), transforms.ToTensor()(augmented)
|
||||
|
||||
51
toolkit/esrgan_utils.py
Normal file
51
toolkit/esrgan_utils.py
Normal file
@@ -0,0 +1,51 @@
|
||||
|
||||
to_basicsr_dict = {
|
||||
'model.0.weight': 'conv_first.weight',
|
||||
'model.0.bias': 'conv_first.bias',
|
||||
'model.1.sub.23.weight': 'conv_body.weight',
|
||||
'model.1.sub.23.bias': 'conv_body.bias',
|
||||
'model.3.weight': 'conv_up1.weight',
|
||||
'model.3.bias': 'conv_up1.bias',
|
||||
'model.6.weight': 'conv_up2.weight',
|
||||
'model.6.bias': 'conv_up2.bias',
|
||||
'model.8.weight': 'conv_hr.weight',
|
||||
'model.8.bias': 'conv_hr.bias',
|
||||
'model.10.bias': 'conv_last.bias',
|
||||
'model.10.weight': 'conv_last.weight',
|
||||
# 'model.1.sub.0.RDB1.conv1.0.weight': 'body.0.rdb1.conv1.weight'
|
||||
}
|
||||
|
||||
def convert_state_dict_to_basicsr(state_dict):
|
||||
new_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if k in to_basicsr_dict:
|
||||
new_state_dict[to_basicsr_dict[k]] = v
|
||||
elif k.startswith('model.1.sub.'):
|
||||
bsr_name = k.replace('model.1.sub.', 'body.').lower()
|
||||
bsr_name = bsr_name.replace('.0.weight', '.weight')
|
||||
bsr_name = bsr_name.replace('.0.bias', '.bias')
|
||||
new_state_dict[bsr_name] = v
|
||||
else:
|
||||
new_state_dict[k] = v
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# just matching a commonly used format
|
||||
def convert_basicsr_state_dict_to_save_format(state_dict):
|
||||
new_state_dict = {}
|
||||
to_basicsr_dict_values = list(to_basicsr_dict.values())
|
||||
for k, v in state_dict.items():
|
||||
if k in to_basicsr_dict_values:
|
||||
for key, value in to_basicsr_dict.items():
|
||||
if value == k:
|
||||
new_state_dict[key] = v
|
||||
|
||||
elif k.startswith('body.'):
|
||||
bsr_name = k.replace('body.', 'model.1.sub.').lower()
|
||||
bsr_name = bsr_name.replace('rdb', 'RDB')
|
||||
bsr_name = bsr_name.replace('.weight', '.0.weight')
|
||||
bsr_name = bsr_name.replace('.bias', '.0.bias')
|
||||
new_state_dict[bsr_name] = v
|
||||
else:
|
||||
new_state_dict[k] = v
|
||||
return new_state_dict
|
||||
296
toolkit/models/RRDB.py
Normal file
296
toolkit/models/RRDB.py
Normal file
@@ -0,0 +1,296 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import functools
|
||||
import math
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from . import block as B
|
||||
|
||||
|
||||
# Borrowed from https://github.com/rlaphoenix/VSGAN/blob/master/vsgan/archs/ESRGAN.py
|
||||
# Which enhanced stuff that was already here
|
||||
class RRDBNet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
state_dict,
|
||||
norm=None,
|
||||
act: str = "leakyrelu",
|
||||
upsampler: str = "upconv",
|
||||
mode: B.ConvMode = "CNA",
|
||||
) -> None:
|
||||
"""
|
||||
ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks.
|
||||
By Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao,
|
||||
and Chen Change Loy.
|
||||
This is old-arch Residual in Residual Dense Block Network and is not
|
||||
the newest revision that's available at github.com/xinntao/ESRGAN.
|
||||
This is on purpose, the newest Network has severely limited the
|
||||
potential use of the Network with no benefits.
|
||||
This network supports model files from both new and old-arch.
|
||||
Args:
|
||||
norm: Normalization layer
|
||||
act: Activation layer
|
||||
upsampler: Upsample layer. upconv, pixel_shuffle
|
||||
mode: Convolution mode
|
||||
"""
|
||||
super(RRDBNet, self).__init__()
|
||||
self.model_arch = "ESRGAN"
|
||||
self.sub_type = "SR"
|
||||
|
||||
self.state = state_dict
|
||||
self.norm = norm
|
||||
self.act = act
|
||||
self.upsampler = upsampler
|
||||
self.mode = mode
|
||||
|
||||
self.state_map = {
|
||||
# currently supports old, new, and newer RRDBNet arch models
|
||||
# ESRGAN, BSRGAN/RealSR, Real-ESRGAN
|
||||
"model.0.weight": ("conv_first.weight",),
|
||||
"model.0.bias": ("conv_first.bias",),
|
||||
"model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"),
|
||||
"model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"),
|
||||
r"model.1.sub.\1.RDB\2.conv\3.0.\4": (
|
||||
r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)",
|
||||
r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)",
|
||||
),
|
||||
}
|
||||
if "params_ema" in self.state:
|
||||
self.state = self.state["params_ema"]
|
||||
# self.model_arch = "RealESRGAN"
|
||||
self.num_blocks = self.get_num_blocks()
|
||||
self.plus = any("conv1x1" in k for k in self.state.keys())
|
||||
if self.plus:
|
||||
self.model_arch = "ESRGAN+"
|
||||
|
||||
self.state = self.new_to_old_arch(self.state)
|
||||
|
||||
self.key_arr = list(self.state.keys())
|
||||
|
||||
self.in_nc: int = self.state[self.key_arr[0]].shape[1]
|
||||
self.out_nc: int = self.state[self.key_arr[-1]].shape[0]
|
||||
|
||||
self.scale: int = self.get_scale()
|
||||
self.num_filters: int = self.state[self.key_arr[0]].shape[0]
|
||||
|
||||
c2x2 = False
|
||||
if self.state["model.0.weight"].shape[-2] == 2:
|
||||
c2x2 = True
|
||||
self.scale = round(math.sqrt(self.scale / 4))
|
||||
self.model_arch = "ESRGAN-2c2"
|
||||
|
||||
self.supports_fp16 = True
|
||||
self.supports_bfp16 = True
|
||||
self.min_size_restriction = None
|
||||
|
||||
# Detect if pixelunshuffle was used (Real-ESRGAN)
|
||||
if self.in_nc in (self.out_nc * 4, self.out_nc * 16) and self.out_nc in (
|
||||
self.in_nc / 4,
|
||||
self.in_nc / 16,
|
||||
):
|
||||
self.shuffle_factor = int(math.sqrt(self.in_nc / self.out_nc))
|
||||
else:
|
||||
self.shuffle_factor = None
|
||||
|
||||
upsample_block = {
|
||||
"upconv": B.upconv_block,
|
||||
"pixel_shuffle": B.pixelshuffle_block,
|
||||
}.get(self.upsampler)
|
||||
if upsample_block is None:
|
||||
raise NotImplementedError(f"Upsample mode [{self.upsampler}] is not found")
|
||||
|
||||
if self.scale == 3:
|
||||
upsample_blocks = upsample_block(
|
||||
in_nc=self.num_filters,
|
||||
out_nc=self.num_filters,
|
||||
upscale_factor=3,
|
||||
act_type=self.act,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
else:
|
||||
upsample_blocks = [
|
||||
upsample_block(
|
||||
in_nc=self.num_filters,
|
||||
out_nc=self.num_filters,
|
||||
act_type=self.act,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
for _ in range(int(math.log(self.scale, 2)))
|
||||
]
|
||||
|
||||
self.model = B.sequential(
|
||||
# fea conv
|
||||
B.conv_block(
|
||||
in_nc=self.in_nc,
|
||||
out_nc=self.num_filters,
|
||||
kernel_size=3,
|
||||
norm_type=None,
|
||||
act_type=None,
|
||||
c2x2=c2x2,
|
||||
),
|
||||
B.ShortcutBlock(
|
||||
B.sequential(
|
||||
# rrdb blocks
|
||||
*[
|
||||
B.RRDB(
|
||||
nf=self.num_filters,
|
||||
kernel_size=3,
|
||||
gc=32,
|
||||
stride=1,
|
||||
bias=True,
|
||||
pad_type="zero",
|
||||
norm_type=self.norm,
|
||||
act_type=self.act,
|
||||
mode="CNA",
|
||||
plus=self.plus,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
for _ in range(self.num_blocks)
|
||||
],
|
||||
# lr conv
|
||||
B.conv_block(
|
||||
in_nc=self.num_filters,
|
||||
out_nc=self.num_filters,
|
||||
kernel_size=3,
|
||||
norm_type=self.norm,
|
||||
act_type=None,
|
||||
mode=self.mode,
|
||||
c2x2=c2x2,
|
||||
),
|
||||
)
|
||||
),
|
||||
*upsample_blocks,
|
||||
# hr_conv0
|
||||
B.conv_block(
|
||||
in_nc=self.num_filters,
|
||||
out_nc=self.num_filters,
|
||||
kernel_size=3,
|
||||
norm_type=None,
|
||||
act_type=self.act,
|
||||
c2x2=c2x2,
|
||||
),
|
||||
# hr_conv1
|
||||
B.conv_block(
|
||||
in_nc=self.num_filters,
|
||||
out_nc=self.out_nc,
|
||||
kernel_size=3,
|
||||
norm_type=None,
|
||||
act_type=None,
|
||||
c2x2=c2x2,
|
||||
),
|
||||
)
|
||||
|
||||
# Adjust these properties for calculations outside of the model
|
||||
if self.shuffle_factor:
|
||||
self.in_nc //= self.shuffle_factor ** 2
|
||||
self.scale //= self.shuffle_factor
|
||||
|
||||
self.load_state_dict(self.state, strict=False)
|
||||
|
||||
def new_to_old_arch(self, state):
|
||||
"""Convert a new-arch model state dictionary to an old-arch dictionary."""
|
||||
if "params_ema" in state:
|
||||
state = state["params_ema"]
|
||||
|
||||
if "conv_first.weight" not in state:
|
||||
# model is already old arch, this is a loose check, but should be sufficient
|
||||
return state
|
||||
|
||||
# add nb to state keys
|
||||
for kind in ("weight", "bias"):
|
||||
self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[
|
||||
f"model.1.sub./NB/.{kind}"
|
||||
]
|
||||
del self.state_map[f"model.1.sub./NB/.{kind}"]
|
||||
|
||||
old_state = OrderedDict()
|
||||
for old_key, new_keys in self.state_map.items():
|
||||
for new_key in new_keys:
|
||||
if r"\1" in old_key:
|
||||
for k, v in state.items():
|
||||
sub = re.sub(new_key, old_key, k)
|
||||
if sub != k:
|
||||
old_state[sub] = v
|
||||
else:
|
||||
if new_key in state:
|
||||
old_state[old_key] = state[new_key]
|
||||
|
||||
# upconv layers
|
||||
max_upconv = 0
|
||||
for key in state.keys():
|
||||
match = re.match(r"(upconv|conv_up)(\d)\.(weight|bias)", key)
|
||||
if match is not None:
|
||||
_, key_num, key_type = match.groups()
|
||||
old_state[f"model.{int(key_num) * 3}.{key_type}"] = state[key]
|
||||
max_upconv = max(max_upconv, int(key_num) * 3)
|
||||
|
||||
# final layers
|
||||
for key in state.keys():
|
||||
if key in ("HRconv.weight", "conv_hr.weight"):
|
||||
old_state[f"model.{max_upconv + 2}.weight"] = state[key]
|
||||
elif key in ("HRconv.bias", "conv_hr.bias"):
|
||||
old_state[f"model.{max_upconv + 2}.bias"] = state[key]
|
||||
elif key in ("conv_last.weight",):
|
||||
old_state[f"model.{max_upconv + 4}.weight"] = state[key]
|
||||
elif key in ("conv_last.bias",):
|
||||
old_state[f"model.{max_upconv + 4}.bias"] = state[key]
|
||||
|
||||
# Sort by first numeric value of each layer
|
||||
def compare(item1, item2):
|
||||
parts1 = item1.split(".")
|
||||
parts2 = item2.split(".")
|
||||
int1 = int(parts1[1])
|
||||
int2 = int(parts2[1])
|
||||
return int1 - int2
|
||||
|
||||
sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare))
|
||||
|
||||
# Rebuild the output dict in the right order
|
||||
out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys)
|
||||
|
||||
return out_dict
|
||||
|
||||
def get_scale(self, min_part: int = 6) -> int:
|
||||
n = 0
|
||||
for part in list(self.state):
|
||||
parts = part.split(".")[1:]
|
||||
if len(parts) == 2:
|
||||
part_num = int(parts[0])
|
||||
if part_num > min_part and parts[1] == "weight":
|
||||
n += 1
|
||||
return 2 ** n
|
||||
|
||||
def get_num_blocks(self) -> int:
|
||||
nbs = []
|
||||
state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + (
|
||||
r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)",
|
||||
)
|
||||
for state_key in state_keys:
|
||||
for k in self.state:
|
||||
m = re.search(state_key, k)
|
||||
if m:
|
||||
nbs.append(int(m.group(1)))
|
||||
if nbs:
|
||||
break
|
||||
return max(*nbs) + 1
|
||||
|
||||
def forward(self, x):
|
||||
if self.shuffle_factor:
|
||||
_, _, h, w = x.size()
|
||||
mod_pad_h = (
|
||||
self.shuffle_factor - h % self.shuffle_factor
|
||||
) % self.shuffle_factor
|
||||
mod_pad_w = (
|
||||
self.shuffle_factor - w % self.shuffle_factor
|
||||
) % self.shuffle_factor
|
||||
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
|
||||
x = torch.pixel_unshuffle(x, downscale_factor=self.shuffle_factor)
|
||||
x = self.model(x)
|
||||
return x[:, :, : h * self.scale, : w * self.scale]
|
||||
return self.model(x)
|
||||
549
toolkit/models/block.py
Normal file
549
toolkit/models/block.py
Normal file
@@ -0,0 +1,549 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
try:
|
||||
from typing import Literal
|
||||
except ImportError:
|
||||
from typing_extensions import Literal
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
####################
|
||||
# Basic blocks
|
||||
####################
|
||||
|
||||
|
||||
def act(act_type: str, inplace=True, neg_slope=0.2, n_prelu=1):
|
||||
# helper selecting activation
|
||||
# neg_slope: for leakyrelu and init of prelu
|
||||
# n_prelu: for p_relu num_parameters
|
||||
act_type = act_type.lower()
|
||||
if act_type == "relu":
|
||||
layer = nn.ReLU(inplace)
|
||||
elif act_type == "leakyrelu":
|
||||
layer = nn.LeakyReLU(neg_slope, inplace)
|
||||
elif act_type == "prelu":
|
||||
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"activation layer [{:s}] is not found".format(act_type)
|
||||
)
|
||||
return layer
|
||||
|
||||
|
||||
def norm(norm_type: str, nc: int):
|
||||
# helper selecting normalization layer
|
||||
norm_type = norm_type.lower()
|
||||
if norm_type == "batch":
|
||||
layer = nn.BatchNorm2d(nc, affine=True)
|
||||
elif norm_type == "instance":
|
||||
layer = nn.InstanceNorm2d(nc, affine=False)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"normalization layer [{:s}] is not found".format(norm_type)
|
||||
)
|
||||
return layer
|
||||
|
||||
|
||||
def pad(pad_type: str, padding):
|
||||
# helper selecting padding layer
|
||||
# if padding is 'zero', do by conv layers
|
||||
pad_type = pad_type.lower()
|
||||
if padding == 0:
|
||||
return None
|
||||
if pad_type == "reflect":
|
||||
layer = nn.ReflectionPad2d(padding)
|
||||
elif pad_type == "replicate":
|
||||
layer = nn.ReplicationPad2d(padding)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"padding layer [{:s}] is not implemented".format(pad_type)
|
||||
)
|
||||
return layer
|
||||
|
||||
|
||||
def get_valid_padding(kernel_size, dilation):
|
||||
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
||||
padding = (kernel_size - 1) // 2
|
||||
return padding
|
||||
|
||||
|
||||
class ConcatBlock(nn.Module):
|
||||
# Concat the output of a submodule to its input
|
||||
def __init__(self, submodule):
|
||||
super(ConcatBlock, self).__init__()
|
||||
self.sub = submodule
|
||||
|
||||
def forward(self, x):
|
||||
output = torch.cat((x, self.sub(x)), dim=1)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
tmpstr = "Identity .. \n|"
|
||||
modstr = self.sub.__repr__().replace("\n", "\n|")
|
||||
tmpstr = tmpstr + modstr
|
||||
return tmpstr
|
||||
|
||||
|
||||
class ShortcutBlock(nn.Module):
|
||||
# Elementwise sum the output of a submodule to its input
|
||||
def __init__(self, submodule):
|
||||
super(ShortcutBlock, self).__init__()
|
||||
self.sub = submodule
|
||||
|
||||
def forward(self, x):
|
||||
output = x + self.sub(x)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
tmpstr = "Identity + \n|"
|
||||
modstr = self.sub.__repr__().replace("\n", "\n|")
|
||||
tmpstr = tmpstr + modstr
|
||||
return tmpstr
|
||||
|
||||
|
||||
class ShortcutBlockSPSR(nn.Module):
|
||||
# Elementwise sum the output of a submodule to its input
|
||||
def __init__(self, submodule):
|
||||
super(ShortcutBlockSPSR, self).__init__()
|
||||
self.sub = submodule
|
||||
|
||||
def forward(self, x):
|
||||
return x, self.sub
|
||||
|
||||
def __repr__(self):
|
||||
tmpstr = "Identity + \n|"
|
||||
modstr = self.sub.__repr__().replace("\n", "\n|")
|
||||
tmpstr = tmpstr + modstr
|
||||
return tmpstr
|
||||
|
||||
|
||||
def sequential(*args):
|
||||
# Flatten Sequential. It unwraps nn.Sequential.
|
||||
if len(args) == 1:
|
||||
if isinstance(args[0], OrderedDict):
|
||||
raise NotImplementedError("sequential does not support OrderedDict input.")
|
||||
return args[0] # No sequential is needed.
|
||||
modules = []
|
||||
for module in args:
|
||||
if isinstance(module, nn.Sequential):
|
||||
for submodule in module.children():
|
||||
modules.append(submodule)
|
||||
elif isinstance(module, nn.Module):
|
||||
modules.append(module)
|
||||
return nn.Sequential(*modules)
|
||||
|
||||
|
||||
ConvMode = Literal["CNA", "NAC", "CNAC"]
|
||||
|
||||
|
||||
# 2x2x2 Conv Block
|
||||
def conv_block_2c2(
|
||||
in_nc,
|
||||
out_nc,
|
||||
act_type="relu",
|
||||
):
|
||||
return sequential(
|
||||
nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1),
|
||||
nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0),
|
||||
act(act_type) if act_type else None,
|
||||
)
|
||||
|
||||
|
||||
def conv_block(
|
||||
in_nc: int,
|
||||
out_nc: int,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
pad_type="zero",
|
||||
norm_type: str | None = None,
|
||||
act_type: str | None = "relu",
|
||||
mode: ConvMode = "CNA",
|
||||
c2x2=False,
|
||||
):
|
||||
"""
|
||||
Conv layer with padding, normalization, activation
|
||||
mode: CNA --> Conv -> Norm -> Act
|
||||
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
|
||||
"""
|
||||
|
||||
if c2x2:
|
||||
return conv_block_2c2(in_nc, out_nc, act_type=act_type)
|
||||
|
||||
assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode)
|
||||
padding = get_valid_padding(kernel_size, dilation)
|
||||
p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
|
||||
padding = padding if pad_type == "zero" else 0
|
||||
|
||||
c = nn.Conv2d(
|
||||
in_nc,
|
||||
out_nc,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias,
|
||||
groups=groups,
|
||||
)
|
||||
a = act(act_type) if act_type else None
|
||||
if mode in ("CNA", "CNAC"):
|
||||
n = norm(norm_type, out_nc) if norm_type else None
|
||||
return sequential(p, c, n, a)
|
||||
elif mode == "NAC":
|
||||
if norm_type is None and act_type is not None:
|
||||
a = act(act_type, inplace=False)
|
||||
# Important!
|
||||
# input----ReLU(inplace)----Conv--+----output
|
||||
# |________________________|
|
||||
# inplace ReLU will modify the input, therefore wrong output
|
||||
n = norm(norm_type, in_nc) if norm_type else None
|
||||
return sequential(n, a, p, c)
|
||||
else:
|
||||
assert False, f"Invalid conv mode {mode}"
|
||||
|
||||
|
||||
####################
|
||||
# Useful blocks
|
||||
####################
|
||||
|
||||
|
||||
class ResNetBlock(nn.Module):
|
||||
"""
|
||||
ResNet Block, 3-3 style
|
||||
with extra residual scaling used in EDSR
|
||||
(Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_nc,
|
||||
mid_nc,
|
||||
out_nc,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
pad_type="zero",
|
||||
norm_type=None,
|
||||
act_type="relu",
|
||||
mode: ConvMode = "CNA",
|
||||
res_scale=1,
|
||||
):
|
||||
super(ResNetBlock, self).__init__()
|
||||
conv0 = conv_block(
|
||||
in_nc,
|
||||
mid_nc,
|
||||
kernel_size,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
pad_type,
|
||||
norm_type,
|
||||
act_type,
|
||||
mode,
|
||||
)
|
||||
if mode == "CNA":
|
||||
act_type = None
|
||||
if mode == "CNAC": # Residual path: |-CNAC-|
|
||||
act_type = None
|
||||
norm_type = None
|
||||
conv1 = conv_block(
|
||||
mid_nc,
|
||||
out_nc,
|
||||
kernel_size,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
pad_type,
|
||||
norm_type,
|
||||
act_type,
|
||||
mode,
|
||||
)
|
||||
# if in_nc != out_nc:
|
||||
# self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \
|
||||
# None, None)
|
||||
# print('Need a projecter in ResNetBlock.')
|
||||
# else:
|
||||
# self.project = lambda x:x
|
||||
self.res = sequential(conv0, conv1)
|
||||
self.res_scale = res_scale
|
||||
|
||||
def forward(self, x):
|
||||
res = self.res(x).mul(self.res_scale)
|
||||
return x + res
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
"""
|
||||
Residual in Residual Dense Block
|
||||
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nf,
|
||||
kernel_size=3,
|
||||
gc=32,
|
||||
stride=1,
|
||||
bias: bool = True,
|
||||
pad_type="zero",
|
||||
norm_type=None,
|
||||
act_type="leakyrelu",
|
||||
mode: ConvMode = "CNA",
|
||||
_convtype="Conv2D",
|
||||
_spectral_norm=False,
|
||||
plus=False,
|
||||
c2x2=False,
|
||||
):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(
|
||||
nf,
|
||||
kernel_size,
|
||||
gc,
|
||||
stride,
|
||||
bias,
|
||||
pad_type,
|
||||
norm_type,
|
||||
act_type,
|
||||
mode,
|
||||
plus=plus,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.RDB2 = ResidualDenseBlock_5C(
|
||||
nf,
|
||||
kernel_size,
|
||||
gc,
|
||||
stride,
|
||||
bias,
|
||||
pad_type,
|
||||
norm_type,
|
||||
act_type,
|
||||
mode,
|
||||
plus=plus,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.RDB3 = ResidualDenseBlock_5C(
|
||||
nf,
|
||||
kernel_size,
|
||||
gc,
|
||||
stride,
|
||||
bias,
|
||||
pad_type,
|
||||
norm_type,
|
||||
act_type,
|
||||
mode,
|
||||
plus=plus,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.RDB1(x)
|
||||
out = self.RDB2(out)
|
||||
out = self.RDB3(out)
|
||||
return out * 0.2 + x
|
||||
|
||||
|
||||
class ResidualDenseBlock_5C(nn.Module):
|
||||
"""
|
||||
Residual Dense Block
|
||||
style: 5 convs
|
||||
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
||||
Modified options that can be used:
|
||||
- "Partial Convolution based Padding" arXiv:1811.11718
|
||||
- "Spectral normalization" arXiv:1802.05957
|
||||
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
||||
{Rakotonirina} and A. {Rasoanaivo}
|
||||
|
||||
Args:
|
||||
nf (int): Channel number of intermediate features (num_feat).
|
||||
gc (int): Channels for each growth (num_grow_ch: growth channel,
|
||||
i.e. intermediate channels).
|
||||
convtype (str): the type of convolution to use. Default: 'Conv2D'
|
||||
gaussian_noise (bool): enable the ESRGAN+ gaussian noise (no new
|
||||
trainable parameters)
|
||||
plus (bool): enable the additional residual paths from ESRGAN+
|
||||
(adds trainable parameters)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nf=64,
|
||||
kernel_size=3,
|
||||
gc=32,
|
||||
stride=1,
|
||||
bias: bool = True,
|
||||
pad_type="zero",
|
||||
norm_type=None,
|
||||
act_type="leakyrelu",
|
||||
mode: ConvMode = "CNA",
|
||||
plus=False,
|
||||
c2x2=False,
|
||||
):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
|
||||
## +
|
||||
self.conv1x1 = conv1x1(nf, gc) if plus else None
|
||||
## +
|
||||
|
||||
self.conv1 = conv_block(
|
||||
nf,
|
||||
gc,
|
||||
kernel_size,
|
||||
stride,
|
||||
bias=bias,
|
||||
pad_type=pad_type,
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.conv2 = conv_block(
|
||||
nf + gc,
|
||||
gc,
|
||||
kernel_size,
|
||||
stride,
|
||||
bias=bias,
|
||||
pad_type=pad_type,
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.conv3 = conv_block(
|
||||
nf + 2 * gc,
|
||||
gc,
|
||||
kernel_size,
|
||||
stride,
|
||||
bias=bias,
|
||||
pad_type=pad_type,
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.conv4 = conv_block(
|
||||
nf + 3 * gc,
|
||||
gc,
|
||||
kernel_size,
|
||||
stride,
|
||||
bias=bias,
|
||||
pad_type=pad_type,
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
if mode == "CNA":
|
||||
last_act = None
|
||||
else:
|
||||
last_act = act_type
|
||||
self.conv5 = conv_block(
|
||||
nf + 4 * gc,
|
||||
nf,
|
||||
3,
|
||||
stride,
|
||||
bias=bias,
|
||||
pad_type=pad_type,
|
||||
norm_type=norm_type,
|
||||
act_type=last_act,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.conv1(x)
|
||||
x2 = self.conv2(torch.cat((x, x1), 1))
|
||||
if self.conv1x1:
|
||||
# pylint: disable=not-callable
|
||||
x2 = x2 + self.conv1x1(x) # +
|
||||
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
||||
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
||||
if self.conv1x1:
|
||||
x4 = x4 + x2 # +
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
return x5 * 0.2 + x
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
####################
|
||||
# Upsampler
|
||||
####################
|
||||
|
||||
|
||||
def pixelshuffle_block(
|
||||
in_nc: int,
|
||||
out_nc: int,
|
||||
upscale_factor=2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
bias=True,
|
||||
pad_type="zero",
|
||||
norm_type: str | None = None,
|
||||
act_type="relu",
|
||||
):
|
||||
"""
|
||||
Pixel shuffle layer
|
||||
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
||||
Neural Network, CVPR17)
|
||||
"""
|
||||
conv = conv_block(
|
||||
in_nc,
|
||||
out_nc * (upscale_factor ** 2),
|
||||
kernel_size,
|
||||
stride,
|
||||
bias=bias,
|
||||
pad_type=pad_type,
|
||||
norm_type=None,
|
||||
act_type=None,
|
||||
)
|
||||
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||
|
||||
n = norm(norm_type, out_nc) if norm_type else None
|
||||
a = act(act_type) if act_type else None
|
||||
return sequential(conv, pixel_shuffle, n, a)
|
||||
|
||||
|
||||
def upconv_block(
|
||||
in_nc: int,
|
||||
out_nc: int,
|
||||
upscale_factor=2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
bias=True,
|
||||
pad_type="zero",
|
||||
norm_type: str | None = None,
|
||||
act_type="relu",
|
||||
mode="nearest",
|
||||
c2x2=False,
|
||||
):
|
||||
# Up conv
|
||||
# described in https://distill.pub/2016/deconv-checkerboard/
|
||||
# convert to float 16 if is bfloat16
|
||||
upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
|
||||
conv = conv_block(
|
||||
in_nc,
|
||||
out_nc,
|
||||
kernel_size,
|
||||
stride,
|
||||
bias=bias,
|
||||
pad_type=pad_type,
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
return sequential(upsample, conv)
|
||||
Reference in New Issue
Block a user