Work on vae trainer

This commit is contained in:
Jaret Burkett
2025-05-28 07:42:48 -06:00
parent 79bb9be92b
commit 34f4c14cd6
3 changed files with 538 additions and 39 deletions

View File

@@ -7,6 +7,7 @@ from collections import OrderedDict
from PIL import Image from PIL import Image
from PIL.ImageOps import exif_transpose from PIL.ImageOps import exif_transpose
from einops import rearrange
from safetensors.torch import save_file, load_file from safetensors.torch import save_file, load_file
from torch.utils.data import DataLoader, ConcatDataset from torch.utils.data import DataLoader, ConcatDataset
import torch import torch
@@ -17,18 +18,22 @@ from jobs.process import BaseTrainProcess
from toolkit.image_utils import show_tensors from toolkit.image_utils import show_tensors
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
from toolkit.data_loader import ImageDataset from toolkit.data_loader import ImageDataset
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss, total_variation
from toolkit.metadata import get_meta_for_safetensors from toolkit.metadata import get_meta_for_safetensors
from toolkit.optimizer import get_optimizer from toolkit.optimizer import get_optimizer
from toolkit.style import get_style_model_and_losses from toolkit.style import get_style_model_and_losses
from toolkit.train_tools import get_torch_dtype from toolkit.train_tools import get_torch_dtype
from diffusers import AutoencoderKL from diffusers import AutoencoderKL
from tqdm import tqdm from tqdm import tqdm
import math
import torchvision.utils
import time import time
import numpy as np import numpy as np
from .models.vgg19_critic import Critic from .models.critic import Critic
from torchvision.transforms import Resize from torchvision.transforms import Resize
import lpips import lpips
import random
import traceback
IMAGE_TRANSFORMS = transforms.Compose( IMAGE_TRANSFORMS = transforms.Compose(
[ [
@@ -42,13 +47,21 @@ def unnormalize(tensor):
return (tensor / 2 + 0.5).clamp(0, 1) return (tensor / 2 + 0.5).clamp(0, 1)
def channel_dropout(x, p=0.5):
keep_prob = 1 - p
mask = torch.rand(x.size(0), x.size(1), 1, 1, device=x.device, dtype=x.dtype) < keep_prob
mask = mask / keep_prob # scale
return x * mask
class TrainVAEProcess(BaseTrainProcess): class TrainVAEProcess(BaseTrainProcess):
def __init__(self, process_id: int, job, config: OrderedDict): def __init__(self, process_id: int, job, config: OrderedDict):
super().__init__(process_id, job, config) super().__init__(process_id, job, config)
self.data_loader = None self.data_loader = None
self.vae = None self.vae = None
self.device = self.get_conf('device', self.job.device) self.device = self.get_conf('device', self.job.device)
self.vae_path = self.get_conf('vae_path', required=True) self.vae_path = self.get_conf('vae_path', None)
self.eq_vae = self.get_conf('eq_vae', False)
self.datasets_objects = self.get_conf('datasets', required=True) self.datasets_objects = self.get_conf('datasets', required=True)
self.batch_size = self.get_conf('batch_size', 1, as_type=int) self.batch_size = self.get_conf('batch_size', 1, as_type=int)
self.resolution = self.get_conf('resolution', 256, as_type=int) self.resolution = self.get_conf('resolution', 256, as_type=int)
@@ -65,11 +78,24 @@ class TrainVAEProcess(BaseTrainProcess):
self.content_weight = self.get_conf('content_weight', 0, as_type=float) self.content_weight = self.get_conf('content_weight', 0, as_type=float)
self.kld_weight = self.get_conf('kld_weight', 0, as_type=float) self.kld_weight = self.get_conf('kld_weight', 0, as_type=float)
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float) self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float) self.mv_loss_weight = self.get_conf('mv_loss_weight', 0, as_type=float)
self.tv_weight = self.get_conf('tv_weight', 0, as_type=float)
self.ltv_weight = self.get_conf('ltv_weight', 0, as_type=float)
self.lpm_weight = self.get_conf('lpm_weight', 0, as_type=float) # latent pixel matching
self.lpips_weight = self.get_conf('lpips_weight', 1e0, as_type=float) self.lpips_weight = self.get_conf('lpips_weight', 1e0, as_type=float)
self.critic_weight = self.get_conf('critic_weight', 1, as_type=float) self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float) self.pattern_weight = self.get_conf('pattern_weight', 0, as_type=float)
self.optimizer_params = self.get_conf('optimizer_params', {}) self.optimizer_params = self.get_conf('optimizer_params', {})
self.vae_config = self.get_conf('vae_config', None)
self.dropout = self.get_conf('dropout', 0.0, as_type=float)
self.train_encoder = self.get_conf('train_encoder', False, as_type=bool)
if not self.train_encoder:
# remove losses that only target encoder
self.kld_weight = 0
self.mv_loss_weight = 0
self.ltv_weight = 0
self.lpm_weight = 0
self.blocks_to_train = self.get_conf('blocks_to_train', ['all']) self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
self.torch_dtype = get_torch_dtype(self.dtype) self.torch_dtype = get_torch_dtype(self.dtype)
@@ -142,7 +168,7 @@ class TrainVAEProcess(BaseTrainProcess):
concatenated_dataset, concatenated_dataset,
batch_size=self.batch_size, batch_size=self.batch_size,
shuffle=True, shuffle=True,
num_workers=6 num_workers=8
) )
def remove_oldest_checkpoint(self): def remove_oldest_checkpoint(self):
@@ -153,6 +179,13 @@ class TrainVAEProcess(BaseTrainProcess):
for folder in folders[:-max_to_keep]: for folder in folders[:-max_to_keep]:
print(f"Removing {folder}") print(f"Removing {folder}")
shutil.rmtree(folder) shutil.rmtree(folder)
# also handle CRITIC_vae_42_000000500.safetensors format for critic
critic_files = glob.glob(os.path.join(self.save_root, f"CRITIC_{self.job.name}*.safetensors"))
if len(critic_files) > max_to_keep:
critic_files.sort(key=os.path.getmtime)
for file in critic_files[:-max_to_keep]:
print(f"Removing {file}")
os.remove(file)
def setup_vgg19(self): def setup_vgg19(self):
if self.vgg_19 is None: if self.vgg_19 is None:
@@ -218,6 +251,62 @@ class TrainVAEProcess(BaseTrainProcess):
else: else:
return torch.tensor(0.0, device=self.device) return torch.tensor(0.0, device=self.device)
def get_mean_variance_loss(self, latents: torch.Tensor):
if self.mv_loss_weight > 0:
# collapse rows into channels
latents_col = rearrange(latents, 'b c h (gw w) -> b (c gw) h w', gw=latents.shape[-1])
mean_col = latents_col.mean(dim=(2, 3), keepdim=True)
std_col = latents_col.std(dim=(2, 3), keepdim=True, unbiased=False)
mean_loss_col = (mean_col ** 2).mean()
std_loss_col = ((std_col - 1) ** 2).mean()
# collapse columns into channels
latents_row = rearrange(latents, 'b c (gh h) w -> b (c gh) h w', gh=latents.shape[-2])
mean_row = latents_row.mean(dim=(2, 3), keepdim=True)
std_row = latents_row.std(dim=(2, 3), keepdim=True, unbiased=False)
mean_loss_row = (mean_row ** 2).mean()
std_loss_row = ((std_row - 1) ** 2).mean()
# do a global one
mean = latents.mean(dim=(2, 3), keepdim=True)
std = latents.std(dim=(2, 3), keepdim=True, unbiased=False)
mean_loss_global = (mean ** 2).mean()
std_loss_global = ((std - 1) ** 2).mean()
return (mean_loss_col + std_loss_col + mean_loss_row + std_loss_row + mean_loss_global + std_loss_global) / 3
else:
return torch.tensor(0.0, device=self.device)
def get_ltv_loss(self, latent):
# loss to reduce the latent space variance
if self.ltv_weight > 0:
return total_variation(latent).mean()
else:
return torch.tensor(0.0, device=self.device)
def get_latent_pixel_matching_loss(self, latent, pixels):
if self.lpm_weight > 0:
with torch.no_grad():
pixels = pixels.to(latent.device, dtype=latent.dtype)
# resize down to latent size
pixels = torch.nn.functional.interpolate(pixels, size=(latent.shape[2], latent.shape[3]), mode='bilinear', align_corners=False)
# mean the color channel and then expand to latent size
pixels = pixels.mean(dim=1, keepdim=True)
pixels = pixels.repeat(1, latent.shape[1], 1, 1)
# match the mean std of latent
latent_mean = latent.mean(dim=(2, 3), keepdim=True)
latent_std = latent.std(dim=(2, 3), keepdim=True)
pixels_mean = pixels.mean(dim=(2, 3), keepdim=True)
pixels_std = pixels.std(dim=(2, 3), keepdim=True)
pixels = (pixels - pixels_mean) / (pixels_std + 1e-6) * latent_std + latent_mean
return torch.nn.functional.mse_loss(latent.float(), pixels.float())
else:
return torch.tensor(0.0, device=self.device)
def get_tv_loss(self, pred, target): def get_tv_loss(self, pred, target):
if self.tv_weight > 0: if self.tv_weight > 0:
get_tv_loss = ComparativeTotalVariation() get_tv_loss = ComparativeTotalVariation()
@@ -277,7 +366,39 @@ class TrainVAEProcess(BaseTrainProcess):
input_img = img input_img = img
img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype) img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype)
img = img img = img
decoded = self.vae(img).sample latent = self.vae.encode(img).latent_dist.sample()
latent_img = latent.clone()
bs, ch, h, w = latent_img.shape
grid_size = math.ceil(math.sqrt(ch))
pad = grid_size * grid_size - ch
# take first item in batch
latent_img = latent_img[0] # shape: (ch, h, w)
if pad > 0:
padding = torch.zeros((pad, h, w), dtype=latent_img.dtype, device=latent_img.device)
latent_img = torch.cat([latent_img, padding], dim=0)
# make grid
new_img = torch.zeros((1, grid_size * h, grid_size * w), dtype=latent_img.dtype, device=latent_img.device)
for x in range(grid_size):
for y in range(grid_size):
if x * grid_size + y < ch:
new_img[0, x * h:(x + 1) * h, y * w:(y + 1) * w] = latent_img[x * grid_size + y]
latent_img = new_img
# make rgb
latent_img = latent_img.repeat(3, 1, 1).unsqueeze(0)
latent_img = (latent_img / 2 + 0.5).clamp(0, 1)
# resize to 256x256
latent_img = torch.nn.functional.interpolate(latent_img, size=(self.resolution, self.resolution), mode='nearest')
latent_img = latent_img.squeeze(0).cpu().permute(1, 2, 0).float().numpy()
latent_img = (latent_img * 255).astype(np.uint8)
# convert to pillow image
latent_img = Image.fromarray(latent_img)
decoded = self.vae.decode(latent).sample
decoded = (decoded / 2 + 0.5).clamp(0, 1) decoded = (decoded / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
decoded = decoded.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy() decoded = decoded.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
@@ -289,9 +410,10 @@ class TrainVAEProcess(BaseTrainProcess):
input_img = input_img.resize((self.resolution, self.resolution)) input_img = input_img.resize((self.resolution, self.resolution))
decoded = decoded.resize((self.resolution, self.resolution)) decoded = decoded.resize((self.resolution, self.resolution))
output_img = Image.new('RGB', (self.resolution * 2, self.resolution)) output_img = Image.new('RGB', (self.resolution * 3, self.resolution))
output_img.paste(input_img, (0, 0)) output_img.paste(input_img, (0, 0))
output_img.paste(decoded, (self.resolution, 0)) output_img.paste(decoded, (self.resolution, 0))
output_img.paste(latent_img, (self.resolution * 2, 0))
scale_up = 2 scale_up = 2
if output_img.height <= 300: if output_img.height <= 300:
@@ -326,12 +448,20 @@ class TrainVAEProcess(BaseTrainProcess):
self.print(f"Loading VAE") self.print(f"Loading VAE")
self.print(f" - Loading VAE: {path_to_load}") self.print(f" - Loading VAE: {path_to_load}")
if self.vae is None: if self.vae is None:
self.vae = AutoencoderKL.from_pretrained(path_to_load) if path_to_load is not None:
self.vae = AutoencoderKL.from_pretrained(path_to_load)
elif self.vae_config is not None:
self.vae = AutoencoderKL(**self.vae_config)
else:
raise ValueError('vae_path or ae_config must be specified')
# set decoder to train # set decoder to train
self.vae.to(self.device, dtype=self.torch_dtype) self.vae.to(self.device, dtype=self.torch_dtype)
self.vae.requires_grad_(False) if self.eq_vae:
self.vae.eval() self.vae.encoder.train()
else:
self.vae.requires_grad_(False)
self.vae.eval()
self.vae.decoder.train() self.vae.decoder.train()
self.vae_scale_factor = 2 ** (len(self.vae.config['block_out_channels']) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config['block_out_channels']) - 1)
@@ -374,6 +504,10 @@ class TrainVAEProcess(BaseTrainProcess):
if train_all: if train_all:
params = list(self.vae.decoder.parameters()) params = list(self.vae.decoder.parameters())
self.vae.decoder.requires_grad_(True) self.vae.decoder.requires_grad_(True)
if self.train_encoder:
# encoder
params += list(self.vae.encoder.parameters())
self.vae.encoder.requires_grad_(True)
else: else:
# mid_block # mid_block
if train_all or 'mid_block' in self.blocks_to_train: if train_all or 'mid_block' in self.blocks_to_train:
@@ -388,12 +522,13 @@ class TrainVAEProcess(BaseTrainProcess):
params += list(self.vae.decoder.conv_out.parameters()) params += list(self.vae.decoder.conv_out.parameters())
self.vae.decoder.conv_out.requires_grad_(True) self.vae.decoder.conv_out.requires_grad_(True)
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic: if self.style_weight > 0 or self.content_weight > 0:
self.setup_vgg19() self.setup_vgg19()
self.vgg_19.requires_grad_(False) # self.vgg_19.requires_grad_(False)
self.vgg_19.eval() self.vgg_19.eval()
if self.use_critic:
self.critic.setup() if self.use_critic:
self.critic.setup()
if self.lpips_weight > 0 and self.lpips_loss is None: if self.lpips_weight > 0 and self.lpips_loss is None:
# self.lpips_loss = lpips.LPIPS(net='vgg') # self.lpips_loss = lpips.LPIPS(net='vgg')
@@ -426,6 +561,9 @@ class TrainVAEProcess(BaseTrainProcess):
"style": [], "style": [],
"content": [], "content": [],
"mse": [], "mse": [],
"mvl": [],
"ltv": [],
"lpm": [],
"kl": [], "kl": [],
"tv": [], "tv": [],
"ptn": [], "ptn": [],
@@ -451,27 +589,83 @@ class TrainVAEProcess(BaseTrainProcess):
batch.shape[3] // self.vae_scale_factor * self.vae_scale_factor))(batch) batch.shape[3] // self.vae_scale_factor * self.vae_scale_factor))(batch)
# forward pass # forward pass
# grad only if eq_vae
with torch.set_grad_enabled(self.train_encoder):
dgd = self.vae.encode(batch).latent_dist dgd = self.vae.encode(batch).latent_dist
mu, logvar = dgd.mean, dgd.logvar mu, logvar = dgd.mean, dgd.logvar
latents = dgd.sample() latents = dgd.sample()
latents.detach().requires_grad_(True)
if self.eq_vae:
# process flips, rotate, scale
latent_chunks = list(latents.chunk(latents.shape[0], dim=0))
batch_chunks = list(batch.chunk(batch.shape[0], dim=0))
out_chunks = []
for i in range(len(latent_chunks)):
try:
do_rotate = random.randint(0, 3)
do_flip_x = random.randint(0, 1)
do_flip_y = random.randint(0, 1)
do_scale = random.randint(0, 1)
if do_rotate > 0:
latent_chunks[i] = torch.rot90(latent_chunks[i], do_rotate, (2, 3))
batch_chunks[i] = torch.rot90(batch_chunks[i], do_rotate, (2, 3))
if do_flip_x > 0:
latent_chunks[i] = torch.flip(latent_chunks[i], [2])
batch_chunks[i] = torch.flip(batch_chunks[i], [2])
if do_flip_y > 0:
latent_chunks[i] = torch.flip(latent_chunks[i], [3])
batch_chunks[i] = torch.flip(batch_chunks[i], [3])
# if do_scale > 0:
# scale = 2
# start_latent_h = latent_chunks[i].shape[2]
# start_latent_w = latent_chunks[i].shape[3]
# start_batch_h = batch_chunks[i].shape[2]
# start_batch_w = batch_chunks[i].shape[3]
# latent_chunks[i] = torch.nn.functional.interpolate(latent_chunks[i], scale_factor=scale, mode='bilinear', align_corners=False)
# batch_chunks[i] = torch.nn.functional.interpolate(batch_chunks[i], scale_factor=scale, mode='bilinear', align_corners=False)
# # random crop. latent is smaller than match but crops need to match
# latent_x = random.randint(0, latent_chunks[i].shape[2] - start_latent_h)
# latent_y = random.randint(0, latent_chunks[i].shape[3] - start_latent_w)
# batch_x = latent_x * self.vae_scale_factor
# batch_y = latent_y * self.vae_scale_factor
# # crop
# latent_chunks[i] = latent_chunks[i][:, :, latent_x:latent_x + start_latent_h, latent_y:latent_y + start_latent_w]
# batch_chunks[i] = batch_chunks[i][:, :, batch_x:batch_x + start_batch_h, batch_y:batch_y + start_batch_w]
except Exception as e:
print(f"Error processing image {i}: {e}")
traceback.print_exc()
raise e
out_chunks.append(latent_chunks[i])
latents = torch.cat(out_chunks, dim=0)
# do dropout
if self.dropout > 0:
forward_latents = channel_dropout(latents, self.dropout)
else:
forward_latents = latents
batch = torch.cat(batch_chunks, dim=0)
else:
latents.detach().requires_grad_(True)
forward_latents = latents
forward_latents = forward_latents.to(self.device, dtype=self.torch_dtype)
if not self.train_encoder:
# detach latents if not training encoder
forward_latents = forward_latents.detach()
pred = self.vae.decode(latents).sample pred = self.vae.decode(forward_latents).sample
with torch.no_grad():
show_tensors(
pred.clamp(-1, 1).clone(),
"combined tensor"
)
# Run through VGG19 # Run through VGG19
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic: if self.style_weight > 0 or self.content_weight > 0:
stacked = torch.cat([pred, batch], dim=0) stacked = torch.cat([pred, batch], dim=0)
stacked = (stacked / 2 + 0.5).clamp(0, 1) stacked = (stacked / 2 + 0.5).clamp(0, 1)
self.vgg_19(stacked) self.vgg_19(stacked)
if self.use_critic: if self.use_critic:
critic_d_loss = self.critic.step(self.vgg19_pool_4.tensor.detach()) stacked = torch.cat([pred, batch], dim=0)
critic_d_loss = self.critic.step(stacked.detach())
else: else:
critic_d_loss = 0.0 critic_d_loss = 0.0
@@ -489,7 +683,8 @@ class TrainVAEProcess(BaseTrainProcess):
tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight
pattern_loss = self.get_pattern_loss(pred, batch) * self.pattern_weight pattern_loss = self.get_pattern_loss(pred, batch) * self.pattern_weight
if self.use_critic: if self.use_critic:
critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight stacked = torch.cat([pred, batch], dim=0)
critic_gen_loss = self.critic.get_critic_loss(stacked) * self.critic_weight
# do not let abs critic gen loss be higher than abs lpips * 0.1 if using it # do not let abs critic gen loss be higher than abs lpips * 0.1 if using it
if self.lpips_weight > 0: if self.lpips_weight > 0:
@@ -502,8 +697,42 @@ class TrainVAEProcess(BaseTrainProcess):
critic_gen_loss *= crit_g_scaler critic_gen_loss *= crit_g_scaler
else: else:
critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
if self.mv_loss_weight > 0:
mv_loss = self.get_mean_variance_loss(latents) * self.mv_loss_weight
else:
mv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
if self.ltv_weight > 0:
ltv_loss = self.get_ltv_loss(latents) * self.ltv_weight
else:
ltv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
if self.lpm_weight > 0:
lpm_loss = self.get_latent_pixel_matching_loss(latents, batch) * self.lpm_weight
else:
lpm_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + mv_loss + ltv_loss
# check if loss is NaN or Inf
if torch.isnan(loss) or torch.isinf(loss):
self.print(f"Loss is NaN or Inf, stopping at step {self.step_num}")
self.print(f" - Style loss: {style_loss.item()}")
self.print(f" - Content loss: {content_loss.item()}")
self.print(f" - KLD loss: {kld_loss.item()}")
self.print(f" - MSE loss: {mse_loss.item()}")
self.print(f" - LPIPS loss: {lpips_loss.item()}")
self.print(f" - TV loss: {tv_loss.item()}")
self.print(f" - Pattern loss: {pattern_loss.item()}")
self.print(f" - Critic gen loss: {critic_gen_loss.item()}")
self.print(f" - Critic D loss: {critic_d_loss}")
self.print(f" - Mean variance loss: {mv_loss.item()}")
self.print(f" - Latent TV loss: {ltv_loss.item()}")
self.print(f" - Latent pixel matching loss: {lpm_loss.item()}")
self.print(f" - Total loss: {loss.item()}")
self.print(f" - Stopping training")
exit(1)
# Backward pass and optimization # Backward pass and optimization
optimizer.zero_grad() optimizer.zero_grad()
@@ -533,8 +762,17 @@ class TrainVAEProcess(BaseTrainProcess):
loss_string += f" crG: {critic_gen_loss.item():.2e}" loss_string += f" crG: {critic_gen_loss.item():.2e}"
if self.use_critic: if self.use_critic:
loss_string += f" crD: {critic_d_loss:.2e}" loss_string += f" crD: {critic_d_loss:.2e}"
if self.mv_loss_weight > 0:
loss_string += f" mvl: {mv_loss:.2e}"
if self.ltv_weight > 0:
loss_string += f" ltv: {ltv_loss:.2e}"
if self.lpm_weight > 0:
loss_string += f" lpm: {lpm_loss:.2e}"
if self.optimizer_type.startswith('dadaptation') or \ if hasattr(optimizer, 'get_avg_learning_rate'):
learning_rate = optimizer.get_avg_learning_rate()
elif self.optimizer_type.startswith('dadaptation') or \
self.optimizer_type.lower().startswith('prodigy'): self.optimizer_type.lower().startswith('prodigy'):
learning_rate = ( learning_rate = (
optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["d"] *
@@ -562,6 +800,9 @@ class TrainVAEProcess(BaseTrainProcess):
epoch_losses["ptn"].append(pattern_loss.item()) epoch_losses["ptn"].append(pattern_loss.item())
epoch_losses["crG"].append(critic_gen_loss.item()) epoch_losses["crG"].append(critic_gen_loss.item())
epoch_losses["crD"].append(critic_d_loss) epoch_losses["crD"].append(critic_d_loss)
epoch_losses["mvl"].append(mv_loss.item())
epoch_losses["ltv"].append(ltv_loss.item())
epoch_losses["lpm"].append(lpm_loss.item())
log_losses["total"].append(loss_value) log_losses["total"].append(loss_value)
log_losses["lpips"].append(lpips_loss.item()) log_losses["lpips"].append(lpips_loss.item())
@@ -573,6 +814,9 @@ class TrainVAEProcess(BaseTrainProcess):
log_losses["ptn"].append(pattern_loss.item()) log_losses["ptn"].append(pattern_loss.item())
log_losses["crG"].append(critic_gen_loss.item()) log_losses["crG"].append(critic_gen_loss.item())
log_losses["crD"].append(critic_d_loss) log_losses["crD"].append(critic_d_loss)
log_losses["mvl"].append(mv_loss.item())
log_losses["ltv"].append(ltv_loss.item())
log_losses["lpm"].append(lpm_loss.item())
# don't do on first step # don't do on first step
if self.step_num != start_step: if self.step_num != start_step:

View File

@@ -0,0 +1,229 @@
import glob
import os
from typing import TYPE_CHECKING, Union
import numpy as np
import torch
import torch.nn as nn
from safetensors.torch import load_file, save_file
from toolkit.losses import get_gradient_penalty
from toolkit.metadata import get_meta_for_safetensors
from toolkit.optimizer import get_optimizer
from toolkit.train_tools import get_torch_dtype
class MeanReduce(nn.Module):
def __init__(self):
super().__init__()
def forward(self, inputs):
# global mean over spatial dims (keeps channel/batch)
return torch.mean(inputs, dim=(2, 3), keepdim=True)
class SelfAttention2d(nn.Module):
"""
Lightweight self-attention layer (SAGAN-style) that keeps spatial
resolution unchanged. Adds minimal params / compute but improves
long-range modelling helpful for variable-sized inputs.
"""
def __init__(self, in_channels: int):
super().__init__()
self.query = nn.Conv1d(in_channels, in_channels // 8, 1)
self.key = nn.Conv1d(in_channels, in_channels // 8, 1)
self.value = nn.Conv1d(in_channels, in_channels, 1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
B, C, H, W = x.shape
flat = x.view(B, C, H * W) # (B,C,N)
q = self.query(flat).permute(0, 2, 1) # (B,N,C//8)
k = self.key(flat) # (B,C//8,N)
attn = torch.bmm(q, k) # (B,N,N)
attn = attn.softmax(dim=-1) # softmax along last dim
v = self.value(flat) # (B,C,N)
out = torch.bmm(v, attn.permute(0, 2, 1)) # (B,C,N)
out = out.view(B, C, H, W) # restore spatial dims
return self.gamma * out + x # residual
class CriticModel(nn.Module):
def __init__(self, base_channels: int = 64):
super().__init__()
def sn_conv(in_c, out_c, k, s, p):
return nn.utils.spectral_norm(
nn.Conv2d(in_c, out_c, kernel_size=k, stride=s, padding=p)
)
layers = [
# initial down-sample
sn_conv(3, base_channels, 3, 2, 1),
nn.LeakyReLU(0.2, inplace=True),
]
in_c = base_channels
# progressive downsamples ×3 (64→128→256→512)
for _ in range(3):
out_c = min(in_c * 2, 1024)
layers += [
sn_conv(in_c, out_c, 3, 2, 1),
nn.LeakyReLU(0.2, inplace=True),
]
# single attention block after reaching 256 channels
if out_c == 256:
layers += [SelfAttention2d(out_c)]
in_c = out_c
# extra depth (keeps spatial size)
layers += [
sn_conv(in_c, 1024, 3, 1, 1),
nn.LeakyReLU(0.2, inplace=True),
# final 1-channel prediction map
sn_conv(1024, 1, 3, 1, 1),
MeanReduce(), # → (B,1,1,1)
nn.Flatten(), # → (B,1)
]
self.main = nn.Sequential(*layers)
def forward(self, inputs):
# force full-precision inside AMP ctx for stability
with torch.cuda.amp.autocast(False):
return self.main(inputs.float())
if TYPE_CHECKING:
from jobs.process.TrainVAEProcess import TrainVAEProcess
from jobs.process.TrainESRGANProcess import TrainESRGANProcess
class Critic:
process: Union['TrainVAEProcess', 'TrainESRGANProcess']
def __init__(
self,
learning_rate=1e-5,
device='cpu',
optimizer='adam',
num_critic_per_gen=1,
dtype='float32',
lambda_gp=10,
start_step=0,
warmup_steps=1000,
process=None,
optimizer_params=None,
):
self.learning_rate = learning_rate
self.device = device
self.optimizer_type = optimizer
self.num_critic_per_gen = num_critic_per_gen
self.dtype = dtype
self.torch_dtype = get_torch_dtype(self.dtype)
self.process = process
self.model = None
self.optimizer = None
self.scheduler = None
self.warmup_steps = warmup_steps
self.start_step = start_step
self.lambda_gp = lambda_gp
if optimizer_params is None:
optimizer_params = {}
self.optimizer_params = optimizer_params
self.print = self.process.print
print(f" Critic config: {self.__dict__}")
def setup(self):
self.model = CriticModel().to(self.device)
self.load_weights()
self.model.train()
self.model.requires_grad_(True)
params = self.model.parameters()
self.optimizer = get_optimizer(
params,
self.optimizer_type,
self.learning_rate,
optimizer_params=self.optimizer_params,
)
self.scheduler = torch.optim.lr_scheduler.ConstantLR(
self.optimizer,
total_iters=self.process.max_steps * self.num_critic_per_gen,
factor=1,
verbose=False,
)
def load_weights(self):
path_to_load = None
self.print(f"Critic: Looking for latest checkpoint in {self.process.save_root}")
files = glob.glob(os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}*.safetensors"))
if files:
latest_file = max(files, key=os.path.getmtime)
print(f" - Latest checkpoint is: {latest_file}")
path_to_load = latest_file
else:
self.print(" - No checkpoint found, starting from scratch")
if path_to_load:
self.model.load_state_dict(load_file(path_to_load))
def save(self, step=None):
self.process.update_training_metadata()
save_meta = get_meta_for_safetensors(self.process.meta, self.process.job.name)
step_num = f"_{str(step).zfill(9)}" if step is not None else ''
save_path = os.path.join(
self.process.save_root, f"CRITIC_{self.process.job.name}{step_num}.safetensors"
)
save_file(self.model.state_dict(), save_path, save_meta)
self.print(f"Saved critic to {save_path}")
def get_critic_loss(self, vgg_output):
# (caller still passes combined [pred|target] images)
if self.start_step > self.process.step_num:
return torch.tensor(0.0, dtype=self.torch_dtype, device=self.device)
warmup_scaler = 1.0
if self.process.step_num < self.start_step + self.warmup_steps:
warmup_scaler = (self.process.step_num - self.start_step) / self.warmup_steps
self.model.eval()
self.model.requires_grad_(False)
vgg_pred, _ = torch.chunk(vgg_output.float(), 2, dim=0)
stacked_output = self.model(vgg_pred)
return (-torch.mean(stacked_output)) * warmup_scaler
def step(self, vgg_output):
self.model.train()
self.model.requires_grad_(True)
self.optimizer.zero_grad()
critic_losses = []
inputs = vgg_output.detach().to(self.device, dtype=torch.float32)
vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
stacked_output = self.model(inputs).float()
out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
# hinge loss + gradient penalty
loss_real = torch.relu(1.0 - out_target).mean()
loss_fake = torch.relu(1.0 + out_pred).mean()
gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
critic_loss = loss_real + loss_fake + self.lambda_gp * gradient_penalty
critic_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
self.scheduler.step()
critic_losses.append(critic_loss.item())
return float(np.mean(critic_losses))
def get_lr(self):
if self.optimizer_type.startswith('dadaptation'):
return (
self.optimizer.param_groups[0]["d"]
* self.optimizer.param_groups[0]["lr"]
)
return self.optimizer.param_groups[0]["lr"]

View File

@@ -33,11 +33,20 @@ class Vgg19Critic(nn.Module):
super(Vgg19Critic, self).__init__() super(Vgg19Critic, self).__init__()
self.main = nn.Sequential( self.main = nn.Sequential(
# input (bs, 512, 32, 32) # input (bs, 512, 32, 32)
nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1), # nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1),
nn.utils.spectral_norm( # SN keeps Ds scale in check
nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1)
),
nn.LeakyReLU(0.2), # (bs, 512, 16, 16) nn.LeakyReLU(0.2), # (bs, 512, 16, 16)
nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1), # nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1),
nn.utils.spectral_norm(
nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1)
),
nn.LeakyReLU(0.2), # (bs, 512, 8, 8) nn.LeakyReLU(0.2), # (bs, 512, 8, 8)
nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1), # nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1),
nn.utils.spectral_norm(
nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1)
),
# (bs, 1, 4, 4) # (bs, 1, 4, 4)
MeanReduce(), # (bs, 1, 1, 1) MeanReduce(), # (bs, 1, 1, 1)
nn.Flatten(), # (bs, 1) nn.Flatten(), # (bs, 1)
@@ -47,7 +56,9 @@ class Vgg19Critic(nn.Module):
) )
def forward(self, inputs): def forward(self, inputs):
return self.main(inputs) # return self.main(inputs)
with torch.cuda.amp.autocast(False):
return self.main(inputs.float())
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -92,7 +103,7 @@ class Critic:
print(f" Critic config: {self.__dict__}") print(f" Critic config: {self.__dict__}")
def setup(self): def setup(self):
self.model = Vgg19Critic().to(self.device, dtype=self.torch_dtype) self.model = Vgg19Critic().to(self.device)
self.load_weights() self.load_weights()
self.model.train() self.model.train()
self.model.requires_grad_(True) self.model.requires_grad_(True)
@@ -142,7 +153,8 @@ class Critic:
# set model to not train for generator loss # set model to not train for generator loss
self.model.eval() self.model.eval()
self.model.requires_grad_(False) self.model.requires_grad_(False)
vgg_pred, vgg_target = torch.chunk(vgg_output, 2, dim=0) # vgg_pred, vgg_target = torch.chunk(vgg_output, 2, dim=0)
vgg_pred, vgg_target = torch.chunk(vgg_output.float(), 2, dim=0)
# run model # run model
stacked_output = self.model(vgg_pred) stacked_output = self.model(vgg_pred)
@@ -157,20 +169,34 @@ class Critic:
self.optimizer.zero_grad() self.optimizer.zero_grad()
critic_losses = [] critic_losses = []
inputs = vgg_output.detach() # inputs = vgg_output.detach()
inputs = inputs.to(self.device, dtype=self.torch_dtype) # inputs = inputs.to(self.device, dtype=self.torch_dtype)
inputs = vgg_output.detach().to(self.device, dtype=torch.float32)
self.optimizer.zero_grad() self.optimizer.zero_grad()
vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0) vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
# stacked_output = self.model(inputs).float()
# out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
# # Compute gradient penalty
# gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
# # Compute WGAN-GP critic loss
# critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty
stacked_output = self.model(inputs).float() stacked_output = self.model(inputs).float()
out_pred, out_target = torch.chunk(stacked_output, 2, dim=0) out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
# Compute gradient penalty # ── hinge loss ──
loss_real = torch.relu(1.0 - out_target).mean()
loss_fake = torch.relu(1.0 + out_pred).mean()
# gradient penalty (unchanged helper)
gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device) gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
# Compute WGAN-GP critic loss critic_loss = loss_real + loss_fake + self.lambda_gp * gradient_penalty
critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty
critic_loss.backward() critic_loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step() self.optimizer.step()