diff --git a/jobs/TrainJob.py b/jobs/TrainJob.py index a1939890..fc594edf 100644 --- a/jobs/TrainJob.py +++ b/jobs/TrainJob.py @@ -20,6 +20,7 @@ process_dict = { 'slider_old': 'TrainSliderProcessOld', 'lora_hack': 'TrainLoRAHack', 'rescale_sd': 'TrainSDRescaleProcess', + 'esrgan': 'TrainESRGANProcess', } diff --git a/jobs/process/TrainESRGANProcess.py b/jobs/process/TrainESRGANProcess.py new file mode 100644 index 00000000..7cc08b41 --- /dev/null +++ b/jobs/process/TrainESRGANProcess.py @@ -0,0 +1,575 @@ +import copy +import glob +import os +import time +from collections import OrderedDict + +from PIL import Image +from PIL.ImageOps import exif_transpose +# from basicsr.archs.rrdbnet_arch import RRDBNet +from toolkit.models.RRDB import RRDBNet as ESRGAN +from safetensors.torch import save_file, load_file +from torch.utils.data import DataLoader, ConcatDataset +import torch +from torch import nn +from torchvision.transforms import transforms + +from jobs.process import BaseTrainProcess +from toolkit.data_loader import AugmentedImageDataset +from toolkit.esrgan_utils import convert_state_dict_to_basicsr, convert_basicsr_state_dict_to_save_format +from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss +from toolkit.metadata import get_meta_for_safetensors +from toolkit.optimizer import get_optimizer +from toolkit.style import get_style_model_and_losses +from toolkit.train_tools import get_torch_dtype +from diffusers import AutoencoderKL +from tqdm import tqdm +import time +import numpy as np +from .models.vgg19_critic import Critic + +IMAGE_TRANSFORMS = transforms.Compose( + [ + transforms.ToTensor(), + # transforms.Normalize([0.5], [0.5]), + ] +) + + +class TrainESRGANProcess(BaseTrainProcess): + def __init__(self, process_id: int, job, config: OrderedDict): + super().__init__(process_id, job, config) + self.data_loader = None + self.model = None + self.device = self.get_conf('device', self.job.device) + self.pretrained_path = self.get_conf('pretrained_path', 'None') + self.datasets_objects = self.get_conf('datasets', required=True) + self.batch_size = self.get_conf('batch_size', 1, as_type=int) + self.resolution = self.get_conf('resolution', 256, as_type=int) + self.learning_rate = self.get_conf('learning_rate', 1e-6, as_type=float) + self.sample_every = self.get_conf('sample_every', None) + self.optimizer_type = self.get_conf('optimizer', 'adam') + self.epochs = self.get_conf('epochs', None, as_type=int) + self.max_steps = self.get_conf('max_steps', None, as_type=int) + self.save_every = self.get_conf('save_every', None) + self.upscale_sample = self.get_conf('upscale_sample', 4) + self.dtype = self.get_conf('dtype', 'float32') + self.sample_sources = self.get_conf('sample_sources', None) + self.log_every = self.get_conf('log_every', 100, as_type=int) + self.style_weight = self.get_conf('style_weight', 0, as_type=float) + self.content_weight = self.get_conf('content_weight', 0, as_type=float) + self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float) + self.zoom = self.get_conf('zoom', 4, as_type=int) + self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float) + self.critic_weight = self.get_conf('critic_weight', 1, as_type=float) + self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float) + self.optimizer_params = self.get_conf('optimizer_params', {}) + self.augmentations = self.get_conf('augmentations', {}) + self.torch_dtype = get_torch_dtype(self.dtype) + if self.torch_dtype == torch.bfloat16: + self.esrgan_dtype = torch.float16 + else: + self.esrgan_dtype = torch.float32 + self.vgg_19 = None + self.style_weight_scalers = [] + self.content_weight_scalers = [] + + # throw error if zoom if not divisible by 2 + if self.zoom % 2 != 0: + raise ValueError('zoom must be divisible by 2') + + self.step_num = 0 + self.epoch_num = 0 + + self.use_critic = self.get_conf('use_critic', False, as_type=bool) + self.critic = None + + if self.use_critic: + self.critic = Critic( + device=self.device, + dtype=self.dtype, + process=self, + **self.get_conf('critic', {}) # pass any other params + ) + + if self.sample_every is not None and self.sample_sources is None: + raise ValueError('sample_every is specified but sample_sources is not') + + if self.epochs is None and self.max_steps is None: + raise ValueError('epochs or max_steps must be specified') + + self.data_loaders = [] + # check datasets + assert isinstance(self.datasets_objects, list) + for dataset in self.datasets_objects: + if 'path' not in dataset: + raise ValueError('dataset must have a path') + # check if is dir + if not os.path.isdir(dataset['path']): + raise ValueError(f"dataset path does is not a directory: {dataset['path']}") + + # make training folder + if not os.path.exists(self.save_root): + os.makedirs(self.save_root, exist_ok=True) + + self._pattern_loss = None + + # build augmentation transforms + aug_transforms = [] + + def update_training_metadata(self): + self.add_meta(OrderedDict({"training_info": self.get_training_info()})) + + def get_training_info(self): + info = OrderedDict({ + 'step': self.step_num, + 'epoch': self.epoch_num, + }) + return info + + def load_datasets(self): + if self.data_loader is None: + print(f"Loading datasets") + datasets = [] + for dataset in self.datasets_objects: + print(f" - Dataset: {dataset['path']}") + ds = copy.copy(dataset) + ds['resolution'] = self.resolution + + if 'augmentations' not in ds: + ds['augmentations'] = self.augmentations + + # add the resize down augmentation + ds['augmentations'] = [{ + 'method': 'Resize', + 'params': { + 'width': int(self.resolution // self.zoom), + 'height': int(self.resolution // self.zoom), + # downscale interpolation, string will be evaluated + 'interpolation': 'cv2.INTER_AREA' + } + }] + ds['augmentations'] + + image_dataset = AugmentedImageDataset(ds) + datasets.append(image_dataset) + + concatenated_dataset = ConcatDataset(datasets) + self.data_loader = DataLoader( + concatenated_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=6 + ) + + def setup_vgg19(self): + if self.vgg_19 is None: + self.vgg_19, self.style_losses, self.content_losses, self.vgg19_pool_4 = get_style_model_and_losses( + single_target=True, + device=self.device, + output_layer_name='pool_4', + dtype=self.torch_dtype + ) + self.vgg_19.to(self.device, dtype=self.torch_dtype) + self.vgg_19.requires_grad_(False) + + # we run random noise through first to get layer scalers to normalize the loss per layer + # bs of 2 because we run pred and target through stacked + noise = torch.randn((2, 3, self.resolution, self.resolution), device=self.device, dtype=self.torch_dtype) + self.vgg_19(noise) + for style_loss in self.style_losses: + # get a scaler to normalize to 1 + scaler = 1 / torch.mean(style_loss.loss).item() + self.style_weight_scalers.append(scaler) + for content_loss in self.content_losses: + # get a scaler to normalize to 1 + scaler = 1 / torch.mean(content_loss.loss).item() + # if is nan, set to 1 + if scaler != scaler: + scaler = 1 + print(f"Warning: content loss scaler is nan, setting to 1") + self.content_weight_scalers.append(scaler) + + self.print(f"Style weight scalers: {self.style_weight_scalers}") + self.print(f"Content weight scalers: {self.content_weight_scalers}") + + def get_style_loss(self): + if self.style_weight > 0: + # scale all losses with loss scalers + loss = torch.sum( + torch.stack([loss.loss * scaler for loss, scaler in zip(self.style_losses, self.style_weight_scalers)])) + return loss + else: + return torch.tensor(0.0, device=self.device) + + def get_content_loss(self): + if self.content_weight > 0: + # scale all losses with loss scalers + loss = torch.sum(torch.stack( + [loss.loss * scaler for loss, scaler in zip(self.content_losses, self.content_weight_scalers)])) + return loss + else: + return torch.tensor(0.0, device=self.device) + + def get_mse_loss(self, pred, target): + if self.mse_weight > 0: + loss_fn = nn.MSELoss() + loss = loss_fn(pred, target) + return loss + else: + return torch.tensor(0.0, device=self.device) + + def get_tv_loss(self, pred, target): + if self.tv_weight > 0: + get_tv_loss = ComparativeTotalVariation() + loss = get_tv_loss(pred, target) + return loss + else: + return torch.tensor(0.0, device=self.device) + + def get_pattern_loss(self, pred, target): + if self._pattern_loss is None: + self._pattern_loss = PatternLoss( + pattern_size=self.zoom, + dtype=self.torch_dtype + ).to(self.device, dtype=self.torch_dtype) + loss = torch.mean(self._pattern_loss(pred, target)) + return loss + + def save(self, step=None): + if not os.path.exists(self.save_root): + os.makedirs(self.save_root, exist_ok=True) + + step_num = '' + if step is not None: + # zeropad 9 digits + step_num = f"_{str(step).zfill(9)}" + + self.update_training_metadata() + filename = f'{self.job.name}{step_num}.safetensors' + # prepare meta + save_meta = get_meta_for_safetensors(self.meta, self.job.name) + + # state_dict = self.model.state_dict() + + # state has the original state dict keys so we can save what we started from + save_state_dict = self.model.state + + for key in list(save_state_dict.keys()): + v = save_state_dict[key] + v = v.detach().clone().to("cpu").to(torch.float32) + save_state_dict[key] = v + + # having issues with meta + save_file(save_state_dict, os.path.join(self.save_root, filename), save_meta) + + self.print(f"Saved to {os.path.join(self.save_root, filename)}") + + if self.use_critic: + self.critic.save(step) + + def sample(self, step=None): + sample_folder = os.path.join(self.save_root, 'samples') + if not os.path.exists(sample_folder): + os.makedirs(sample_folder, exist_ok=True) + + self.model.eval() + + with torch.no_grad(): + for i, img_url in enumerate(self.sample_sources): + img = exif_transpose(Image.open(img_url)) + img = img.convert('RGB') + # crop if not square + if img.width != img.height: + min_dim = min(img.width, img.height) + img = img.crop((0, 0, min_dim, min_dim)) + # resize + img = img.resize((self.resolution * self.zoom, self.resolution * self.zoom), resample=Image.BICUBIC) + + target_image = img + # downscale the image input + img = img.resize((self.resolution, self.resolution), resample=Image.BICUBIC) + + # downscale the image input + + img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.esrgan_dtype) + img = img + output = self.model(img) + # output = (output / 2 + 0.5).clamp(0, 1) + output = output.clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + output = output.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy() + + # convert to pillow image + output = Image.fromarray((output * 255).astype(np.uint8)) + + # upscale to size * self.upscale_sample while maintaining pixels + output = output.resize( + (self.resolution * self.upscale_sample, self.resolution * self.upscale_sample), + resample=Image.NEAREST + ) + + width, height = output.size + + # stack input image and decoded image + target_image = target_image.resize((width, height)) + output = output.resize((width, height)) + + output_img = Image.new('RGB', (width * 2, height)) + output_img.paste(target_image, (0, 0)) + output_img.paste(output, (width, 0)) + + step_num = '' + if step is not None: + # zero-pad 9 digits + step_num = f"_{str(step).zfill(9)}" + seconds_since_epoch = int(time.time()) + # zero-pad 2 digits + i_str = str(i).zfill(2) + filename = f"{seconds_since_epoch}{step_num}_{i_str}.png" + output_img.save(os.path.join(sample_folder, filename)) + + self.model.train() + + def load_model(self): + state_dict = None + path_to_load = self.pretrained_path + # see if we have a checkpoint in out output to resume from + self.print(f"Looking for latest checkpoint in {self.save_root}") + files = glob.glob(os.path.join(self.save_root, f"{self.job.name}*.safetensors")) + if files and len(files) > 0: + latest_file = max(files, key=os.path.getmtime) + print(f" - Latest checkpoint is: {latest_file}") + path_to_load = latest_file + # todo update step and epoch count + elif self.pretrained_path is None: + self.print(f" - No checkpoint found, starting from scratch") + else: + self.print(f" - No checkpoint found, loading pretrained model") + self.print(f" - path: {path_to_load}") + + if path_to_load is not None: + self.print(f" - Loading pretrained checkpoint: {self.pretrained_path}") + # if ends with pth then assume pytorch checkpoint + if path_to_load.endswith('.pth') or path_to_load.endswith('.pt'): + state_dict = torch.load(path_to_load, map_location=self.device) + elif path_to_load.endswith('.safetensors'): + state_dict = load_file(path_to_load) + else: + raise Exception(f"Unknown file extension for checkpoint: {path_to_load}") + + # todo determine architecture from checkpoint + self.model = ESRGAN( + state_dict + ).to(self.device, dtype=self.esrgan_dtype) + + # set the model to training mode + self.model.train() + self.model.requires_grad_(True) + + def run(self): + super().run() + self.load_datasets() + + max_step_epochs = self.max_steps // len(self.data_loader) + num_epochs = self.epochs + if num_epochs is None or num_epochs > max_step_epochs: + num_epochs = max_step_epochs + + max_epoch_steps = len(self.data_loader) * num_epochs + num_steps = self.max_steps + if num_steps is None or num_steps > max_epoch_steps: + num_steps = max_epoch_steps + self.max_steps = num_steps + self.epochs = num_epochs + start_step = self.step_num + self.first_step = start_step + + self.print(f"Training ESRGAN model:") + self.print(f" - Training folder: {self.training_folder}") + self.print(f" - Batch size: {self.batch_size}") + self.print(f" - Learning rate: {self.learning_rate}") + self.print(f" - Epochs: {num_epochs}") + self.print(f" - Max steps: {self.max_steps}") + + # load model + self.load_model() + + params = self.model.parameters() + + if self.style_weight > 0 or self.content_weight > 0 or self.use_critic: + self.setup_vgg19() + self.vgg_19.requires_grad_(False) + self.vgg_19.eval() + if self.use_critic: + self.critic.setup() + + optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate, + optimizer_params=self.optimizer_params) + + # setup scheduler + # todo allow other schedulers + scheduler = torch.optim.lr_scheduler.ConstantLR( + optimizer, + total_iters=num_steps, + factor=1, + verbose=False + ) + + # setup tqdm progress bar + self.progress_bar = tqdm( + total=num_steps, + desc='Training ESRGAN', + leave=True + ) + + blank_losses = OrderedDict({ + "total": [], + "style": [], + "content": [], + "mse": [], + "kl": [], + "tv": [], + "ptn": [], + "crD": [], + "crG": [], + }) + epoch_losses = copy.deepcopy(blank_losses) + log_losses = copy.deepcopy(blank_losses) + print("Generating baseline samples") + self.sample(step=0) + # range start at self.epoch_num go to self.epochs + for epoch in range(self.epoch_num, self.epochs, 1): + if self.step_num >= self.max_steps: + break + for targets, inputs in self.data_loader: + if self.step_num >= self.max_steps: + break + with torch.no_grad(): + targets = targets.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1) + inputs = inputs.to(self.device, dtype=self.esrgan_dtype).clamp(0, 1) + + pred = self.model(inputs) + + pred = pred.to(self.device, dtype=self.torch_dtype).clamp(0, 1) + targets = targets.to(self.device, dtype=self.torch_dtype).clamp(0, 1) + + # Run through VGG19 + if self.style_weight > 0 or self.content_weight > 0 or self.use_critic: + stacked = torch.cat([pred, targets], dim=0) + # stacked = (stacked / 2 + 0.5).clamp(0, 1) + stacked = stacked.clamp(0, 1) + self.vgg_19(stacked) + + if self.use_critic: + critic_d_loss = self.critic.step(self.vgg19_pool_4.tensor.detach()) + else: + critic_d_loss = 0.0 + + style_loss = self.get_style_loss() * self.style_weight + content_loss = self.get_content_loss() * self.content_weight + mse_loss = self.get_mse_loss(pred, targets) * self.mse_weight + tv_loss = self.get_tv_loss(pred, targets) * self.tv_weight + pattern_loss = self.get_pattern_loss(pred, targets) * self.pattern_weight + if self.use_critic: + critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight + else: + critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) + + loss = style_loss + content_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + + # Backward pass and optimization + optimizer.zero_grad() + loss.backward() + optimizer.step() + scheduler.step() + + # update progress bar + loss_value = loss.item() + # get exponent like 3.54e-4 + loss_string = f"loss: {loss_value:.2e}" + if self.content_weight > 0: + loss_string += f" cnt: {content_loss.item():.2e}" + if self.style_weight > 0: + loss_string += f" sty: {style_loss.item():.2e}" + if self.mse_weight > 0: + loss_string += f" mse: {mse_loss.item():.2e}" + if self.tv_weight > 0: + loss_string += f" tv: {tv_loss.item():.2e}" + if self.pattern_weight > 0: + loss_string += f" ptn: {pattern_loss.item():.2e}" + if self.use_critic and self.critic_weight > 0: + loss_string += f" crG: {critic_gen_loss.item():.2e}" + if self.use_critic: + loss_string += f" crD: {critic_d_loss:.2e}" + + if self.optimizer_type.startswith('dadaptation') or self.optimizer_type.startswith('prodigy'): + learning_rate = ( + optimizer.param_groups[0]["d"] * + optimizer.param_groups[0]["lr"] + ) + else: + learning_rate = optimizer.param_groups[0]['lr'] + + lr_critic_string = '' + if self.use_critic: + lr_critic = self.critic.get_lr() + lr_critic_string = f" lrC: {lr_critic:.1e}" + + self.progress_bar.set_postfix_str(f"lr: {learning_rate:.1e}{lr_critic_string} {loss_string}") + self.progress_bar.set_description(f"E: {epoch}") + self.progress_bar.update(1) + + epoch_losses["total"].append(loss_value) + epoch_losses["style"].append(style_loss.item()) + epoch_losses["content"].append(content_loss.item()) + epoch_losses["mse"].append(mse_loss.item()) + epoch_losses["tv"].append(tv_loss.item()) + epoch_losses["ptn"].append(pattern_loss.item()) + epoch_losses["crG"].append(critic_gen_loss.item()) + epoch_losses["crD"].append(critic_d_loss) + + log_losses["total"].append(loss_value) + log_losses["style"].append(style_loss.item()) + log_losses["content"].append(content_loss.item()) + log_losses["mse"].append(mse_loss.item()) + log_losses["tv"].append(tv_loss.item()) + log_losses["ptn"].append(pattern_loss.item()) + log_losses["crG"].append(critic_gen_loss.item()) + log_losses["crD"].append(critic_d_loss) + + # don't do on first step + if self.step_num != start_step: + if self.sample_every and self.step_num % self.sample_every == 0: + # print above the progress bar + self.print(f"Sampling at step {self.step_num}") + self.sample(self.step_num) + + if self.save_every and self.step_num % self.save_every == 0: + # print above the progress bar + self.print(f"Saving at step {self.step_num}") + self.save(self.step_num) + + if self.log_every and self.step_num % self.log_every == 0: + # log to tensorboard + if self.writer is not None: + # get avg loss + for key in log_losses: + log_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + 1e-6) + # if log_losses[key] > 0: + self.writer.add_scalar(f"loss/{key}", log_losses[key], self.step_num) + # reset log losses + log_losses = copy.deepcopy(blank_losses) + + self.step_num += 1 + # end epoch + if self.writer is not None: + eps = 1e-6 + # get avg loss + for key in epoch_losses: + epoch_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + eps) + if epoch_losses[key] > 0: + self.writer.add_scalar(f"epoch loss/{key}", epoch_losses[key], epoch) + # reset epoch losses + epoch_losses = copy.deepcopy(blank_losses) + + self.save() diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index 5eba8174..a7da8160 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -24,6 +24,7 @@ from diffusers import AutoencoderKL from tqdm import tqdm import time import numpy as np +from .models.vgg19_critic import Critic IMAGE_TRANSFORMS = transforms.Compose( [ @@ -37,145 +38,6 @@ def unnormalize(tensor): return (tensor / 2 + 0.5).clamp(0, 1) -class Critic: - process: 'TrainVAEProcess' - - def __init__( - self, - learning_rate=1e-5, - device='cpu', - optimizer='adam', - num_critic_per_gen=1, - dtype='float32', - lambda_gp=10, - start_step=0, - warmup_steps=1000, - process=None, - optimizer_params=None, - ): - self.learning_rate = learning_rate - self.device = device - self.optimizer_type = optimizer - self.num_critic_per_gen = num_critic_per_gen - self.dtype = dtype - self.torch_dtype = get_torch_dtype(self.dtype) - self.process = process - self.model = None - self.optimizer = None - self.scheduler = None - self.warmup_steps = warmup_steps - self.start_step = start_step - self.lambda_gp = lambda_gp - - if optimizer_params is None: - optimizer_params = {} - self.optimizer_params = optimizer_params - self.print = self.process.print - print(f" Critic config: {self.__dict__}") - - def setup(self): - from .models.vgg19_critic import Vgg19Critic - self.model = Vgg19Critic().to(self.device, dtype=self.torch_dtype) - self.load_weights() - self.model.train() - self.model.requires_grad_(True) - params = self.model.parameters() - self.optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate, - optimizer_params=self.optimizer_params) - self.scheduler = torch.optim.lr_scheduler.ConstantLR( - self.optimizer, - total_iters=self.process.max_steps * self.num_critic_per_gen, - factor=1, - verbose=False - ) - - def load_weights(self): - path_to_load = None - self.print(f"Critic: Looking for latest checkpoint in {self.process.save_root}") - files = glob.glob(os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}*.safetensors")) - if files and len(files) > 0: - latest_file = max(files, key=os.path.getmtime) - print(f" - Latest checkpoint is: {latest_file}") - path_to_load = latest_file - else: - self.print(f" - No checkpoint found, starting from scratch") - if path_to_load: - self.model.load_state_dict(load_file(path_to_load)) - - def save(self, step=None): - self.process.update_training_metadata() - save_meta = get_meta_for_safetensors(self.process.meta, self.process.job.name) - step_num = '' - if step is not None: - # zeropad 9 digits - step_num = f"_{str(step).zfill(9)}" - save_path = os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}{step_num}.safetensors") - save_file(self.model.state_dict(), save_path, save_meta) - self.print(f"Saved critic to {save_path}") - - def get_critic_loss(self, vgg_output): - if self.start_step > self.process.step_num: - return torch.tensor(0.0, dtype=self.torch_dtype, device=self.device) - - warmup_scaler = 1.0 - # we need a warmup when we come on of 1000 steps - # we want to scale the loss by 0.0 at self.start_step steps and 1.0 at self.start_step + warmup_steps - if self.process.step_num < self.start_step + self.warmup_steps: - warmup_scaler = (self.process.step_num - self.start_step) / self.warmup_steps - # set model to not train for generator loss - self.model.eval() - self.model.requires_grad_(False) - vgg_pred, vgg_target = torch.chunk(vgg_output, 2, dim=0) - - # run model - stacked_output = self.model(vgg_pred) - - return (-torch.mean(stacked_output)) * warmup_scaler - - def step(self, vgg_output): - - # train critic here - self.model.train() - self.model.requires_grad_(True) - - critic_losses = [] - for i in range(self.num_critic_per_gen): - inputs = vgg_output.detach() - inputs = inputs.to(self.device, dtype=self.torch_dtype) - self.optimizer.zero_grad() - - vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0) - - stacked_output = self.model(inputs) - out_pred, out_target = torch.chunk(stacked_output, 2, dim=0) - - # Compute gradient penalty - gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device) - - # Compute WGAN-GP critic loss - critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty - critic_loss.backward() - self.optimizer.zero_grad() - self.optimizer.step() - self.scheduler.step() - critic_losses.append(critic_loss.item()) - - # avg loss - loss = np.mean(critic_losses) - return loss - - def get_lr(self): - if self.optimizer_type.startswith('dadaptation'): - learning_rate = ( - self.optimizer.param_groups[0]["d"] * - self.optimizer.param_groups[0]["lr"] - ) - else: - learning_rate = self.optimizer.param_groups[0]['lr'] - - return learning_rate - - class TrainVAEProcess(BaseTrainProcess): def __init__(self, process_id: int, job, config: OrderedDict): super().__init__(process_id, job, config) diff --git a/jobs/process/__init__.py b/jobs/process/__init__.py index 4c7c7660..f731db1f 100644 --- a/jobs/process/__init__.py +++ b/jobs/process/__init__.py @@ -12,3 +12,4 @@ from .TrainSDRescaleProcess import TrainSDRescaleProcess from .ModRescaleLoraProcess import ModRescaleLoraProcess from .GenerateProcess import GenerateProcess from .BaseExtensionProcess import BaseExtensionProcess +from .TrainESRGANProcess import TrainESRGANProcess diff --git a/jobs/process/models/vgg19_critic.py b/jobs/process/models/vgg19_critic.py index 6fd2f9c2..a5ef92be 100644 --- a/jobs/process/models/vgg19_critic.py +++ b/jobs/process/models/vgg19_critic.py @@ -1,5 +1,17 @@ +import glob +import os + +import numpy as np import torch import torch.nn as nn +from safetensors.torch import load_file, save_file + +from toolkit.losses import get_gradient_penalty +from toolkit.metadata import get_meta_for_safetensors +from toolkit.optimizer import get_optimizer +from toolkit.train_tools import get_torch_dtype + +from typing import TYPE_CHECKING, Union class MeanReduce(nn.Module): @@ -36,3 +48,147 @@ class Vgg19Critic(nn.Module): def forward(self, inputs): return self.main(inputs) + + +if TYPE_CHECKING: + from jobs.process.TrainVAEProcess import TrainVAEProcess + from jobs.process.TrainESRGANProcess import TrainESRGANProcess + + +class Critic: + process: Union['TrainVAEProcess', 'TrainESRGANProcess'] + + def __init__( + self, + learning_rate=1e-5, + device='cpu', + optimizer='adam', + num_critic_per_gen=1, + dtype='float32', + lambda_gp=10, + start_step=0, + warmup_steps=1000, + process=None, + optimizer_params=None, + ): + self.learning_rate = learning_rate + self.device = device + self.optimizer_type = optimizer + self.num_critic_per_gen = num_critic_per_gen + self.dtype = dtype + self.torch_dtype = get_torch_dtype(self.dtype) + self.process = process + self.model = None + self.optimizer = None + self.scheduler = None + self.warmup_steps = warmup_steps + self.start_step = start_step + self.lambda_gp = lambda_gp + + if optimizer_params is None: + optimizer_params = {} + self.optimizer_params = optimizer_params + self.print = self.process.print + print(f" Critic config: {self.__dict__}") + + def setup(self): + self.model = Vgg19Critic().to(self.device, dtype=self.torch_dtype) + self.load_weights() + self.model.train() + self.model.requires_grad_(True) + params = self.model.parameters() + self.optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate, + optimizer_params=self.optimizer_params) + self.scheduler = torch.optim.lr_scheduler.ConstantLR( + self.optimizer, + total_iters=self.process.max_steps * self.num_critic_per_gen, + factor=1, + verbose=False + ) + + def load_weights(self): + path_to_load = None + self.print(f"Critic: Looking for latest checkpoint in {self.process.save_root}") + files = glob.glob(os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}*.safetensors")) + if files and len(files) > 0: + latest_file = max(files, key=os.path.getmtime) + print(f" - Latest checkpoint is: {latest_file}") + path_to_load = latest_file + else: + self.print(f" - No checkpoint found, starting from scratch") + if path_to_load: + self.model.load_state_dict(load_file(path_to_load)) + + def save(self, step=None): + self.process.update_training_metadata() + save_meta = get_meta_for_safetensors(self.process.meta, self.process.job.name) + step_num = '' + if step is not None: + # zeropad 9 digits + step_num = f"_{str(step).zfill(9)}" + save_path = os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}{step_num}.safetensors") + save_file(self.model.state_dict(), save_path, save_meta) + self.print(f"Saved critic to {save_path}") + + def get_critic_loss(self, vgg_output): + if self.start_step > self.process.step_num: + return torch.tensor(0.0, dtype=self.torch_dtype, device=self.device) + + warmup_scaler = 1.0 + # we need a warmup when we come on of 1000 steps + # we want to scale the loss by 0.0 at self.start_step steps and 1.0 at self.start_step + warmup_steps + if self.process.step_num < self.start_step + self.warmup_steps: + warmup_scaler = (self.process.step_num - self.start_step) / self.warmup_steps + # set model to not train for generator loss + self.model.eval() + self.model.requires_grad_(False) + vgg_pred, vgg_target = torch.chunk(vgg_output, 2, dim=0) + + # run model + stacked_output = self.model(vgg_pred) + + return (-torch.mean(stacked_output)) * warmup_scaler + + def step(self, vgg_output): + + # train critic here + self.model.train() + self.model.requires_grad_(True) + + critic_losses = [] + for i in range(self.num_critic_per_gen): + inputs = vgg_output.detach() + inputs = inputs.to(self.device, dtype=self.torch_dtype) + self.optimizer.zero_grad() + + vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0) + + stacked_output = self.model(inputs) + out_pred, out_target = torch.chunk(stacked_output, 2, dim=0) + + # Compute gradient penalty + gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device) + + # Compute WGAN-GP critic loss + critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty + critic_loss.backward() + self.optimizer.zero_grad() + self.optimizer.step() + self.scheduler.step() + critic_losses.append(critic_loss.item()) + + # avg loss + loss = np.mean(critic_losses) + return loss + + def get_lr(self): + if self.optimizer_type.startswith('dadaptation'): + learning_rate = ( + self.optimizer.param_groups[0]["d"] * + self.optimizer.param_groups[0]["lr"] + ) + else: + learning_rate = self.optimizer.param_groups[0]['lr'] + + return learning_rate + diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index e72b3098..75ece7c9 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -1,10 +1,14 @@ import os import random + +import cv2 +import numpy as np from PIL import Image from PIL.ImageOps import exif_transpose from torchvision import transforms from torch.utils.data import Dataset from tqdm import tqdm +import albumentations as A class ImageDataset(Dataset): @@ -38,7 +42,7 @@ class ImageDataset(Dataset): self.transform = transforms.Compose([ transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), + transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1] ]) def get_config(self, key, default=None, required=False): @@ -65,7 +69,7 @@ class ImageDataset(Dataset): if self.random_scale and min_img_size > self.resolution: if min_img_size < self.resolution: print( - f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={file}") + f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={img_path}") scale_size = self.resolution else: scale_size = random.randint(self.resolution, int(min_img_size)) @@ -78,3 +82,61 @@ class ImageDataset(Dataset): img = self.transform(img) return img + + +class Augments: + def __init__(self, **kwargs): + self.method_name = kwargs.get('method', None) + self.params = kwargs.get('params', {}) + + # convert kwargs enums for cv2 + for key, value in self.params.items(): + if isinstance(value, str): + # split the string + split_string = value.split('.') + if len(split_string) == 2 and split_string[0] == 'cv2': + if hasattr(cv2, split_string[1]): + self.params[key] = getattr(cv2, split_string[1].upper()) + else: + raise ValueError(f"invalid cv2 enum: {split_string[1]}") + + +class AugmentedImageDataset(ImageDataset): + def __init__(self, config): + super().__init__(config) + self.augmentations = self.get_config('augmentations', []) + self.augmentations = [Augments(**aug) for aug in self.augmentations] + + augmentation_list = [] + for aug in self.augmentations: + # make sure method name is valid + assert hasattr(A, aug.method_name), f"invalid augmentation method: {aug.method_name}" + # get the method + method = getattr(A, aug.method_name) + # add the method to the list + augmentation_list.append(method(**aug.params)) + + self.aug_transform = A.Compose(augmentation_list) + self.original_transform = self.transform + # replace transform so we get raw pil image + self.transform = transforms.Compose([]) + + def __getitem__(self, index): + # get the original image + # image is a PIL image, convert to bgr + pil_image = super().__getitem__(index) + open_cv_image = np.array(pil_image) + # Convert RGB to BGR + open_cv_image = open_cv_image[:, :, ::-1].copy() + + # apply augmentations + augmented = self.aug_transform(image=open_cv_image)["image"] + + # convert back to RGB tensor + augmented = cv2.cvtColor(augmented, cv2.COLOR_BGR2RGB) + + # convert to PIL image + augmented = Image.fromarray(augmented) + + # return both # return image as 0 - 1 tensor + return transforms.ToTensor()(pil_image), transforms.ToTensor()(augmented) diff --git a/toolkit/esrgan_utils.py b/toolkit/esrgan_utils.py new file mode 100644 index 00000000..25a8bfba --- /dev/null +++ b/toolkit/esrgan_utils.py @@ -0,0 +1,51 @@ + +to_basicsr_dict = { + 'model.0.weight': 'conv_first.weight', + 'model.0.bias': 'conv_first.bias', + 'model.1.sub.23.weight': 'conv_body.weight', + 'model.1.sub.23.bias': 'conv_body.bias', + 'model.3.weight': 'conv_up1.weight', + 'model.3.bias': 'conv_up1.bias', + 'model.6.weight': 'conv_up2.weight', + 'model.6.bias': 'conv_up2.bias', + 'model.8.weight': 'conv_hr.weight', + 'model.8.bias': 'conv_hr.bias', + 'model.10.bias': 'conv_last.bias', + 'model.10.weight': 'conv_last.weight', + # 'model.1.sub.0.RDB1.conv1.0.weight': 'body.0.rdb1.conv1.weight' +} + +def convert_state_dict_to_basicsr(state_dict): + new_state_dict = {} + for k, v in state_dict.items(): + if k in to_basicsr_dict: + new_state_dict[to_basicsr_dict[k]] = v + elif k.startswith('model.1.sub.'): + bsr_name = k.replace('model.1.sub.', 'body.').lower() + bsr_name = bsr_name.replace('.0.weight', '.weight') + bsr_name = bsr_name.replace('.0.bias', '.bias') + new_state_dict[bsr_name] = v + else: + new_state_dict[k] = v + return new_state_dict + + +# just matching a commonly used format +def convert_basicsr_state_dict_to_save_format(state_dict): + new_state_dict = {} + to_basicsr_dict_values = list(to_basicsr_dict.values()) + for k, v in state_dict.items(): + if k in to_basicsr_dict_values: + for key, value in to_basicsr_dict.items(): + if value == k: + new_state_dict[key] = v + + elif k.startswith('body.'): + bsr_name = k.replace('body.', 'model.1.sub.').lower() + bsr_name = bsr_name.replace('rdb', 'RDB') + bsr_name = bsr_name.replace('.weight', '.0.weight') + bsr_name = bsr_name.replace('.bias', '.0.bias') + new_state_dict[bsr_name] = v + else: + new_state_dict[k] = v + return new_state_dict diff --git a/toolkit/models/RRDB.py b/toolkit/models/RRDB.py new file mode 100644 index 00000000..c847e720 --- /dev/null +++ b/toolkit/models/RRDB.py @@ -0,0 +1,296 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import functools +import math +import re +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from . import block as B + + +# Borrowed from https://github.com/rlaphoenix/VSGAN/blob/master/vsgan/archs/ESRGAN.py +# Which enhanced stuff that was already here +class RRDBNet(nn.Module): + def __init__( + self, + state_dict, + norm=None, + act: str = "leakyrelu", + upsampler: str = "upconv", + mode: B.ConvMode = "CNA", + ) -> None: + """ + ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks. + By Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao, + and Chen Change Loy. + This is old-arch Residual in Residual Dense Block Network and is not + the newest revision that's available at github.com/xinntao/ESRGAN. + This is on purpose, the newest Network has severely limited the + potential use of the Network with no benefits. + This network supports model files from both new and old-arch. + Args: + norm: Normalization layer + act: Activation layer + upsampler: Upsample layer. upconv, pixel_shuffle + mode: Convolution mode + """ + super(RRDBNet, self).__init__() + self.model_arch = "ESRGAN" + self.sub_type = "SR" + + self.state = state_dict + self.norm = norm + self.act = act + self.upsampler = upsampler + self.mode = mode + + self.state_map = { + # currently supports old, new, and newer RRDBNet arch models + # ESRGAN, BSRGAN/RealSR, Real-ESRGAN + "model.0.weight": ("conv_first.weight",), + "model.0.bias": ("conv_first.bias",), + "model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"), + "model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"), + r"model.1.sub.\1.RDB\2.conv\3.0.\4": ( + r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)", + r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)", + ), + } + if "params_ema" in self.state: + self.state = self.state["params_ema"] + # self.model_arch = "RealESRGAN" + self.num_blocks = self.get_num_blocks() + self.plus = any("conv1x1" in k for k in self.state.keys()) + if self.plus: + self.model_arch = "ESRGAN+" + + self.state = self.new_to_old_arch(self.state) + + self.key_arr = list(self.state.keys()) + + self.in_nc: int = self.state[self.key_arr[0]].shape[1] + self.out_nc: int = self.state[self.key_arr[-1]].shape[0] + + self.scale: int = self.get_scale() + self.num_filters: int = self.state[self.key_arr[0]].shape[0] + + c2x2 = False + if self.state["model.0.weight"].shape[-2] == 2: + c2x2 = True + self.scale = round(math.sqrt(self.scale / 4)) + self.model_arch = "ESRGAN-2c2" + + self.supports_fp16 = True + self.supports_bfp16 = True + self.min_size_restriction = None + + # Detect if pixelunshuffle was used (Real-ESRGAN) + if self.in_nc in (self.out_nc * 4, self.out_nc * 16) and self.out_nc in ( + self.in_nc / 4, + self.in_nc / 16, + ): + self.shuffle_factor = int(math.sqrt(self.in_nc / self.out_nc)) + else: + self.shuffle_factor = None + + upsample_block = { + "upconv": B.upconv_block, + "pixel_shuffle": B.pixelshuffle_block, + }.get(self.upsampler) + if upsample_block is None: + raise NotImplementedError(f"Upsample mode [{self.upsampler}] is not found") + + if self.scale == 3: + upsample_blocks = upsample_block( + in_nc=self.num_filters, + out_nc=self.num_filters, + upscale_factor=3, + act_type=self.act, + c2x2=c2x2, + ) + else: + upsample_blocks = [ + upsample_block( + in_nc=self.num_filters, + out_nc=self.num_filters, + act_type=self.act, + c2x2=c2x2, + ) + for _ in range(int(math.log(self.scale, 2))) + ] + + self.model = B.sequential( + # fea conv + B.conv_block( + in_nc=self.in_nc, + out_nc=self.num_filters, + kernel_size=3, + norm_type=None, + act_type=None, + c2x2=c2x2, + ), + B.ShortcutBlock( + B.sequential( + # rrdb blocks + *[ + B.RRDB( + nf=self.num_filters, + kernel_size=3, + gc=32, + stride=1, + bias=True, + pad_type="zero", + norm_type=self.norm, + act_type=self.act, + mode="CNA", + plus=self.plus, + c2x2=c2x2, + ) + for _ in range(self.num_blocks) + ], + # lr conv + B.conv_block( + in_nc=self.num_filters, + out_nc=self.num_filters, + kernel_size=3, + norm_type=self.norm, + act_type=None, + mode=self.mode, + c2x2=c2x2, + ), + ) + ), + *upsample_blocks, + # hr_conv0 + B.conv_block( + in_nc=self.num_filters, + out_nc=self.num_filters, + kernel_size=3, + norm_type=None, + act_type=self.act, + c2x2=c2x2, + ), + # hr_conv1 + B.conv_block( + in_nc=self.num_filters, + out_nc=self.out_nc, + kernel_size=3, + norm_type=None, + act_type=None, + c2x2=c2x2, + ), + ) + + # Adjust these properties for calculations outside of the model + if self.shuffle_factor: + self.in_nc //= self.shuffle_factor ** 2 + self.scale //= self.shuffle_factor + + self.load_state_dict(self.state, strict=False) + + def new_to_old_arch(self, state): + """Convert a new-arch model state dictionary to an old-arch dictionary.""" + if "params_ema" in state: + state = state["params_ema"] + + if "conv_first.weight" not in state: + # model is already old arch, this is a loose check, but should be sufficient + return state + + # add nb to state keys + for kind in ("weight", "bias"): + self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[ + f"model.1.sub./NB/.{kind}" + ] + del self.state_map[f"model.1.sub./NB/.{kind}"] + + old_state = OrderedDict() + for old_key, new_keys in self.state_map.items(): + for new_key in new_keys: + if r"\1" in old_key: + for k, v in state.items(): + sub = re.sub(new_key, old_key, k) + if sub != k: + old_state[sub] = v + else: + if new_key in state: + old_state[old_key] = state[new_key] + + # upconv layers + max_upconv = 0 + for key in state.keys(): + match = re.match(r"(upconv|conv_up)(\d)\.(weight|bias)", key) + if match is not None: + _, key_num, key_type = match.groups() + old_state[f"model.{int(key_num) * 3}.{key_type}"] = state[key] + max_upconv = max(max_upconv, int(key_num) * 3) + + # final layers + for key in state.keys(): + if key in ("HRconv.weight", "conv_hr.weight"): + old_state[f"model.{max_upconv + 2}.weight"] = state[key] + elif key in ("HRconv.bias", "conv_hr.bias"): + old_state[f"model.{max_upconv + 2}.bias"] = state[key] + elif key in ("conv_last.weight",): + old_state[f"model.{max_upconv + 4}.weight"] = state[key] + elif key in ("conv_last.bias",): + old_state[f"model.{max_upconv + 4}.bias"] = state[key] + + # Sort by first numeric value of each layer + def compare(item1, item2): + parts1 = item1.split(".") + parts2 = item2.split(".") + int1 = int(parts1[1]) + int2 = int(parts2[1]) + return int1 - int2 + + sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare)) + + # Rebuild the output dict in the right order + out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys) + + return out_dict + + def get_scale(self, min_part: int = 6) -> int: + n = 0 + for part in list(self.state): + parts = part.split(".")[1:] + if len(parts) == 2: + part_num = int(parts[0]) + if part_num > min_part and parts[1] == "weight": + n += 1 + return 2 ** n + + def get_num_blocks(self) -> int: + nbs = [] + state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + ( + r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)", + ) + for state_key in state_keys: + for k in self.state: + m = re.search(state_key, k) + if m: + nbs.append(int(m.group(1))) + if nbs: + break + return max(*nbs) + 1 + + def forward(self, x): + if self.shuffle_factor: + _, _, h, w = x.size() + mod_pad_h = ( + self.shuffle_factor - h % self.shuffle_factor + ) % self.shuffle_factor + mod_pad_w = ( + self.shuffle_factor - w % self.shuffle_factor + ) % self.shuffle_factor + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect") + x = torch.pixel_unshuffle(x, downscale_factor=self.shuffle_factor) + x = self.model(x) + return x[:, :, : h * self.scale, : w * self.scale] + return self.model(x) diff --git a/toolkit/models/block.py b/toolkit/models/block.py new file mode 100644 index 00000000..76356b5e --- /dev/null +++ b/toolkit/models/block.py @@ -0,0 +1,549 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from collections import OrderedDict + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + +import torch +import torch.nn as nn + + +#################### +# Basic blocks +#################### + + +def act(act_type: str, inplace=True, neg_slope=0.2, n_prelu=1): + # helper selecting activation + # neg_slope: for leakyrelu and init of prelu + # n_prelu: for p_relu num_parameters + act_type = act_type.lower() + if act_type == "relu": + layer = nn.ReLU(inplace) + elif act_type == "leakyrelu": + layer = nn.LeakyReLU(neg_slope, inplace) + elif act_type == "prelu": + layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) + else: + raise NotImplementedError( + "activation layer [{:s}] is not found".format(act_type) + ) + return layer + + +def norm(norm_type: str, nc: int): + # helper selecting normalization layer + norm_type = norm_type.lower() + if norm_type == "batch": + layer = nn.BatchNorm2d(nc, affine=True) + elif norm_type == "instance": + layer = nn.InstanceNorm2d(nc, affine=False) + else: + raise NotImplementedError( + "normalization layer [{:s}] is not found".format(norm_type) + ) + return layer + + +def pad(pad_type: str, padding): + # helper selecting padding layer + # if padding is 'zero', do by conv layers + pad_type = pad_type.lower() + if padding == 0: + return None + if pad_type == "reflect": + layer = nn.ReflectionPad2d(padding) + elif pad_type == "replicate": + layer = nn.ReplicationPad2d(padding) + else: + raise NotImplementedError( + "padding layer [{:s}] is not implemented".format(pad_type) + ) + return layer + + +def get_valid_padding(kernel_size, dilation): + kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) + padding = (kernel_size - 1) // 2 + return padding + + +class ConcatBlock(nn.Module): + # Concat the output of a submodule to its input + def __init__(self, submodule): + super(ConcatBlock, self).__init__() + self.sub = submodule + + def forward(self, x): + output = torch.cat((x, self.sub(x)), dim=1) + return output + + def __repr__(self): + tmpstr = "Identity .. \n|" + modstr = self.sub.__repr__().replace("\n", "\n|") + tmpstr = tmpstr + modstr + return tmpstr + + +class ShortcutBlock(nn.Module): + # Elementwise sum the output of a submodule to its input + def __init__(self, submodule): + super(ShortcutBlock, self).__init__() + self.sub = submodule + + def forward(self, x): + output = x + self.sub(x) + return output + + def __repr__(self): + tmpstr = "Identity + \n|" + modstr = self.sub.__repr__().replace("\n", "\n|") + tmpstr = tmpstr + modstr + return tmpstr + + +class ShortcutBlockSPSR(nn.Module): + # Elementwise sum the output of a submodule to its input + def __init__(self, submodule): + super(ShortcutBlockSPSR, self).__init__() + self.sub = submodule + + def forward(self, x): + return x, self.sub + + def __repr__(self): + tmpstr = "Identity + \n|" + modstr = self.sub.__repr__().replace("\n", "\n|") + tmpstr = tmpstr + modstr + return tmpstr + + +def sequential(*args): + # Flatten Sequential. It unwraps nn.Sequential. + if len(args) == 1: + if isinstance(args[0], OrderedDict): + raise NotImplementedError("sequential does not support OrderedDict input.") + return args[0] # No sequential is needed. + modules = [] + for module in args: + if isinstance(module, nn.Sequential): + for submodule in module.children(): + modules.append(submodule) + elif isinstance(module, nn.Module): + modules.append(module) + return nn.Sequential(*modules) + + +ConvMode = Literal["CNA", "NAC", "CNAC"] + + +# 2x2x2 Conv Block +def conv_block_2c2( + in_nc, + out_nc, + act_type="relu", +): + return sequential( + nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1), + nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0), + act(act_type) if act_type else None, + ) + + +def conv_block( + in_nc: int, + out_nc: int, + kernel_size, + stride=1, + dilation=1, + groups=1, + bias=True, + pad_type="zero", + norm_type: str | None = None, + act_type: str | None = "relu", + mode: ConvMode = "CNA", + c2x2=False, +): + """ + Conv layer with padding, normalization, activation + mode: CNA --> Conv -> Norm -> Act + NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16) + """ + + if c2x2: + return conv_block_2c2(in_nc, out_nc, act_type=act_type) + + assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode) + padding = get_valid_padding(kernel_size, dilation) + p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None + padding = padding if pad_type == "zero" else 0 + + c = nn.Conv2d( + in_nc, + out_nc, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=bias, + groups=groups, + ) + a = act(act_type) if act_type else None + if mode in ("CNA", "CNAC"): + n = norm(norm_type, out_nc) if norm_type else None + return sequential(p, c, n, a) + elif mode == "NAC": + if norm_type is None and act_type is not None: + a = act(act_type, inplace=False) + # Important! + # input----ReLU(inplace)----Conv--+----output + # |________________________| + # inplace ReLU will modify the input, therefore wrong output + n = norm(norm_type, in_nc) if norm_type else None + return sequential(n, a, p, c) + else: + assert False, f"Invalid conv mode {mode}" + + +#################### +# Useful blocks +#################### + + +class ResNetBlock(nn.Module): + """ + ResNet Block, 3-3 style + with extra residual scaling used in EDSR + (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17) + """ + + def __init__( + self, + in_nc, + mid_nc, + out_nc, + kernel_size=3, + stride=1, + dilation=1, + groups=1, + bias=True, + pad_type="zero", + norm_type=None, + act_type="relu", + mode: ConvMode = "CNA", + res_scale=1, + ): + super(ResNetBlock, self).__init__() + conv0 = conv_block( + in_nc, + mid_nc, + kernel_size, + stride, + dilation, + groups, + bias, + pad_type, + norm_type, + act_type, + mode, + ) + if mode == "CNA": + act_type = None + if mode == "CNAC": # Residual path: |-CNAC-| + act_type = None + norm_type = None + conv1 = conv_block( + mid_nc, + out_nc, + kernel_size, + stride, + dilation, + groups, + bias, + pad_type, + norm_type, + act_type, + mode, + ) + # if in_nc != out_nc: + # self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \ + # None, None) + # print('Need a projecter in ResNetBlock.') + # else: + # self.project = lambda x:x + self.res = sequential(conv0, conv1) + self.res_scale = res_scale + + def forward(self, x): + res = self.res(x).mul(self.res_scale) + return x + res + + +class RRDB(nn.Module): + """ + Residual in Residual Dense Block + (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) + """ + + def __init__( + self, + nf, + kernel_size=3, + gc=32, + stride=1, + bias: bool = True, + pad_type="zero", + norm_type=None, + act_type="leakyrelu", + mode: ConvMode = "CNA", + _convtype="Conv2D", + _spectral_norm=False, + plus=False, + c2x2=False, + ): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock_5C( + nf, + kernel_size, + gc, + stride, + bias, + pad_type, + norm_type, + act_type, + mode, + plus=plus, + c2x2=c2x2, + ) + self.RDB2 = ResidualDenseBlock_5C( + nf, + kernel_size, + gc, + stride, + bias, + pad_type, + norm_type, + act_type, + mode, + plus=plus, + c2x2=c2x2, + ) + self.RDB3 = ResidualDenseBlock_5C( + nf, + kernel_size, + gc, + stride, + bias, + pad_type, + norm_type, + act_type, + mode, + plus=plus, + c2x2=c2x2, + ) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out * 0.2 + x + + +class ResidualDenseBlock_5C(nn.Module): + """ + Residual Dense Block + style: 5 convs + The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) + Modified options that can be used: + - "Partial Convolution based Padding" arXiv:1811.11718 + - "Spectral normalization" arXiv:1802.05957 + - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. + {Rakotonirina} and A. {Rasoanaivo} + + Args: + nf (int): Channel number of intermediate features (num_feat). + gc (int): Channels for each growth (num_grow_ch: growth channel, + i.e. intermediate channels). + convtype (str): the type of convolution to use. Default: 'Conv2D' + gaussian_noise (bool): enable the ESRGAN+ gaussian noise (no new + trainable parameters) + plus (bool): enable the additional residual paths from ESRGAN+ + (adds trainable parameters) + """ + + def __init__( + self, + nf=64, + kernel_size=3, + gc=32, + stride=1, + bias: bool = True, + pad_type="zero", + norm_type=None, + act_type="leakyrelu", + mode: ConvMode = "CNA", + plus=False, + c2x2=False, + ): + super(ResidualDenseBlock_5C, self).__init__() + + ## + + self.conv1x1 = conv1x1(nf, gc) if plus else None + ## + + + self.conv1 = conv_block( + nf, + gc, + kernel_size, + stride, + bias=bias, + pad_type=pad_type, + norm_type=norm_type, + act_type=act_type, + mode=mode, + c2x2=c2x2, + ) + self.conv2 = conv_block( + nf + gc, + gc, + kernel_size, + stride, + bias=bias, + pad_type=pad_type, + norm_type=norm_type, + act_type=act_type, + mode=mode, + c2x2=c2x2, + ) + self.conv3 = conv_block( + nf + 2 * gc, + gc, + kernel_size, + stride, + bias=bias, + pad_type=pad_type, + norm_type=norm_type, + act_type=act_type, + mode=mode, + c2x2=c2x2, + ) + self.conv4 = conv_block( + nf + 3 * gc, + gc, + kernel_size, + stride, + bias=bias, + pad_type=pad_type, + norm_type=norm_type, + act_type=act_type, + mode=mode, + c2x2=c2x2, + ) + if mode == "CNA": + last_act = None + else: + last_act = act_type + self.conv5 = conv_block( + nf + 4 * gc, + nf, + 3, + stride, + bias=bias, + pad_type=pad_type, + norm_type=norm_type, + act_type=last_act, + mode=mode, + c2x2=c2x2, + ) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(torch.cat((x, x1), 1)) + if self.conv1x1: + # pylint: disable=not-callable + x2 = x2 + self.conv1x1(x) # + + x3 = self.conv3(torch.cat((x, x1, x2), 1)) + x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) + if self.conv1x1: + x4 = x4 + x2 # + + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + x + + +def conv1x1(in_planes, out_planes, stride=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +#################### +# Upsampler +#################### + + +def pixelshuffle_block( + in_nc: int, + out_nc: int, + upscale_factor=2, + kernel_size=3, + stride=1, + bias=True, + pad_type="zero", + norm_type: str | None = None, + act_type="relu", +): + """ + Pixel shuffle layer + (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional + Neural Network, CVPR17) + """ + conv = conv_block( + in_nc, + out_nc * (upscale_factor ** 2), + kernel_size, + stride, + bias=bias, + pad_type=pad_type, + norm_type=None, + act_type=None, + ) + pixel_shuffle = nn.PixelShuffle(upscale_factor) + + n = norm(norm_type, out_nc) if norm_type else None + a = act(act_type) if act_type else None + return sequential(conv, pixel_shuffle, n, a) + + +def upconv_block( + in_nc: int, + out_nc: int, + upscale_factor=2, + kernel_size=3, + stride=1, + bias=True, + pad_type="zero", + norm_type: str | None = None, + act_type="relu", + mode="nearest", + c2x2=False, +): + # Up conv + # described in https://distill.pub/2016/deconv-checkerboard/ + # convert to float 16 if is bfloat16 + upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode) + conv = conv_block( + in_nc, + out_nc, + kernel_size, + stride, + bias=bias, + pad_type=pad_type, + norm_type=norm_type, + act_type=act_type, + c2x2=c2x2, + ) + return sequential(upsample, conv)