diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index fb6536cd..64566db9 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -7,6 +7,7 @@ from collections import OrderedDict from PIL import Image from PIL.ImageOps import exif_transpose +from einops import rearrange from safetensors.torch import save_file, load_file from torch.utils.data import DataLoader, ConcatDataset import torch @@ -17,18 +18,22 @@ from jobs.process import BaseTrainProcess from toolkit.image_utils import show_tensors from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm from toolkit.data_loader import ImageDataset -from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss +from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss, total_variation from toolkit.metadata import get_meta_for_safetensors from toolkit.optimizer import get_optimizer from toolkit.style import get_style_model_and_losses from toolkit.train_tools import get_torch_dtype from diffusers import AutoencoderKL from tqdm import tqdm +import math +import torchvision.utils import time import numpy as np -from .models.vgg19_critic import Critic +from .models.critic import Critic from torchvision.transforms import Resize import lpips +import random +import traceback IMAGE_TRANSFORMS = transforms.Compose( [ @@ -42,13 +47,21 @@ def unnormalize(tensor): return (tensor / 2 + 0.5).clamp(0, 1) +def channel_dropout(x, p=0.5): + keep_prob = 1 - p + mask = torch.rand(x.size(0), x.size(1), 1, 1, device=x.device, dtype=x.dtype) < keep_prob + mask = mask / keep_prob # scale + return x * mask + + class TrainVAEProcess(BaseTrainProcess): def __init__(self, process_id: int, job, config: OrderedDict): super().__init__(process_id, job, config) self.data_loader = None self.vae = None self.device = self.get_conf('device', self.job.device) - self.vae_path = self.get_conf('vae_path', required=True) + self.vae_path = self.get_conf('vae_path', None) + self.eq_vae = self.get_conf('eq_vae', False) self.datasets_objects = self.get_conf('datasets', required=True) self.batch_size = self.get_conf('batch_size', 1, as_type=int) self.resolution = self.get_conf('resolution', 256, as_type=int) @@ -65,11 +78,24 @@ class TrainVAEProcess(BaseTrainProcess): self.content_weight = self.get_conf('content_weight', 0, as_type=float) self.kld_weight = self.get_conf('kld_weight', 0, as_type=float) self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float) - self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float) + self.mv_loss_weight = self.get_conf('mv_loss_weight', 0, as_type=float) + self.tv_weight = self.get_conf('tv_weight', 0, as_type=float) + self.ltv_weight = self.get_conf('ltv_weight', 0, as_type=float) + self.lpm_weight = self.get_conf('lpm_weight', 0, as_type=float) # latent pixel matching self.lpips_weight = self.get_conf('lpips_weight', 1e0, as_type=float) self.critic_weight = self.get_conf('critic_weight', 1, as_type=float) - self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float) + self.pattern_weight = self.get_conf('pattern_weight', 0, as_type=float) self.optimizer_params = self.get_conf('optimizer_params', {}) + self.vae_config = self.get_conf('vae_config', None) + self.dropout = self.get_conf('dropout', 0.0, as_type=float) + self.train_encoder = self.get_conf('train_encoder', False, as_type=bool) + + if not self.train_encoder: + # remove losses that only target encoder + self.kld_weight = 0 + self.mv_loss_weight = 0 + self.ltv_weight = 0 + self.lpm_weight = 0 self.blocks_to_train = self.get_conf('blocks_to_train', ['all']) self.torch_dtype = get_torch_dtype(self.dtype) @@ -142,7 +168,7 @@ class TrainVAEProcess(BaseTrainProcess): concatenated_dataset, batch_size=self.batch_size, shuffle=True, - num_workers=6 + num_workers=8 ) def remove_oldest_checkpoint(self): @@ -153,6 +179,13 @@ class TrainVAEProcess(BaseTrainProcess): for folder in folders[:-max_to_keep]: print(f"Removing {folder}") shutil.rmtree(folder) + # also handle CRITIC_vae_42_000000500.safetensors format for critic + critic_files = glob.glob(os.path.join(self.save_root, f"CRITIC_{self.job.name}*.safetensors")) + if len(critic_files) > max_to_keep: + critic_files.sort(key=os.path.getmtime) + for file in critic_files[:-max_to_keep]: + print(f"Removing {file}") + os.remove(file) def setup_vgg19(self): if self.vgg_19 is None: @@ -218,6 +251,62 @@ class TrainVAEProcess(BaseTrainProcess): else: return torch.tensor(0.0, device=self.device) + def get_mean_variance_loss(self, latents: torch.Tensor): + if self.mv_loss_weight > 0: + # collapse rows into channels + latents_col = rearrange(latents, 'b c h (gw w) -> b (c gw) h w', gw=latents.shape[-1]) + mean_col = latents_col.mean(dim=(2, 3), keepdim=True) + std_col = latents_col.std(dim=(2, 3), keepdim=True, unbiased=False) + mean_loss_col = (mean_col ** 2).mean() + std_loss_col = ((std_col - 1) ** 2).mean() + + # collapse columns into channels + latents_row = rearrange(latents, 'b c (gh h) w -> b (c gh) h w', gh=latents.shape[-2]) + mean_row = latents_row.mean(dim=(2, 3), keepdim=True) + std_row = latents_row.std(dim=(2, 3), keepdim=True, unbiased=False) + mean_loss_row = (mean_row ** 2).mean() + std_loss_row = ((std_row - 1) ** 2).mean() + + # do a global one + + mean = latents.mean(dim=(2, 3), keepdim=True) + std = latents.std(dim=(2, 3), keepdim=True, unbiased=False) + mean_loss_global = (mean ** 2).mean() + std_loss_global = ((std - 1) ** 2).mean() + + return (mean_loss_col + std_loss_col + mean_loss_row + std_loss_row + mean_loss_global + std_loss_global) / 3 + else: + return torch.tensor(0.0, device=self.device) + + def get_ltv_loss(self, latent): + # loss to reduce the latent space variance + if self.ltv_weight > 0: + return total_variation(latent).mean() + else: + return torch.tensor(0.0, device=self.device) + + def get_latent_pixel_matching_loss(self, latent, pixels): + if self.lpm_weight > 0: + with torch.no_grad(): + pixels = pixels.to(latent.device, dtype=latent.dtype) + # resize down to latent size + pixels = torch.nn.functional.interpolate(pixels, size=(latent.shape[2], latent.shape[3]), mode='bilinear', align_corners=False) + + # mean the color channel and then expand to latent size + pixels = pixels.mean(dim=1, keepdim=True) + pixels = pixels.repeat(1, latent.shape[1], 1, 1) + # match the mean std of latent + latent_mean = latent.mean(dim=(2, 3), keepdim=True) + latent_std = latent.std(dim=(2, 3), keepdim=True) + pixels_mean = pixels.mean(dim=(2, 3), keepdim=True) + pixels_std = pixels.std(dim=(2, 3), keepdim=True) + pixels = (pixels - pixels_mean) / (pixels_std + 1e-6) * latent_std + latent_mean + + return torch.nn.functional.mse_loss(latent.float(), pixels.float()) + + else: + return torch.tensor(0.0, device=self.device) + def get_tv_loss(self, pred, target): if self.tv_weight > 0: get_tv_loss = ComparativeTotalVariation() @@ -277,7 +366,39 @@ class TrainVAEProcess(BaseTrainProcess): input_img = img img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype) img = img - decoded = self.vae(img).sample + latent = self.vae.encode(img).latent_dist.sample() + + latent_img = latent.clone() + bs, ch, h, w = latent_img.shape + grid_size = math.ceil(math.sqrt(ch)) + pad = grid_size * grid_size - ch + + # take first item in batch + latent_img = latent_img[0] # shape: (ch, h, w) + + if pad > 0: + padding = torch.zeros((pad, h, w), dtype=latent_img.dtype, device=latent_img.device) + latent_img = torch.cat([latent_img, padding], dim=0) + + # make grid + new_img = torch.zeros((1, grid_size * h, grid_size * w), dtype=latent_img.dtype, device=latent_img.device) + for x in range(grid_size): + for y in range(grid_size): + if x * grid_size + y < ch: + new_img[0, x * h:(x + 1) * h, y * w:(y + 1) * w] = latent_img[x * grid_size + y] + latent_img = new_img + # make rgb + latent_img = latent_img.repeat(3, 1, 1).unsqueeze(0) + latent_img = (latent_img / 2 + 0.5).clamp(0, 1) + + # resize to 256x256 + latent_img = torch.nn.functional.interpolate(latent_img, size=(self.resolution, self.resolution), mode='nearest') + latent_img = latent_img.squeeze(0).cpu().permute(1, 2, 0).float().numpy() + latent_img = (latent_img * 255).astype(np.uint8) + # convert to pillow image + latent_img = Image.fromarray(latent_img) + + decoded = self.vae.decode(latent).sample decoded = (decoded / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 decoded = decoded.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy() @@ -289,9 +410,10 @@ class TrainVAEProcess(BaseTrainProcess): input_img = input_img.resize((self.resolution, self.resolution)) decoded = decoded.resize((self.resolution, self.resolution)) - output_img = Image.new('RGB', (self.resolution * 2, self.resolution)) + output_img = Image.new('RGB', (self.resolution * 3, self.resolution)) output_img.paste(input_img, (0, 0)) output_img.paste(decoded, (self.resolution, 0)) + output_img.paste(latent_img, (self.resolution * 2, 0)) scale_up = 2 if output_img.height <= 300: @@ -326,12 +448,20 @@ class TrainVAEProcess(BaseTrainProcess): self.print(f"Loading VAE") self.print(f" - Loading VAE: {path_to_load}") if self.vae is None: - self.vae = AutoencoderKL.from_pretrained(path_to_load) + if path_to_load is not None: + self.vae = AutoencoderKL.from_pretrained(path_to_load) + elif self.vae_config is not None: + self.vae = AutoencoderKL(**self.vae_config) + else: + raise ValueError('vae_path or ae_config must be specified') # set decoder to train self.vae.to(self.device, dtype=self.torch_dtype) - self.vae.requires_grad_(False) - self.vae.eval() + if self.eq_vae: + self.vae.encoder.train() + else: + self.vae.requires_grad_(False) + self.vae.eval() self.vae.decoder.train() self.vae_scale_factor = 2 ** (len(self.vae.config['block_out_channels']) - 1) @@ -374,6 +504,10 @@ class TrainVAEProcess(BaseTrainProcess): if train_all: params = list(self.vae.decoder.parameters()) self.vae.decoder.requires_grad_(True) + if self.train_encoder: + # encoder + params += list(self.vae.encoder.parameters()) + self.vae.encoder.requires_grad_(True) else: # mid_block if train_all or 'mid_block' in self.blocks_to_train: @@ -388,12 +522,13 @@ class TrainVAEProcess(BaseTrainProcess): params += list(self.vae.decoder.conv_out.parameters()) self.vae.decoder.conv_out.requires_grad_(True) - if self.style_weight > 0 or self.content_weight > 0 or self.use_critic: + if self.style_weight > 0 or self.content_weight > 0: self.setup_vgg19() - self.vgg_19.requires_grad_(False) + # self.vgg_19.requires_grad_(False) self.vgg_19.eval() - if self.use_critic: - self.critic.setup() + + if self.use_critic: + self.critic.setup() if self.lpips_weight > 0 and self.lpips_loss is None: # self.lpips_loss = lpips.LPIPS(net='vgg') @@ -426,6 +561,9 @@ class TrainVAEProcess(BaseTrainProcess): "style": [], "content": [], "mse": [], + "mvl": [], + "ltv": [], + "lpm": [], "kl": [], "tv": [], "ptn": [], @@ -451,27 +589,83 @@ class TrainVAEProcess(BaseTrainProcess): batch.shape[3] // self.vae_scale_factor * self.vae_scale_factor))(batch) # forward pass + # grad only if eq_vae + with torch.set_grad_enabled(self.train_encoder): dgd = self.vae.encode(batch).latent_dist mu, logvar = dgd.mean, dgd.logvar latents = dgd.sample() - latents.detach().requires_grad_(True) + + if self.eq_vae: + # process flips, rotate, scale + latent_chunks = list(latents.chunk(latents.shape[0], dim=0)) + batch_chunks = list(batch.chunk(batch.shape[0], dim=0)) + out_chunks = [] + for i in range(len(latent_chunks)): + try: + do_rotate = random.randint(0, 3) + do_flip_x = random.randint(0, 1) + do_flip_y = random.randint(0, 1) + do_scale = random.randint(0, 1) + if do_rotate > 0: + latent_chunks[i] = torch.rot90(latent_chunks[i], do_rotate, (2, 3)) + batch_chunks[i] = torch.rot90(batch_chunks[i], do_rotate, (2, 3)) + if do_flip_x > 0: + latent_chunks[i] = torch.flip(latent_chunks[i], [2]) + batch_chunks[i] = torch.flip(batch_chunks[i], [2]) + if do_flip_y > 0: + latent_chunks[i] = torch.flip(latent_chunks[i], [3]) + batch_chunks[i] = torch.flip(batch_chunks[i], [3]) + # if do_scale > 0: + # scale = 2 + # start_latent_h = latent_chunks[i].shape[2] + # start_latent_w = latent_chunks[i].shape[3] + # start_batch_h = batch_chunks[i].shape[2] + # start_batch_w = batch_chunks[i].shape[3] + # latent_chunks[i] = torch.nn.functional.interpolate(latent_chunks[i], scale_factor=scale, mode='bilinear', align_corners=False) + # batch_chunks[i] = torch.nn.functional.interpolate(batch_chunks[i], scale_factor=scale, mode='bilinear', align_corners=False) + # # random crop. latent is smaller than match but crops need to match + # latent_x = random.randint(0, latent_chunks[i].shape[2] - start_latent_h) + # latent_y = random.randint(0, latent_chunks[i].shape[3] - start_latent_w) + # batch_x = latent_x * self.vae_scale_factor + # batch_y = latent_y * self.vae_scale_factor + + # # crop + # latent_chunks[i] = latent_chunks[i][:, :, latent_x:latent_x + start_latent_h, latent_y:latent_y + start_latent_w] + # batch_chunks[i] = batch_chunks[i][:, :, batch_x:batch_x + start_batch_h, batch_y:batch_y + start_batch_w] + except Exception as e: + print(f"Error processing image {i}: {e}") + traceback.print_exc() + raise e + out_chunks.append(latent_chunks[i]) + latents = torch.cat(out_chunks, dim=0) + # do dropout + if self.dropout > 0: + forward_latents = channel_dropout(latents, self.dropout) + else: + forward_latents = latents + batch = torch.cat(batch_chunks, dim=0) + + else: + latents.detach().requires_grad_(True) + forward_latents = latents + + forward_latents = forward_latents.to(self.device, dtype=self.torch_dtype) + + if not self.train_encoder: + # detach latents if not training encoder + forward_latents = forward_latents.detach() - pred = self.vae.decode(latents).sample - - with torch.no_grad(): - show_tensors( - pred.clamp(-1, 1).clone(), - "combined tensor" - ) + pred = self.vae.decode(forward_latents).sample # Run through VGG19 - if self.style_weight > 0 or self.content_weight > 0 or self.use_critic: + if self.style_weight > 0 or self.content_weight > 0: stacked = torch.cat([pred, batch], dim=0) stacked = (stacked / 2 + 0.5).clamp(0, 1) self.vgg_19(stacked) if self.use_critic: - critic_d_loss = self.critic.step(self.vgg19_pool_4.tensor.detach()) + stacked = torch.cat([pred, batch], dim=0) + critic_d_loss = self.critic.step(stacked.detach()) else: critic_d_loss = 0.0 @@ -489,7 +683,8 @@ class TrainVAEProcess(BaseTrainProcess): tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight pattern_loss = self.get_pattern_loss(pred, batch) * self.pattern_weight if self.use_critic: - critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight + stacked = torch.cat([pred, batch], dim=0) + critic_gen_loss = self.critic.get_critic_loss(stacked) * self.critic_weight # do not let abs critic gen loss be higher than abs lpips * 0.1 if using it if self.lpips_weight > 0: @@ -502,8 +697,42 @@ class TrainVAEProcess(BaseTrainProcess): critic_gen_loss *= crit_g_scaler else: critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) + + if self.mv_loss_weight > 0: + mv_loss = self.get_mean_variance_loss(latents) * self.mv_loss_weight + else: + mv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) + + if self.ltv_weight > 0: + ltv_loss = self.get_ltv_loss(latents) * self.ltv_weight + else: + ltv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) + + if self.lpm_weight > 0: + lpm_loss = self.get_latent_pixel_matching_loss(latents, batch) * self.lpm_weight + else: + lpm_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) - loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + mv_loss + ltv_loss + + # check if loss is NaN or Inf + if torch.isnan(loss) or torch.isinf(loss): + self.print(f"Loss is NaN or Inf, stopping at step {self.step_num}") + self.print(f" - Style loss: {style_loss.item()}") + self.print(f" - Content loss: {content_loss.item()}") + self.print(f" - KLD loss: {kld_loss.item()}") + self.print(f" - MSE loss: {mse_loss.item()}") + self.print(f" - LPIPS loss: {lpips_loss.item()}") + self.print(f" - TV loss: {tv_loss.item()}") + self.print(f" - Pattern loss: {pattern_loss.item()}") + self.print(f" - Critic gen loss: {critic_gen_loss.item()}") + self.print(f" - Critic D loss: {critic_d_loss}") + self.print(f" - Mean variance loss: {mv_loss.item()}") + self.print(f" - Latent TV loss: {ltv_loss.item()}") + self.print(f" - Latent pixel matching loss: {lpm_loss.item()}") + self.print(f" - Total loss: {loss.item()}") + self.print(f" - Stopping training") + exit(1) # Backward pass and optimization optimizer.zero_grad() @@ -533,8 +762,17 @@ class TrainVAEProcess(BaseTrainProcess): loss_string += f" crG: {critic_gen_loss.item():.2e}" if self.use_critic: loss_string += f" crD: {critic_d_loss:.2e}" + if self.mv_loss_weight > 0: + loss_string += f" mvl: {mv_loss:.2e}" + if self.ltv_weight > 0: + loss_string += f" ltv: {ltv_loss:.2e}" + if self.lpm_weight > 0: + loss_string += f" lpm: {lpm_loss:.2e}" + - if self.optimizer_type.startswith('dadaptation') or \ + if hasattr(optimizer, 'get_avg_learning_rate'): + learning_rate = optimizer.get_avg_learning_rate() + elif self.optimizer_type.startswith('dadaptation') or \ self.optimizer_type.lower().startswith('prodigy'): learning_rate = ( optimizer.param_groups[0]["d"] * @@ -562,6 +800,9 @@ class TrainVAEProcess(BaseTrainProcess): epoch_losses["ptn"].append(pattern_loss.item()) epoch_losses["crG"].append(critic_gen_loss.item()) epoch_losses["crD"].append(critic_d_loss) + epoch_losses["mvl"].append(mv_loss.item()) + epoch_losses["ltv"].append(ltv_loss.item()) + epoch_losses["lpm"].append(lpm_loss.item()) log_losses["total"].append(loss_value) log_losses["lpips"].append(lpips_loss.item()) @@ -573,6 +814,9 @@ class TrainVAEProcess(BaseTrainProcess): log_losses["ptn"].append(pattern_loss.item()) log_losses["crG"].append(critic_gen_loss.item()) log_losses["crD"].append(critic_d_loss) + log_losses["mvl"].append(mv_loss.item()) + log_losses["ltv"].append(ltv_loss.item()) + log_losses["lpm"].append(lpm_loss.item()) # don't do on first step if self.step_num != start_step: diff --git a/jobs/process/models/critic.py b/jobs/process/models/critic.py new file mode 100644 index 00000000..42bdb637 --- /dev/null +++ b/jobs/process/models/critic.py @@ -0,0 +1,229 @@ +import glob +import os +from typing import TYPE_CHECKING, Union + +import numpy as np +import torch +import torch.nn as nn +from safetensors.torch import load_file, save_file + +from toolkit.losses import get_gradient_penalty +from toolkit.metadata import get_meta_for_safetensors +from toolkit.optimizer import get_optimizer +from toolkit.train_tools import get_torch_dtype + + +class MeanReduce(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, inputs): + # global mean over spatial dims (keeps channel/batch) + return torch.mean(inputs, dim=(2, 3), keepdim=True) + + +class SelfAttention2d(nn.Module): + """ + Lightweight self-attention layer (SAGAN-style) that keeps spatial + resolution unchanged. Adds minimal params / compute but improves + long-range modelling – helpful for variable-sized inputs. + """ + def __init__(self, in_channels: int): + super().__init__() + self.query = nn.Conv1d(in_channels, in_channels // 8, 1) + self.key = nn.Conv1d(in_channels, in_channels // 8, 1) + self.value = nn.Conv1d(in_channels, in_channels, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + B, C, H, W = x.shape + flat = x.view(B, C, H * W) # (B,C,N) + q = self.query(flat).permute(0, 2, 1) # (B,N,C//8) + k = self.key(flat) # (B,C//8,N) + attn = torch.bmm(q, k) # (B,N,N) + attn = attn.softmax(dim=-1) # softmax along last dim + v = self.value(flat) # (B,C,N) + out = torch.bmm(v, attn.permute(0, 2, 1)) # (B,C,N) + out = out.view(B, C, H, W) # restore spatial dims + return self.gamma * out + x # residual + + +class CriticModel(nn.Module): + def __init__(self, base_channels: int = 64): + super().__init__() + + def sn_conv(in_c, out_c, k, s, p): + return nn.utils.spectral_norm( + nn.Conv2d(in_c, out_c, kernel_size=k, stride=s, padding=p) + ) + + layers = [ + # initial down-sample + sn_conv(3, base_channels, 3, 2, 1), + nn.LeakyReLU(0.2, inplace=True), + ] + + in_c = base_channels + # progressive downsamples ×3 (64→128→256→512) + for _ in range(3): + out_c = min(in_c * 2, 1024) + layers += [ + sn_conv(in_c, out_c, 3, 2, 1), + nn.LeakyReLU(0.2, inplace=True), + ] + # single attention block after reaching 256 channels + if out_c == 256: + layers += [SelfAttention2d(out_c)] + in_c = out_c + + # extra depth (keeps spatial size) + layers += [ + sn_conv(in_c, 1024, 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True), + + # final 1-channel prediction map + sn_conv(1024, 1, 3, 1, 1), + MeanReduce(), # → (B,1,1,1) + nn.Flatten(), # → (B,1) + ] + + self.main = nn.Sequential(*layers) + + def forward(self, inputs): + # force full-precision inside AMP ctx for stability + with torch.cuda.amp.autocast(False): + return self.main(inputs.float()) + + +if TYPE_CHECKING: + from jobs.process.TrainVAEProcess import TrainVAEProcess + from jobs.process.TrainESRGANProcess import TrainESRGANProcess + + +class Critic: + process: Union['TrainVAEProcess', 'TrainESRGANProcess'] + + def __init__( + self, + learning_rate=1e-5, + device='cpu', + optimizer='adam', + num_critic_per_gen=1, + dtype='float32', + lambda_gp=10, + start_step=0, + warmup_steps=1000, + process=None, + optimizer_params=None, + ): + self.learning_rate = learning_rate + self.device = device + self.optimizer_type = optimizer + self.num_critic_per_gen = num_critic_per_gen + self.dtype = dtype + self.torch_dtype = get_torch_dtype(self.dtype) + self.process = process + self.model = None + self.optimizer = None + self.scheduler = None + self.warmup_steps = warmup_steps + self.start_step = start_step + self.lambda_gp = lambda_gp + + if optimizer_params is None: + optimizer_params = {} + self.optimizer_params = optimizer_params + self.print = self.process.print + print(f" Critic config: {self.__dict__}") + + def setup(self): + self.model = CriticModel().to(self.device) + self.load_weights() + self.model.train() + self.model.requires_grad_(True) + params = self.model.parameters() + self.optimizer = get_optimizer( + params, + self.optimizer_type, + self.learning_rate, + optimizer_params=self.optimizer_params, + ) + self.scheduler = torch.optim.lr_scheduler.ConstantLR( + self.optimizer, + total_iters=self.process.max_steps * self.num_critic_per_gen, + factor=1, + verbose=False, + ) + + def load_weights(self): + path_to_load = None + self.print(f"Critic: Looking for latest checkpoint in {self.process.save_root}") + files = glob.glob(os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}*.safetensors")) + if files: + latest_file = max(files, key=os.path.getmtime) + print(f" - Latest checkpoint is: {latest_file}") + path_to_load = latest_file + else: + self.print(" - No checkpoint found, starting from scratch") + if path_to_load: + self.model.load_state_dict(load_file(path_to_load)) + + def save(self, step=None): + self.process.update_training_metadata() + save_meta = get_meta_for_safetensors(self.process.meta, self.process.job.name) + step_num = f"_{str(step).zfill(9)}" if step is not None else '' + save_path = os.path.join( + self.process.save_root, f"CRITIC_{self.process.job.name}{step_num}.safetensors" + ) + save_file(self.model.state_dict(), save_path, save_meta) + self.print(f"Saved critic to {save_path}") + + def get_critic_loss(self, vgg_output): + # (caller still passes combined [pred|target] images) + if self.start_step > self.process.step_num: + return torch.tensor(0.0, dtype=self.torch_dtype, device=self.device) + + warmup_scaler = 1.0 + if self.process.step_num < self.start_step + self.warmup_steps: + warmup_scaler = (self.process.step_num - self.start_step) / self.warmup_steps + + self.model.eval() + self.model.requires_grad_(False) + + vgg_pred, _ = torch.chunk(vgg_output.float(), 2, dim=0) + stacked_output = self.model(vgg_pred) + return (-torch.mean(stacked_output)) * warmup_scaler + + def step(self, vgg_output): + self.model.train() + self.model.requires_grad_(True) + self.optimizer.zero_grad() + + critic_losses = [] + inputs = vgg_output.detach().to(self.device, dtype=torch.float32) + + vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0) + stacked_output = self.model(inputs).float() + out_pred, out_target = torch.chunk(stacked_output, 2, dim=0) + + # hinge loss + gradient penalty + loss_real = torch.relu(1.0 - out_target).mean() + loss_fake = torch.relu(1.0 + out_pred).mean() + gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device) + critic_loss = loss_real + loss_fake + self.lambda_gp * gradient_penalty + + critic_loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + self.optimizer.step() + self.scheduler.step() + critic_losses.append(critic_loss.item()) + + return float(np.mean(critic_losses)) + + def get_lr(self): + if self.optimizer_type.startswith('dadaptation'): + return ( + self.optimizer.param_groups[0]["d"] + * self.optimizer.param_groups[0]["lr"] + ) + return self.optimizer.param_groups[0]["lr"] diff --git a/jobs/process/models/vgg19_critic.py b/jobs/process/models/vgg19_critic.py index 8cf438bf..4d7f74f9 100644 --- a/jobs/process/models/vgg19_critic.py +++ b/jobs/process/models/vgg19_critic.py @@ -33,11 +33,20 @@ class Vgg19Critic(nn.Module): super(Vgg19Critic, self).__init__() self.main = nn.Sequential( # input (bs, 512, 32, 32) - nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1), + # nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1), + nn.utils.spectral_norm( # SN keeps D’s scale in check + nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1) + ), nn.LeakyReLU(0.2), # (bs, 512, 16, 16) - nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1), + # nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1), + nn.utils.spectral_norm( + nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1) + ), nn.LeakyReLU(0.2), # (bs, 512, 8, 8) - nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1), + # nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1), + nn.utils.spectral_norm( + nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1) + ), # (bs, 1, 4, 4) MeanReduce(), # (bs, 1, 1, 1) nn.Flatten(), # (bs, 1) @@ -47,7 +56,9 @@ class Vgg19Critic(nn.Module): ) def forward(self, inputs): - return self.main(inputs) + # return self.main(inputs) + with torch.cuda.amp.autocast(False): + return self.main(inputs.float()) if TYPE_CHECKING: @@ -92,7 +103,7 @@ class Critic: print(f" Critic config: {self.__dict__}") def setup(self): - self.model = Vgg19Critic().to(self.device, dtype=self.torch_dtype) + self.model = Vgg19Critic().to(self.device) self.load_weights() self.model.train() self.model.requires_grad_(True) @@ -142,7 +153,8 @@ class Critic: # set model to not train for generator loss self.model.eval() self.model.requires_grad_(False) - vgg_pred, vgg_target = torch.chunk(vgg_output, 2, dim=0) + # vgg_pred, vgg_target = torch.chunk(vgg_output, 2, dim=0) + vgg_pred, vgg_target = torch.chunk(vgg_output.float(), 2, dim=0) # run model stacked_output = self.model(vgg_pred) @@ -157,20 +169,34 @@ class Critic: self.optimizer.zero_grad() critic_losses = [] - inputs = vgg_output.detach() - inputs = inputs.to(self.device, dtype=self.torch_dtype) + # inputs = vgg_output.detach() + # inputs = inputs.to(self.device, dtype=self.torch_dtype) + inputs = vgg_output.detach().to(self.device, dtype=torch.float32) self.optimizer.zero_grad() vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0) + # stacked_output = self.model(inputs).float() + # out_pred, out_target = torch.chunk(stacked_output, 2, dim=0) + + # # Compute gradient penalty + # gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device) + + # # Compute WGAN-GP critic loss + # critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty + stacked_output = self.model(inputs).float() out_pred, out_target = torch.chunk(stacked_output, 2, dim=0) - # Compute gradient penalty + # ── hinge loss ── + loss_real = torch.relu(1.0 - out_target).mean() + loss_fake = torch.relu(1.0 + out_pred).mean() + + # gradient penalty (unchanged helper) gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device) - # Compute WGAN-GP critic loss - critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty + critic_loss = loss_real + loss_fake + self.lambda_gp * gradient_penalty + critic_loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self.optimizer.step()