diff --git a/jobs/TrainJob.py b/jobs/TrainJob.py index 2f5c66ef..c8b4e4b9 100644 --- a/jobs/TrainJob.py +++ b/jobs/TrainJob.py @@ -15,7 +15,7 @@ sys.path.append(REPOS_ROOT) process_dict = { 'vae': 'TrainVAEProcess', - 'finetune': 'TrainFineTuneProcess' + 'slider': 'TrainSliderProcess', } @@ -27,8 +27,8 @@ class TrainJob(BaseJob): self.training_folder = self.get_conf('training_folder', required=True) self.is_v2 = self.get_conf('is_v2', False) self.device = self.get_conf('device', 'cpu') - self.gradient_accumulation_steps = self.get_conf('gradient_accumulation_steps', 1) - self.mixed_precision = self.get_conf('mixed_precision', False) # fp16 + # self.gradient_accumulation_steps = self.get_conf('gradient_accumulation_steps', 1) + # self.mixed_precision = self.get_conf('mixed_precision', False) # fp16 self.log_dir = self.get_conf('log_dir', None) self.writer = None diff --git a/toolkit/llvae.py b/toolkit/llvae.py index 74750894..cfb8ca27 100644 --- a/toolkit/llvae.py +++ b/toolkit/llvae.py @@ -5,13 +5,14 @@ import itertools class LosslessLatentDecoder(nn.Module): - def __init__(self, in_channels, latent_depth): + def __init__(self, in_channels, latent_depth, dtype=torch.float32): super(LosslessLatentDecoder, self).__init__() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.latent_depth = latent_depth self.in_channels = in_channels self.out_channels = int(in_channels // (latent_depth * latent_depth)) numpy_kernel = self.build_kernel(in_channels, latent_depth) - self.kernel = torch.from_numpy(numpy_kernel).float() + self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype) def build_kernel(self, in_channels, latent_depth): # my old code from tensorflow. @@ -39,13 +40,15 @@ class LosslessLatentDecoder(nn.Module): class LosslessLatentEncoder(nn.Module): - def __init__(self, in_channels, latent_depth): + def __init__(self, in_channels, latent_depth, dtype=torch.float32): super(LosslessLatentEncoder, self).__init__() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.latent_depth = latent_depth self.in_channels = in_channels self.out_channels = int(in_channels * (latent_depth * latent_depth)) numpy_kernel = self.build_kernel(in_channels, latent_depth) - self.kernel = torch.from_numpy(numpy_kernel).float() + self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype) + def build_kernel(self, in_channels, latent_depth): # my old code from tensorflow. @@ -72,13 +75,13 @@ class LosslessLatentEncoder(nn.Module): class LosslessLatentVAE(nn.Module): - def __init__(self, in_channels, latent_depth): + def __init__(self, in_channels, latent_depth, dtype=torch.float32): super(LosslessLatentVAE, self).__init__() self.latent_depth = latent_depth self.in_channels = in_channels - self.encoder = LosslessLatentEncoder(in_channels, latent_depth) + self.encoder = LosslessLatentEncoder(in_channels, latent_depth, dtype=dtype) encoder_out_channels = self.encoder.out_channels - self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth) + self.decoder = LosslessLatentDecoder(encoder_out_channels, latent_depth, dtype=dtype) def forward(self, x): latent = self.latent_encoder(x) diff --git a/toolkit/losses.py b/toolkit/losses.py index 9158c505..f8c2855a 100644 --- a/toolkit/losses.py +++ b/toolkit/losses.py @@ -1,4 +1,5 @@ import torch +from .llvae import LosslessLatentEncoder def total_variation(image): @@ -45,3 +46,40 @@ def get_gradient_penalty(critic, real, fake, device): gradient_penalty = ((gradient_norm - 1) ** 2).mean() return gradient_penalty + +class PatternLoss(torch.nn.Module): + def __init__(self, pattern_size=4, dtype=torch.float32): + super().__init__() + self.pattern_size = pattern_size + self.llvae_encoder = LosslessLatentEncoder(3, pattern_size, dtype=dtype) + + def forward(self, pred, target): + pred_latents = self.llvae_encoder(pred) + target_latents = self.llvae_encoder(target) + + matrix_pixels = self.pattern_size * self.pattern_size + + color_chans = pred_latents.shape[1] // 3 + # pytorch + r_chans, g_chans, b_chans = torch.split(pred_latents, [color_chans, color_chans, color_chans], 1) + r_chans_target, g_chans_target, b_chans_target = torch.split(target_latents, [color_chans, color_chans, color_chans], 1) + + def separated_chan_loss(latent_chan): + nonlocal matrix_pixels + chan_mean = torch.mean(latent_chan, dim=[1, 2, 3]) + chan_splits = torch.split(latent_chan, [1 for i in range(matrix_pixels)], 1) + chan_loss = None + for chan in chan_splits: + this_mean = torch.mean(chan, dim=[1, 2, 3]) + this_chan_loss = torch.abs(this_mean - chan_mean) + if chan_loss is None: + chan_loss = this_chan_loss + else: + chan_loss = chan_loss + this_chan_loss + chan_loss = chan_loss * (1 / matrix_pixels) + return chan_loss + + r_chan_loss = torch.abs(separated_chan_loss(r_chans) - separated_chan_loss(r_chans_target)) + g_chan_loss = torch.abs(separated_chan_loss(g_chans) - separated_chan_loss(g_chans_target)) + b_chan_loss = torch.abs(separated_chan_loss(b_chans) - separated_chan_loss(b_chans_target)) + return (r_chan_loss + g_chan_loss + b_chan_loss) * 0.3333 diff --git a/toolkit/optimizer.py b/toolkit/optimizer.py index f58d90ff..bba10cf1 100644 --- a/toolkit/optimizer.py +++ b/toolkit/optimizer.py @@ -6,13 +6,16 @@ def get_optimizer( optimizer_type='adam', learning_rate=1e-6 ): - if optimizer_type == 'dadaptation': + lower_type = optimizer_type.lower() + if lower_type == 'dadaptation': # dadaptation optimizer does not use standard learning rate. 1 is the default value import dadaptation print("Using DAdaptAdam optimizer") optimizer = dadaptation.DAdaptAdam(params, lr=1.0) - elif optimizer_type == 'adam': + elif lower_type == 'adam': optimizer = torch.optim.Adam(params, lr=float(learning_rate)) + elif lower_type == 'adamw': + optimizer = torch.optim.AdamW(params, lr=float(learning_rate)) else: raise ValueError(f'Unknown optimizer type {optimizer_type}') return optimizer