mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Work on vae trainer
This commit is contained in:
@@ -7,6 +7,7 @@ from collections import OrderedDict
|
||||
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from einops import rearrange
|
||||
from safetensors.torch import save_file, load_file
|
||||
from torch.utils.data import DataLoader, ConcatDataset
|
||||
import torch
|
||||
@@ -17,18 +18,22 @@ from jobs.process import BaseTrainProcess
|
||||
from toolkit.image_utils import show_tensors
|
||||
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
|
||||
from toolkit.data_loader import ImageDataset
|
||||
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss
|
||||
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss, total_variation
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.optimizer import get_optimizer
|
||||
from toolkit.style import get_style_model_and_losses
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
from diffusers import AutoencoderKL
|
||||
from tqdm import tqdm
|
||||
import math
|
||||
import torchvision.utils
|
||||
import time
|
||||
import numpy as np
|
||||
from .models.vgg19_critic import Critic
|
||||
from .models.critic import Critic
|
||||
from torchvision.transforms import Resize
|
||||
import lpips
|
||||
import random
|
||||
import traceback
|
||||
|
||||
IMAGE_TRANSFORMS = transforms.Compose(
|
||||
[
|
||||
@@ -42,13 +47,21 @@ def unnormalize(tensor):
|
||||
return (tensor / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
|
||||
def channel_dropout(x, p=0.5):
|
||||
keep_prob = 1 - p
|
||||
mask = torch.rand(x.size(0), x.size(1), 1, 1, device=x.device, dtype=x.dtype) < keep_prob
|
||||
mask = mask / keep_prob # scale
|
||||
return x * mask
|
||||
|
||||
|
||||
class TrainVAEProcess(BaseTrainProcess):
|
||||
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||
super().__init__(process_id, job, config)
|
||||
self.data_loader = None
|
||||
self.vae = None
|
||||
self.device = self.get_conf('device', self.job.device)
|
||||
self.vae_path = self.get_conf('vae_path', required=True)
|
||||
self.vae_path = self.get_conf('vae_path', None)
|
||||
self.eq_vae = self.get_conf('eq_vae', False)
|
||||
self.datasets_objects = self.get_conf('datasets', required=True)
|
||||
self.batch_size = self.get_conf('batch_size', 1, as_type=int)
|
||||
self.resolution = self.get_conf('resolution', 256, as_type=int)
|
||||
@@ -65,11 +78,24 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.content_weight = self.get_conf('content_weight', 0, as_type=float)
|
||||
self.kld_weight = self.get_conf('kld_weight', 0, as_type=float)
|
||||
self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
|
||||
self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
|
||||
self.mv_loss_weight = self.get_conf('mv_loss_weight', 0, as_type=float)
|
||||
self.tv_weight = self.get_conf('tv_weight', 0, as_type=float)
|
||||
self.ltv_weight = self.get_conf('ltv_weight', 0, as_type=float)
|
||||
self.lpm_weight = self.get_conf('lpm_weight', 0, as_type=float) # latent pixel matching
|
||||
self.lpips_weight = self.get_conf('lpips_weight', 1e0, as_type=float)
|
||||
self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
|
||||
self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float)
|
||||
self.pattern_weight = self.get_conf('pattern_weight', 0, as_type=float)
|
||||
self.optimizer_params = self.get_conf('optimizer_params', {})
|
||||
self.vae_config = self.get_conf('vae_config', None)
|
||||
self.dropout = self.get_conf('dropout', 0.0, as_type=float)
|
||||
self.train_encoder = self.get_conf('train_encoder', False, as_type=bool)
|
||||
|
||||
if not self.train_encoder:
|
||||
# remove losses that only target encoder
|
||||
self.kld_weight = 0
|
||||
self.mv_loss_weight = 0
|
||||
self.ltv_weight = 0
|
||||
self.lpm_weight = 0
|
||||
|
||||
self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
@@ -142,7 +168,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
concatenated_dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=6
|
||||
num_workers=8
|
||||
)
|
||||
|
||||
def remove_oldest_checkpoint(self):
|
||||
@@ -153,6 +179,13 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
for folder in folders[:-max_to_keep]:
|
||||
print(f"Removing {folder}")
|
||||
shutil.rmtree(folder)
|
||||
# also handle CRITIC_vae_42_000000500.safetensors format for critic
|
||||
critic_files = glob.glob(os.path.join(self.save_root, f"CRITIC_{self.job.name}*.safetensors"))
|
||||
if len(critic_files) > max_to_keep:
|
||||
critic_files.sort(key=os.path.getmtime)
|
||||
for file in critic_files[:-max_to_keep]:
|
||||
print(f"Removing {file}")
|
||||
os.remove(file)
|
||||
|
||||
def setup_vgg19(self):
|
||||
if self.vgg_19 is None:
|
||||
@@ -218,6 +251,62 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
else:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
def get_mean_variance_loss(self, latents: torch.Tensor):
|
||||
if self.mv_loss_weight > 0:
|
||||
# collapse rows into channels
|
||||
latents_col = rearrange(latents, 'b c h (gw w) -> b (c gw) h w', gw=latents.shape[-1])
|
||||
mean_col = latents_col.mean(dim=(2, 3), keepdim=True)
|
||||
std_col = latents_col.std(dim=(2, 3), keepdim=True, unbiased=False)
|
||||
mean_loss_col = (mean_col ** 2).mean()
|
||||
std_loss_col = ((std_col - 1) ** 2).mean()
|
||||
|
||||
# collapse columns into channels
|
||||
latents_row = rearrange(latents, 'b c (gh h) w -> b (c gh) h w', gh=latents.shape[-2])
|
||||
mean_row = latents_row.mean(dim=(2, 3), keepdim=True)
|
||||
std_row = latents_row.std(dim=(2, 3), keepdim=True, unbiased=False)
|
||||
mean_loss_row = (mean_row ** 2).mean()
|
||||
std_loss_row = ((std_row - 1) ** 2).mean()
|
||||
|
||||
# do a global one
|
||||
|
||||
mean = latents.mean(dim=(2, 3), keepdim=True)
|
||||
std = latents.std(dim=(2, 3), keepdim=True, unbiased=False)
|
||||
mean_loss_global = (mean ** 2).mean()
|
||||
std_loss_global = ((std - 1) ** 2).mean()
|
||||
|
||||
return (mean_loss_col + std_loss_col + mean_loss_row + std_loss_row + mean_loss_global + std_loss_global) / 3
|
||||
else:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
def get_ltv_loss(self, latent):
|
||||
# loss to reduce the latent space variance
|
||||
if self.ltv_weight > 0:
|
||||
return total_variation(latent).mean()
|
||||
else:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
def get_latent_pixel_matching_loss(self, latent, pixels):
|
||||
if self.lpm_weight > 0:
|
||||
with torch.no_grad():
|
||||
pixels = pixels.to(latent.device, dtype=latent.dtype)
|
||||
# resize down to latent size
|
||||
pixels = torch.nn.functional.interpolate(pixels, size=(latent.shape[2], latent.shape[3]), mode='bilinear', align_corners=False)
|
||||
|
||||
# mean the color channel and then expand to latent size
|
||||
pixels = pixels.mean(dim=1, keepdim=True)
|
||||
pixels = pixels.repeat(1, latent.shape[1], 1, 1)
|
||||
# match the mean std of latent
|
||||
latent_mean = latent.mean(dim=(2, 3), keepdim=True)
|
||||
latent_std = latent.std(dim=(2, 3), keepdim=True)
|
||||
pixels_mean = pixels.mean(dim=(2, 3), keepdim=True)
|
||||
pixels_std = pixels.std(dim=(2, 3), keepdim=True)
|
||||
pixels = (pixels - pixels_mean) / (pixels_std + 1e-6) * latent_std + latent_mean
|
||||
|
||||
return torch.nn.functional.mse_loss(latent.float(), pixels.float())
|
||||
|
||||
else:
|
||||
return torch.tensor(0.0, device=self.device)
|
||||
|
||||
def get_tv_loss(self, pred, target):
|
||||
if self.tv_weight > 0:
|
||||
get_tv_loss = ComparativeTotalVariation()
|
||||
@@ -277,7 +366,39 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
input_img = img
|
||||
img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype)
|
||||
img = img
|
||||
decoded = self.vae(img).sample
|
||||
latent = self.vae.encode(img).latent_dist.sample()
|
||||
|
||||
latent_img = latent.clone()
|
||||
bs, ch, h, w = latent_img.shape
|
||||
grid_size = math.ceil(math.sqrt(ch))
|
||||
pad = grid_size * grid_size - ch
|
||||
|
||||
# take first item in batch
|
||||
latent_img = latent_img[0] # shape: (ch, h, w)
|
||||
|
||||
if pad > 0:
|
||||
padding = torch.zeros((pad, h, w), dtype=latent_img.dtype, device=latent_img.device)
|
||||
latent_img = torch.cat([latent_img, padding], dim=0)
|
||||
|
||||
# make grid
|
||||
new_img = torch.zeros((1, grid_size * h, grid_size * w), dtype=latent_img.dtype, device=latent_img.device)
|
||||
for x in range(grid_size):
|
||||
for y in range(grid_size):
|
||||
if x * grid_size + y < ch:
|
||||
new_img[0, x * h:(x + 1) * h, y * w:(y + 1) * w] = latent_img[x * grid_size + y]
|
||||
latent_img = new_img
|
||||
# make rgb
|
||||
latent_img = latent_img.repeat(3, 1, 1).unsqueeze(0)
|
||||
latent_img = (latent_img / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# resize to 256x256
|
||||
latent_img = torch.nn.functional.interpolate(latent_img, size=(self.resolution, self.resolution), mode='nearest')
|
||||
latent_img = latent_img.squeeze(0).cpu().permute(1, 2, 0).float().numpy()
|
||||
latent_img = (latent_img * 255).astype(np.uint8)
|
||||
# convert to pillow image
|
||||
latent_img = Image.fromarray(latent_img)
|
||||
|
||||
decoded = self.vae.decode(latent).sample
|
||||
decoded = (decoded / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
decoded = decoded.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
|
||||
@@ -289,9 +410,10 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
input_img = input_img.resize((self.resolution, self.resolution))
|
||||
decoded = decoded.resize((self.resolution, self.resolution))
|
||||
|
||||
output_img = Image.new('RGB', (self.resolution * 2, self.resolution))
|
||||
output_img = Image.new('RGB', (self.resolution * 3, self.resolution))
|
||||
output_img.paste(input_img, (0, 0))
|
||||
output_img.paste(decoded, (self.resolution, 0))
|
||||
output_img.paste(latent_img, (self.resolution * 2, 0))
|
||||
|
||||
scale_up = 2
|
||||
if output_img.height <= 300:
|
||||
@@ -326,12 +448,20 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
self.print(f"Loading VAE")
|
||||
self.print(f" - Loading VAE: {path_to_load}")
|
||||
if self.vae is None:
|
||||
self.vae = AutoencoderKL.from_pretrained(path_to_load)
|
||||
if path_to_load is not None:
|
||||
self.vae = AutoencoderKL.from_pretrained(path_to_load)
|
||||
elif self.vae_config is not None:
|
||||
self.vae = AutoencoderKL(**self.vae_config)
|
||||
else:
|
||||
raise ValueError('vae_path or ae_config must be specified')
|
||||
|
||||
# set decoder to train
|
||||
self.vae.to(self.device, dtype=self.torch_dtype)
|
||||
self.vae.requires_grad_(False)
|
||||
self.vae.eval()
|
||||
if self.eq_vae:
|
||||
self.vae.encoder.train()
|
||||
else:
|
||||
self.vae.requires_grad_(False)
|
||||
self.vae.eval()
|
||||
self.vae.decoder.train()
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
||||
|
||||
@@ -374,6 +504,10 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
if train_all:
|
||||
params = list(self.vae.decoder.parameters())
|
||||
self.vae.decoder.requires_grad_(True)
|
||||
if self.train_encoder:
|
||||
# encoder
|
||||
params += list(self.vae.encoder.parameters())
|
||||
self.vae.encoder.requires_grad_(True)
|
||||
else:
|
||||
# mid_block
|
||||
if train_all or 'mid_block' in self.blocks_to_train:
|
||||
@@ -388,12 +522,13 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
params += list(self.vae.decoder.conv_out.parameters())
|
||||
self.vae.decoder.conv_out.requires_grad_(True)
|
||||
|
||||
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
|
||||
if self.style_weight > 0 or self.content_weight > 0:
|
||||
self.setup_vgg19()
|
||||
self.vgg_19.requires_grad_(False)
|
||||
# self.vgg_19.requires_grad_(False)
|
||||
self.vgg_19.eval()
|
||||
if self.use_critic:
|
||||
self.critic.setup()
|
||||
|
||||
if self.use_critic:
|
||||
self.critic.setup()
|
||||
|
||||
if self.lpips_weight > 0 and self.lpips_loss is None:
|
||||
# self.lpips_loss = lpips.LPIPS(net='vgg')
|
||||
@@ -426,6 +561,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
"style": [],
|
||||
"content": [],
|
||||
"mse": [],
|
||||
"mvl": [],
|
||||
"ltv": [],
|
||||
"lpm": [],
|
||||
"kl": [],
|
||||
"tv": [],
|
||||
"ptn": [],
|
||||
@@ -451,27 +589,83 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
batch.shape[3] // self.vae_scale_factor * self.vae_scale_factor))(batch)
|
||||
|
||||
# forward pass
|
||||
# grad only if eq_vae
|
||||
with torch.set_grad_enabled(self.train_encoder):
|
||||
dgd = self.vae.encode(batch).latent_dist
|
||||
mu, logvar = dgd.mean, dgd.logvar
|
||||
latents = dgd.sample()
|
||||
latents.detach().requires_grad_(True)
|
||||
|
||||
if self.eq_vae:
|
||||
# process flips, rotate, scale
|
||||
latent_chunks = list(latents.chunk(latents.shape[0], dim=0))
|
||||
batch_chunks = list(batch.chunk(batch.shape[0], dim=0))
|
||||
out_chunks = []
|
||||
for i in range(len(latent_chunks)):
|
||||
try:
|
||||
do_rotate = random.randint(0, 3)
|
||||
do_flip_x = random.randint(0, 1)
|
||||
do_flip_y = random.randint(0, 1)
|
||||
do_scale = random.randint(0, 1)
|
||||
if do_rotate > 0:
|
||||
latent_chunks[i] = torch.rot90(latent_chunks[i], do_rotate, (2, 3))
|
||||
batch_chunks[i] = torch.rot90(batch_chunks[i], do_rotate, (2, 3))
|
||||
if do_flip_x > 0:
|
||||
latent_chunks[i] = torch.flip(latent_chunks[i], [2])
|
||||
batch_chunks[i] = torch.flip(batch_chunks[i], [2])
|
||||
if do_flip_y > 0:
|
||||
latent_chunks[i] = torch.flip(latent_chunks[i], [3])
|
||||
batch_chunks[i] = torch.flip(batch_chunks[i], [3])
|
||||
# if do_scale > 0:
|
||||
# scale = 2
|
||||
# start_latent_h = latent_chunks[i].shape[2]
|
||||
# start_latent_w = latent_chunks[i].shape[3]
|
||||
# start_batch_h = batch_chunks[i].shape[2]
|
||||
# start_batch_w = batch_chunks[i].shape[3]
|
||||
# latent_chunks[i] = torch.nn.functional.interpolate(latent_chunks[i], scale_factor=scale, mode='bilinear', align_corners=False)
|
||||
# batch_chunks[i] = torch.nn.functional.interpolate(batch_chunks[i], scale_factor=scale, mode='bilinear', align_corners=False)
|
||||
# # random crop. latent is smaller than match but crops need to match
|
||||
# latent_x = random.randint(0, latent_chunks[i].shape[2] - start_latent_h)
|
||||
# latent_y = random.randint(0, latent_chunks[i].shape[3] - start_latent_w)
|
||||
# batch_x = latent_x * self.vae_scale_factor
|
||||
# batch_y = latent_y * self.vae_scale_factor
|
||||
|
||||
# # crop
|
||||
# latent_chunks[i] = latent_chunks[i][:, :, latent_x:latent_x + start_latent_h, latent_y:latent_y + start_latent_w]
|
||||
# batch_chunks[i] = batch_chunks[i][:, :, batch_x:batch_x + start_batch_h, batch_y:batch_y + start_batch_w]
|
||||
except Exception as e:
|
||||
print(f"Error processing image {i}: {e}")
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
out_chunks.append(latent_chunks[i])
|
||||
latents = torch.cat(out_chunks, dim=0)
|
||||
# do dropout
|
||||
if self.dropout > 0:
|
||||
forward_latents = channel_dropout(latents, self.dropout)
|
||||
else:
|
||||
forward_latents = latents
|
||||
batch = torch.cat(batch_chunks, dim=0)
|
||||
|
||||
else:
|
||||
latents.detach().requires_grad_(True)
|
||||
forward_latents = latents
|
||||
|
||||
forward_latents = forward_latents.to(self.device, dtype=self.torch_dtype)
|
||||
|
||||
if not self.train_encoder:
|
||||
# detach latents if not training encoder
|
||||
forward_latents = forward_latents.detach()
|
||||
|
||||
pred = self.vae.decode(latents).sample
|
||||
|
||||
with torch.no_grad():
|
||||
show_tensors(
|
||||
pred.clamp(-1, 1).clone(),
|
||||
"combined tensor"
|
||||
)
|
||||
pred = self.vae.decode(forward_latents).sample
|
||||
|
||||
# Run through VGG19
|
||||
if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
|
||||
if self.style_weight > 0 or self.content_weight > 0:
|
||||
stacked = torch.cat([pred, batch], dim=0)
|
||||
stacked = (stacked / 2 + 0.5).clamp(0, 1)
|
||||
self.vgg_19(stacked)
|
||||
|
||||
if self.use_critic:
|
||||
critic_d_loss = self.critic.step(self.vgg19_pool_4.tensor.detach())
|
||||
stacked = torch.cat([pred, batch], dim=0)
|
||||
critic_d_loss = self.critic.step(stacked.detach())
|
||||
else:
|
||||
critic_d_loss = 0.0
|
||||
|
||||
@@ -489,7 +683,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight
|
||||
pattern_loss = self.get_pattern_loss(pred, batch) * self.pattern_weight
|
||||
if self.use_critic:
|
||||
critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight
|
||||
stacked = torch.cat([pred, batch], dim=0)
|
||||
critic_gen_loss = self.critic.get_critic_loss(stacked) * self.critic_weight
|
||||
|
||||
# do not let abs critic gen loss be higher than abs lpips * 0.1 if using it
|
||||
if self.lpips_weight > 0:
|
||||
@@ -502,8 +697,42 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
critic_gen_loss *= crit_g_scaler
|
||||
else:
|
||||
critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
if self.mv_loss_weight > 0:
|
||||
mv_loss = self.get_mean_variance_loss(latents) * self.mv_loss_weight
|
||||
else:
|
||||
mv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
if self.ltv_weight > 0:
|
||||
ltv_loss = self.get_ltv_loss(latents) * self.ltv_weight
|
||||
else:
|
||||
ltv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
if self.lpm_weight > 0:
|
||||
lpm_loss = self.get_latent_pixel_matching_loss(latents, batch) * self.lpm_weight
|
||||
else:
|
||||
lpm_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss
|
||||
loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + mv_loss + ltv_loss
|
||||
|
||||
# check if loss is NaN or Inf
|
||||
if torch.isnan(loss) or torch.isinf(loss):
|
||||
self.print(f"Loss is NaN or Inf, stopping at step {self.step_num}")
|
||||
self.print(f" - Style loss: {style_loss.item()}")
|
||||
self.print(f" - Content loss: {content_loss.item()}")
|
||||
self.print(f" - KLD loss: {kld_loss.item()}")
|
||||
self.print(f" - MSE loss: {mse_loss.item()}")
|
||||
self.print(f" - LPIPS loss: {lpips_loss.item()}")
|
||||
self.print(f" - TV loss: {tv_loss.item()}")
|
||||
self.print(f" - Pattern loss: {pattern_loss.item()}")
|
||||
self.print(f" - Critic gen loss: {critic_gen_loss.item()}")
|
||||
self.print(f" - Critic D loss: {critic_d_loss}")
|
||||
self.print(f" - Mean variance loss: {mv_loss.item()}")
|
||||
self.print(f" - Latent TV loss: {ltv_loss.item()}")
|
||||
self.print(f" - Latent pixel matching loss: {lpm_loss.item()}")
|
||||
self.print(f" - Total loss: {loss.item()}")
|
||||
self.print(f" - Stopping training")
|
||||
exit(1)
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
@@ -533,8 +762,17 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
loss_string += f" crG: {critic_gen_loss.item():.2e}"
|
||||
if self.use_critic:
|
||||
loss_string += f" crD: {critic_d_loss:.2e}"
|
||||
if self.mv_loss_weight > 0:
|
||||
loss_string += f" mvl: {mv_loss:.2e}"
|
||||
if self.ltv_weight > 0:
|
||||
loss_string += f" ltv: {ltv_loss:.2e}"
|
||||
if self.lpm_weight > 0:
|
||||
loss_string += f" lpm: {lpm_loss:.2e}"
|
||||
|
||||
|
||||
if self.optimizer_type.startswith('dadaptation') or \
|
||||
if hasattr(optimizer, 'get_avg_learning_rate'):
|
||||
learning_rate = optimizer.get_avg_learning_rate()
|
||||
elif self.optimizer_type.startswith('dadaptation') or \
|
||||
self.optimizer_type.lower().startswith('prodigy'):
|
||||
learning_rate = (
|
||||
optimizer.param_groups[0]["d"] *
|
||||
@@ -562,6 +800,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
epoch_losses["ptn"].append(pattern_loss.item())
|
||||
epoch_losses["crG"].append(critic_gen_loss.item())
|
||||
epoch_losses["crD"].append(critic_d_loss)
|
||||
epoch_losses["mvl"].append(mv_loss.item())
|
||||
epoch_losses["ltv"].append(ltv_loss.item())
|
||||
epoch_losses["lpm"].append(lpm_loss.item())
|
||||
|
||||
log_losses["total"].append(loss_value)
|
||||
log_losses["lpips"].append(lpips_loss.item())
|
||||
@@ -573,6 +814,9 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
log_losses["ptn"].append(pattern_loss.item())
|
||||
log_losses["crG"].append(critic_gen_loss.item())
|
||||
log_losses["crD"].append(critic_d_loss)
|
||||
log_losses["mvl"].append(mv_loss.item())
|
||||
log_losses["ltv"].append(ltv_loss.item())
|
||||
log_losses["lpm"].append(lpm_loss.item())
|
||||
|
||||
# don't do on first step
|
||||
if self.step_num != start_step:
|
||||
|
||||
229
jobs/process/models/critic.py
Normal file
229
jobs/process/models/critic.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import glob
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from toolkit.losses import get_gradient_penalty
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.optimizer import get_optimizer
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
|
||||
class MeanReduce(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, inputs):
|
||||
# global mean over spatial dims (keeps channel/batch)
|
||||
return torch.mean(inputs, dim=(2, 3), keepdim=True)
|
||||
|
||||
|
||||
class SelfAttention2d(nn.Module):
|
||||
"""
|
||||
Lightweight self-attention layer (SAGAN-style) that keeps spatial
|
||||
resolution unchanged. Adds minimal params / compute but improves
|
||||
long-range modelling – helpful for variable-sized inputs.
|
||||
"""
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.query = nn.Conv1d(in_channels, in_channels // 8, 1)
|
||||
self.key = nn.Conv1d(in_channels, in_channels // 8, 1)
|
||||
self.value = nn.Conv1d(in_channels, in_channels, 1)
|
||||
self.gamma = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
flat = x.view(B, C, H * W) # (B,C,N)
|
||||
q = self.query(flat).permute(0, 2, 1) # (B,N,C//8)
|
||||
k = self.key(flat) # (B,C//8,N)
|
||||
attn = torch.bmm(q, k) # (B,N,N)
|
||||
attn = attn.softmax(dim=-1) # softmax along last dim
|
||||
v = self.value(flat) # (B,C,N)
|
||||
out = torch.bmm(v, attn.permute(0, 2, 1)) # (B,C,N)
|
||||
out = out.view(B, C, H, W) # restore spatial dims
|
||||
return self.gamma * out + x # residual
|
||||
|
||||
|
||||
class CriticModel(nn.Module):
|
||||
def __init__(self, base_channels: int = 64):
|
||||
super().__init__()
|
||||
|
||||
def sn_conv(in_c, out_c, k, s, p):
|
||||
return nn.utils.spectral_norm(
|
||||
nn.Conv2d(in_c, out_c, kernel_size=k, stride=s, padding=p)
|
||||
)
|
||||
|
||||
layers = [
|
||||
# initial down-sample
|
||||
sn_conv(3, base_channels, 3, 2, 1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
]
|
||||
|
||||
in_c = base_channels
|
||||
# progressive downsamples ×3 (64→128→256→512)
|
||||
for _ in range(3):
|
||||
out_c = min(in_c * 2, 1024)
|
||||
layers += [
|
||||
sn_conv(in_c, out_c, 3, 2, 1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
]
|
||||
# single attention block after reaching 256 channels
|
||||
if out_c == 256:
|
||||
layers += [SelfAttention2d(out_c)]
|
||||
in_c = out_c
|
||||
|
||||
# extra depth (keeps spatial size)
|
||||
layers += [
|
||||
sn_conv(in_c, 1024, 3, 1, 1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
|
||||
# final 1-channel prediction map
|
||||
sn_conv(1024, 1, 3, 1, 1),
|
||||
MeanReduce(), # → (B,1,1,1)
|
||||
nn.Flatten(), # → (B,1)
|
||||
]
|
||||
|
||||
self.main = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, inputs):
|
||||
# force full-precision inside AMP ctx for stability
|
||||
with torch.cuda.amp.autocast(False):
|
||||
return self.main(inputs.float())
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from jobs.process.TrainVAEProcess import TrainVAEProcess
|
||||
from jobs.process.TrainESRGANProcess import TrainESRGANProcess
|
||||
|
||||
|
||||
class Critic:
|
||||
process: Union['TrainVAEProcess', 'TrainESRGANProcess']
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate=1e-5,
|
||||
device='cpu',
|
||||
optimizer='adam',
|
||||
num_critic_per_gen=1,
|
||||
dtype='float32',
|
||||
lambda_gp=10,
|
||||
start_step=0,
|
||||
warmup_steps=1000,
|
||||
process=None,
|
||||
optimizer_params=None,
|
||||
):
|
||||
self.learning_rate = learning_rate
|
||||
self.device = device
|
||||
self.optimizer_type = optimizer
|
||||
self.num_critic_per_gen = num_critic_per_gen
|
||||
self.dtype = dtype
|
||||
self.torch_dtype = get_torch_dtype(self.dtype)
|
||||
self.process = process
|
||||
self.model = None
|
||||
self.optimizer = None
|
||||
self.scheduler = None
|
||||
self.warmup_steps = warmup_steps
|
||||
self.start_step = start_step
|
||||
self.lambda_gp = lambda_gp
|
||||
|
||||
if optimizer_params is None:
|
||||
optimizer_params = {}
|
||||
self.optimizer_params = optimizer_params
|
||||
self.print = self.process.print
|
||||
print(f" Critic config: {self.__dict__}")
|
||||
|
||||
def setup(self):
|
||||
self.model = CriticModel().to(self.device)
|
||||
self.load_weights()
|
||||
self.model.train()
|
||||
self.model.requires_grad_(True)
|
||||
params = self.model.parameters()
|
||||
self.optimizer = get_optimizer(
|
||||
params,
|
||||
self.optimizer_type,
|
||||
self.learning_rate,
|
||||
optimizer_params=self.optimizer_params,
|
||||
)
|
||||
self.scheduler = torch.optim.lr_scheduler.ConstantLR(
|
||||
self.optimizer,
|
||||
total_iters=self.process.max_steps * self.num_critic_per_gen,
|
||||
factor=1,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
def load_weights(self):
|
||||
path_to_load = None
|
||||
self.print(f"Critic: Looking for latest checkpoint in {self.process.save_root}")
|
||||
files = glob.glob(os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}*.safetensors"))
|
||||
if files:
|
||||
latest_file = max(files, key=os.path.getmtime)
|
||||
print(f" - Latest checkpoint is: {latest_file}")
|
||||
path_to_load = latest_file
|
||||
else:
|
||||
self.print(" - No checkpoint found, starting from scratch")
|
||||
if path_to_load:
|
||||
self.model.load_state_dict(load_file(path_to_load))
|
||||
|
||||
def save(self, step=None):
|
||||
self.process.update_training_metadata()
|
||||
save_meta = get_meta_for_safetensors(self.process.meta, self.process.job.name)
|
||||
step_num = f"_{str(step).zfill(9)}" if step is not None else ''
|
||||
save_path = os.path.join(
|
||||
self.process.save_root, f"CRITIC_{self.process.job.name}{step_num}.safetensors"
|
||||
)
|
||||
save_file(self.model.state_dict(), save_path, save_meta)
|
||||
self.print(f"Saved critic to {save_path}")
|
||||
|
||||
def get_critic_loss(self, vgg_output):
|
||||
# (caller still passes combined [pred|target] images)
|
||||
if self.start_step > self.process.step_num:
|
||||
return torch.tensor(0.0, dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
warmup_scaler = 1.0
|
||||
if self.process.step_num < self.start_step + self.warmup_steps:
|
||||
warmup_scaler = (self.process.step_num - self.start_step) / self.warmup_steps
|
||||
|
||||
self.model.eval()
|
||||
self.model.requires_grad_(False)
|
||||
|
||||
vgg_pred, _ = torch.chunk(vgg_output.float(), 2, dim=0)
|
||||
stacked_output = self.model(vgg_pred)
|
||||
return (-torch.mean(stacked_output)) * warmup_scaler
|
||||
|
||||
def step(self, vgg_output):
|
||||
self.model.train()
|
||||
self.model.requires_grad_(True)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
critic_losses = []
|
||||
inputs = vgg_output.detach().to(self.device, dtype=torch.float32)
|
||||
|
||||
vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
|
||||
stacked_output = self.model(inputs).float()
|
||||
out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
|
||||
|
||||
# hinge loss + gradient penalty
|
||||
loss_real = torch.relu(1.0 - out_target).mean()
|
||||
loss_fake = torch.relu(1.0 + out_pred).mean()
|
||||
gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
|
||||
critic_loss = loss_real + loss_fake + self.lambda_gp * gradient_penalty
|
||||
|
||||
critic_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
||||
self.optimizer.step()
|
||||
self.scheduler.step()
|
||||
critic_losses.append(critic_loss.item())
|
||||
|
||||
return float(np.mean(critic_losses))
|
||||
|
||||
def get_lr(self):
|
||||
if self.optimizer_type.startswith('dadaptation'):
|
||||
return (
|
||||
self.optimizer.param_groups[0]["d"]
|
||||
* self.optimizer.param_groups[0]["lr"]
|
||||
)
|
||||
return self.optimizer.param_groups[0]["lr"]
|
||||
@@ -33,11 +33,20 @@ class Vgg19Critic(nn.Module):
|
||||
super(Vgg19Critic, self).__init__()
|
||||
self.main = nn.Sequential(
|
||||
# input (bs, 512, 32, 32)
|
||||
nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1),
|
||||
# nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1),
|
||||
nn.utils.spectral_norm( # SN keeps D’s scale in check
|
||||
nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1)
|
||||
),
|
||||
nn.LeakyReLU(0.2), # (bs, 512, 16, 16)
|
||||
nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1),
|
||||
# nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1),
|
||||
nn.utils.spectral_norm(
|
||||
nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1)
|
||||
),
|
||||
nn.LeakyReLU(0.2), # (bs, 512, 8, 8)
|
||||
nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1),
|
||||
# nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1),
|
||||
nn.utils.spectral_norm(
|
||||
nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1)
|
||||
),
|
||||
# (bs, 1, 4, 4)
|
||||
MeanReduce(), # (bs, 1, 1, 1)
|
||||
nn.Flatten(), # (bs, 1)
|
||||
@@ -47,7 +56,9 @@ class Vgg19Critic(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.main(inputs)
|
||||
# return self.main(inputs)
|
||||
with torch.cuda.amp.autocast(False):
|
||||
return self.main(inputs.float())
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -92,7 +103,7 @@ class Critic:
|
||||
print(f" Critic config: {self.__dict__}")
|
||||
|
||||
def setup(self):
|
||||
self.model = Vgg19Critic().to(self.device, dtype=self.torch_dtype)
|
||||
self.model = Vgg19Critic().to(self.device)
|
||||
self.load_weights()
|
||||
self.model.train()
|
||||
self.model.requires_grad_(True)
|
||||
@@ -142,7 +153,8 @@ class Critic:
|
||||
# set model to not train for generator loss
|
||||
self.model.eval()
|
||||
self.model.requires_grad_(False)
|
||||
vgg_pred, vgg_target = torch.chunk(vgg_output, 2, dim=0)
|
||||
# vgg_pred, vgg_target = torch.chunk(vgg_output, 2, dim=0)
|
||||
vgg_pred, vgg_target = torch.chunk(vgg_output.float(), 2, dim=0)
|
||||
|
||||
# run model
|
||||
stacked_output = self.model(vgg_pred)
|
||||
@@ -157,20 +169,34 @@ class Critic:
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
critic_losses = []
|
||||
inputs = vgg_output.detach()
|
||||
inputs = inputs.to(self.device, dtype=self.torch_dtype)
|
||||
# inputs = vgg_output.detach()
|
||||
# inputs = inputs.to(self.device, dtype=self.torch_dtype)
|
||||
inputs = vgg_output.detach().to(self.device, dtype=torch.float32)
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0)
|
||||
|
||||
# stacked_output = self.model(inputs).float()
|
||||
# out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
|
||||
|
||||
# # Compute gradient penalty
|
||||
# gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
|
||||
|
||||
# # Compute WGAN-GP critic loss
|
||||
# critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty
|
||||
|
||||
stacked_output = self.model(inputs).float()
|
||||
out_pred, out_target = torch.chunk(stacked_output, 2, dim=0)
|
||||
|
||||
# Compute gradient penalty
|
||||
# ── hinge loss ──
|
||||
loss_real = torch.relu(1.0 - out_target).mean()
|
||||
loss_fake = torch.relu(1.0 + out_pred).mean()
|
||||
|
||||
# gradient penalty (unchanged helper)
|
||||
gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device)
|
||||
|
||||
# Compute WGAN-GP critic loss
|
||||
critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty
|
||||
critic_loss = loss_real + loss_fake + self.lambda_gp * gradient_penalty
|
||||
|
||||
critic_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
||||
self.optimizer.step()
|
||||
|
||||
Reference in New Issue
Block a user