diff --git a/.vscode/launch.json b/.vscode/launch.json index 483703eb..02d5cacf 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -16,6 +16,22 @@ "console": "integratedTerminal", "justMyCode": false }, + { + "name": "Run current config (cuda:1)", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/run.py", + "args": [ + "${file}" + ], + "env": { + "CUDA_LAUNCH_BLOCKING": "1", + "DEBUG_TOOLKIT": "1", + "CUDA_VISIBLE_DEVICES": "1" + }, + "console": "integratedTerminal", + "justMyCode": false + }, { "name": "Python: Debug Current File", "type": "python", diff --git a/build_and_push_docker_dev b/build_and_push_docker_dev new file mode 100644 index 00000000..6a1a17d0 --- /dev/null +++ b/build_and_push_docker_dev @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +VERSION=dev +GIT_COMMIT=dev + +echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo." +echo "Building version: $VERSION and latest" +# wait 2 seconds +sleep 2 + +# Build the image with cache busting +docker build --build-arg CACHEBUST=$(date +%s) -t aitoolkit:$VERSION -f docker/Dockerfile . + +# Tag with version and latest +docker tag aitoolkit:$VERSION ostris/aitoolkit:$VERSION + +# Push both tags +echo "Pushing images to Docker Hub..." +docker push ostris/aitoolkit:$VERSION + +echo "Successfully built and pushed ostris/aitoolkit:$VERSION" \ No newline at end of file diff --git a/extensions_built_in/diffusion_models/chroma/chroma_model.py b/extensions_built_in/diffusion_models/chroma/chroma_model.py index 0072e2e5..3888909a 100644 --- a/extensions_built_in/diffusion_models/chroma/chroma_model.py +++ b/extensions_built_in/diffusion_models/chroma/chroma_model.py @@ -141,19 +141,38 @@ class ChromaModel(BaseModel): extras_path = 'ostris/Flex.1-alpha' self.print_and_status_update("Loading transformer") + + chroma_state_dict = load_file(model_path, 'cpu') + + # determine number of double and single blocks + double_blocks = 0 + single_blocks = 0 + for key in chroma_state_dict.keys(): + if "double_blocks" in key: + block_num = int(key.split(".")[1]) + 1 + if block_num > double_blocks: + double_blocks = block_num + elif "single_blocks" in key: + block_num = int(key.split(".")[1]) + 1 + if block_num > single_blocks: + single_blocks = block_num + print(f"Double Blocks: {double_blocks}") + print(f"Single Blocks: {single_blocks}") + chroma_params.depth = double_blocks + chroma_params.depth_single_blocks = single_blocks transformer = Chroma(chroma_params) # add dtype, not sure why it doesnt have it transformer.dtype = dtype - - chroma_state_dict = load_file(model_path, 'cpu') # load the state dict into the model transformer.load_state_dict(chroma_state_dict) transformer.to(self.quantize_device, dtype=dtype) transformer.config = FakeConfig() + transformer.config.num_layers = double_blocks + transformer.config.num_single_layers = single_blocks if self.model_config.quantize: # patch the state dict method @@ -392,6 +411,8 @@ class ChromaModel(BaseModel): return self.text_encoder[1].encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad def save_model(self, output_path, meta, save_dtype): + if not output_path.endswith(".safetensors"): + output_path = output_path + ".safetensors" # only save the unet transformer: Chroma = unwrap_model(self.model) state_dict = transformer.state_dict() diff --git a/extensions_built_in/diffusion_models/chroma/pipeline.py b/extensions_built_in/diffusion_models/chroma/pipeline.py index 52b9b817..215be798 100644 --- a/extensions_built_in/diffusion_models/chroma/pipeline.py +++ b/extensions_built_in/diffusion_models/chroma/pipeline.py @@ -61,6 +61,8 @@ class ChromaPipeline(FluxPipeline): batch_size = prompt_embeds.shape[0] device = self._execution_device + if isinstance(device, str): + device = torch.device(device) text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=torch.bfloat16) if guidance_scale > 1.00001: diff --git a/extensions_built_in/diffusion_models/chroma/src/math.py b/extensions_built_in/diffusion_models/chroma/src/math.py index b46bca57..31205341 100644 --- a/extensions_built_in/diffusion_models/chroma/src/math.py +++ b/extensions_built_in/diffusion_models/chroma/src/math.py @@ -2,14 +2,32 @@ import torch from einops import rearrange from torch import Tensor +# Flash-Attention 2 (optional) +try: + from flash_attn.flash_attn_interface import flash_attn_func # type: ignore + _HAS_FLASH = True +except (ImportError, ModuleNotFoundError): + _HAS_FLASH = False + def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask: Tensor) -> Tensor: q, k = apply_rope(q, k, pe) # mask should have shape [B, H, L, D] - x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) - x = rearrange(x, "B H L D -> B L (H D)") + if _HAS_FLASH and mask is None and q.is_cuda: + x = flash_attn_func( + rearrange(q, "B H L D -> B L H D").contiguous(), + rearrange(k, "B H L D -> B L H D").contiguous(), + rearrange(v, "B H L D -> B L H D").contiguous(), + dropout_p=0.0, + softmax_scale=None, + causal=False, + ) + x = rearrange(x, "B L H D -> B H L D") + else: + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask) + x = rearrange(x, "B H L D -> B L (H D)") return x diff --git a/extensions_built_in/diffusion_models/chroma/src/model.py b/extensions_built_in/diffusion_models/chroma/src/model.py index 3b6c29bb..33cdbe62 100644 --- a/extensions_built_in/diffusion_models/chroma/src/model.py +++ b/extensions_built_in/diffusion_models/chroma/src/model.py @@ -96,6 +96,7 @@ class Chroma(nn.Module): self.params = params self.in_channels = params.in_channels self.out_channels = self.in_channels + self.gradient_checkpointing = False if params.hidden_size % params.num_heads != 0: raise ValueError( f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" @@ -162,11 +163,14 @@ class Chroma(nn.Module): torch.tensor(list(range(self.mod_index_length)), device="cpu"), persistent=False, ) - + @property def device(self): # Get the device of the module (assumes all parameters are on the same device) return next(self.parameters()).device + + def enable_gradient_checkpointing(self, enable: bool = True): + self.gradient_checkpointing = enable def forward( self, @@ -246,8 +250,7 @@ class Chroma(nn.Module): txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"] double_mod = [img_mod, txt_mod] - # just in case in different GPU for simple pipeline parallel - if self.training: + if torch.is_grad_enabled() and self.gradient_checkpointing: img.requires_grad_(True) img, txt = ckpt.checkpoint( block, img, txt, pe, double_mod, txt_img_mask @@ -260,7 +263,7 @@ class Chroma(nn.Module): img = torch.cat((txt, img), 1) for i, block in enumerate(self.single_blocks): single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] - if self.training: + if torch.is_grad_enabled() and self.gradient_checkpointing: img.requires_grad_(True) img = ckpt.checkpoint(block, img, pe, single_mod, txt_img_mask) else: diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 8a8ba738..7468e0fc 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -35,6 +35,7 @@ import math from toolkit.train_tools import precondition_model_outputs_flow_match from toolkit.models.diffusion_feature_extraction import DiffusionFeatureExtractor, load_dfe from toolkit.util.wavelet_loss import wavelet_loss +import torch.nn.functional as F def flush(): @@ -60,6 +61,7 @@ class SDTrainer(BaseSDTrainProcess): self._clip_image_embeds_unconditional: Union[List[str], None] = None self.negative_prompt_pool: Union[List[str], None] = None self.batch_negative_prompt: Union[List[str], None] = None + self.cfm_cache = None self.is_bfloat = self.train_config.dtype == "bfloat16" or self.train_config.dtype == "bf16" @@ -197,7 +199,7 @@ class SDTrainer(BaseSDTrainProcess): flush() if self.train_config.diffusion_feature_extractor_path is not None: - vae = None + vae = self.sd.vae # if not (self.model_config.arch in ["flux"]) or self.sd.vae.__class__.__name__ == "AutoencoderPixelMixer": # vae = self.sd.vae self.dfe = load_dfe(self.train_config.diffusion_feature_extractor_path, vae=vae) @@ -756,13 +758,13 @@ class SDTrainer(BaseSDTrainProcess): pass def predict_noise( - self, - noisy_latents: torch.Tensor, - timesteps: Union[int, torch.Tensor] = 1, - conditional_embeds: Union[PromptEmbeds, None] = None, - unconditional_embeds: Union[PromptEmbeds, None] = None, - batch: Optional['DataLoaderBatchDTO'] = None, - **kwargs, + self, + noisy_latents: torch.Tensor, + timesteps: Union[int, torch.Tensor] = 1, + conditional_embeds: Union[PromptEmbeds, None] = None, + unconditional_embeds: Union[PromptEmbeds, None] = None, + batch: Optional['DataLoaderBatchDTO'] = None, + **kwargs, ): dtype = get_torch_dtype(self.train_config.dtype) return self.sd.predict_noise( @@ -778,6 +780,81 @@ class SDTrainer(BaseSDTrainProcess): batch=batch, **kwargs ) + + def cfm_augment_tensors( + self, + images: torch.Tensor + ) -> torch.Tensor: + if self.cfm_cache is None: + # flip the current one. Only need this for first time + self.cfm_cache = torch.flip(images, [3]).clone() + augmented_tensor_list = [] + for i in range(images.shape[0]): + # get a random one + idx = random.randint(0, self.cfm_cache.shape[0] - 1) + augmented_tensor_list.append(self.cfm_cache[idx:idx + 1]) + augmented = torch.cat(augmented_tensor_list, dim=0) + # resize to match the input + augmented = torch.nn.functional.interpolate(augmented, size=(images.shape[2], images.shape[3]), mode='bilinear') + self.cfm_cache = images.clone() + return augmented + + def get_cfm_loss( + self, + noisy_latents: torch.Tensor, + noise: torch.Tensor, + noise_pred: torch.Tensor, + conditional_embeds: PromptEmbeds, + timesteps: torch.Tensor, + batch: 'DataLoaderBatchDTO', + alpha: float = 0.1, + ): + dtype = get_torch_dtype(self.train_config.dtype) + if hasattr(self.sd, 'get_loss_target'): + target = self.sd.get_loss_target( + noise=noise, + batch=batch, + timesteps=timesteps, + ).detach() + + elif self.sd.is_flow_matching: + # forward ODE + target = (noise - batch.latents).detach() + else: + raise ValueError("CFM loss only works with flow matching") + fm_loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + with torch.no_grad(): + # we need to compute the contrast + cfm_batch_tensors = self.cfm_augment_tensors(batch.tensor).to(self.device_torch, dtype=dtype) + cfm_latents = self.sd.encode_images(cfm_batch_tensors).to(self.device_torch, dtype=dtype) + cfm_noisy_latents = self.sd.add_noise( + original_samples=cfm_latents, + noise=noise, + timesteps=timesteps, + ) + cfm_pred = self.predict_noise( + noisy_latents=cfm_noisy_latents, + timesteps=timesteps, + conditional_embeds=conditional_embeds, + unconditional_embeds=None, + batch=batch, + ) + + # v_neg = torch.nn.functional.normalize(cfm_pred.float(), dim=1) + # v_pos = torch.nn.functional.normalize(noise_pred.float(), dim=1) # shape: (B, C, H, W) + + # # Compute cosine similarity at each pixel + # sim = (v_pos * v_neg).sum(dim=1) # shape: (B, H, W) + + cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6) + # Compute cosine similarity at each pixel + sim = cos(cfm_pred.float(), noise_pred.float()) # shape: (B, H, W) + + # Average over spatial dimensions, then batch + contrastive_loss = -sim.mean() + + loss = fm_loss.mean() + alpha * contrastive_loss + return loss def train_single_accumulation(self, batch: DataLoaderBatchDTO): self.timer.start('preprocess_batch') @@ -1431,6 +1508,44 @@ class SDTrainer(BaseSDTrainProcess): if self.adapter and isinstance(self.adapter, CustomAdapter): noisy_latents = self.adapter.condition_noisy_latents(noisy_latents, batch) + if self.train_config.timestep_type == 'next_sample': + with self.timer('next_sample_step'): + with torch.no_grad(): + + stepped_timestep_indicies = [self.sd.noise_scheduler.index_for_timestep(t) + 1 for t in timesteps] + stepped_timesteps = [self.sd.noise_scheduler.timesteps[x] for x in stepped_timestep_indicies] + stepped_timesteps = torch.stack(stepped_timesteps, dim=0) + + # do a sample at the current timestep and step it, then determine new noise + next_sample_pred = self.predict_noise( + noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), + timesteps=timesteps, + conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype), + unconditional_embeds=unconditional_embeds, + batch=batch, + **pred_kwargs + ) + stepped_latents = self.sd.step_scheduler( + next_sample_pred, + noisy_latents, + timesteps, + self.sd.noise_scheduler + ) + # stepped latents is our new noisy latents. Now we need to determine noise in the current sample + noisy_latents = stepped_latents + original_samples = batch.latents.to(self.device_torch, dtype=dtype) + # todo calc next timestep, for now this may work as it + t_01 = (stepped_timesteps / 1000).to(original_samples.device) + if len(stepped_latents.shape) == 4: + t_01 = t_01.view(-1, 1, 1, 1) + elif len(stepped_latents.shape) == 5: + t_01 = t_01.view(-1, 1, 1, 1, 1) + else: + raise ValueError("Unknown stepped latents shape", stepped_latents.shape) + next_sample_noise = (stepped_latents - (1.0 - t_01) * original_samples) / t_01 + noise = next_sample_noise + timesteps = stepped_timesteps + with self.timer('predict_unet'): noise_pred = self.predict_noise( noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), @@ -1450,15 +1565,25 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.diff_output_preservation and not do_inverted_masked_prior: prior_to_calculate_loss = None - loss = self.calculate_loss( - noise_pred=noise_pred, - noise=noise, - noisy_latents=noisy_latents, - timesteps=timesteps, - batch=batch, - mask_multiplier=mask_multiplier, - prior_pred=prior_to_calculate_loss, - ) + if self.train_config.loss_type == 'cfm': + loss = self.get_cfm_loss( + noisy_latents=noisy_latents, + noise=noise, + noise_pred=noise_pred, + conditional_embeds=conditional_embeds, + timesteps=timesteps, + batch=batch, + ) + else: + loss = self.calculate_loss( + noise_pred=noise_pred, + noise=noise, + noisy_latents=noisy_latents, + timesteps=timesteps, + batch=batch, + mask_multiplier=mask_multiplier, + prior_pred=prior_to_calculate_loss, + ) if self.train_config.diff_output_preservation: # send the loss backwards otherwise checkpointing will fail diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 1ccb0c3d..b4f768d9 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -629,7 +629,10 @@ class BaseSDTrainProcess(BaseTrainProcess): try: filename = f'optimizer.pt' file_path = os.path.join(self.save_root, filename) - state_dict = unwrap_model(self.optimizer).state_dict() + try: + state_dict = unwrap_model(self.optimizer).state_dict() + except Exception as e: + state_dict = self.optimizer.state_dict() torch.save(state_dict, file_path) print_acc(f"Saved optimizer to {file_path}") except Exception as e: @@ -931,16 +934,16 @@ class BaseSDTrainProcess(BaseTrainProcess): noise_offset=self.train_config.noise_offset, ).to(self.device_torch, dtype=dtype) - if self.train_config.random_noise_shift > 0.0: - # get random noise -1 to 1 - noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device, - dtype=noise.dtype) * 2 - 1 + # if self.train_config.random_noise_shift > 0.0: + # # get random noise -1 to 1 + # noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device, + # dtype=noise.dtype) * 2 - 1 - # multiply by shift amount - noise_shift *= self.train_config.random_noise_shift + # # multiply by shift amount + # noise_shift *= self.train_config.random_noise_shift - # add to noise - noise += noise_shift + # # add to noise + # noise += noise_shift if self.train_config.blended_blur_noise: noise = get_blended_blur_noise( @@ -1011,6 +1014,7 @@ class BaseSDTrainProcess(BaseTrainProcess): dtype = get_torch_dtype(self.train_config.dtype) imgs = None is_reg = any(batch.get_is_reg_list()) + cfm_batch = None if batch.tensor is not None: imgs = batch.tensor imgs = imgs.to(self.device_torch, dtype=dtype) @@ -1118,8 +1122,13 @@ class BaseSDTrainProcess(BaseTrainProcess): if timestep_type is None: timestep_type = self.train_config.timestep_type + if self.train_config.timestep_type == 'next_sample': + # simulate a sample + num_train_timesteps = self.train_config.next_sample_timesteps + timestep_type = 'shift' + patch_size = 1 - if self.sd.is_flux: + if self.sd.is_flux or 'flex' in self.sd.arch: # flux is a patch size of 1, but latents are divided by 2, so we need to double it patch_size = 2 elif hasattr(self.sd.unet.config, 'patch_size'): @@ -1142,7 +1151,15 @@ class BaseSDTrainProcess(BaseTrainProcess): content_or_style = self.train_config.content_or_style_reg # if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content': - if content_or_style in ['style', 'content']: + if self.train_config.timestep_type == 'next_sample': + timestep_indices = torch.randint( + 0, + num_train_timesteps - 2, # -1 for 0 idx, -1 so we can step + (batch_size,), + device=self.device_torch + ) + timestep_indices = timestep_indices.long() + elif content_or_style in ['style', 'content']: # this is from diffusers training code # Cubic sampling for favoring later or earlier timesteps # For more details about why cubic sampling is used for content / structure, @@ -1169,7 +1186,7 @@ class BaseSDTrainProcess(BaseTrainProcess): min_noise_steps + 1, max_noise_steps - 1 ) - + elif content_or_style == 'balanced': if min_noise_steps == max_noise_steps: timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps @@ -1185,16 +1202,6 @@ class BaseSDTrainProcess(BaseTrainProcess): else: raise ValueError(f"Unknown content_or_style {content_or_style}") - # do flow matching - # if self.sd.is_flow_matching: - # u = compute_density_for_timestep_sampling( - # weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"] - # batch_size=batch_size, - # logit_mean=0.0, - # logit_std=1.0, - # mode_scale=1.29, - # ) - # timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long() # convert the timestep_indices to a timestep timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices] timesteps = torch.stack(timesteps, dim=0) @@ -1218,8 +1225,32 @@ class BaseSDTrainProcess(BaseTrainProcess): latents = unaugmented_latents noise_multiplier = self.train_config.noise_multiplier + + s = (noise.shape[0], noise.shape[1], 1, 1) + if len(noise.shape) == 5: + # if we have a 5d tensor, then we need to do it on a per batch item, per channel basis, per frame + s = (noise.shape[0], noise.shape[1], noise.shape[2], 1, 1) + + if self.train_config.random_noise_multiplier > 0.0: + + # do it on a per batch item, per channel basis + noise_multiplier = 1 + torch.randn( + s, + device=noise.device, + dtype=noise.dtype + ) * self.train_config.random_noise_multiplier noise = noise * noise_multiplier + + if self.train_config.random_noise_shift > 0.0: + # get random noise -1 to 1 + noise_shift = torch.randn( + s, + device=noise.device, + dtype=noise.dtype + ) * self.train_config.random_noise_shift + # add to noise + noise += noise_shift latent_multiplier = self.train_config.latent_multiplier diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index fb6536cd..0b6d5fab 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -7,6 +7,7 @@ from collections import OrderedDict from PIL import Image from PIL.ImageOps import exif_transpose +from einops import rearrange from safetensors.torch import save_file, load_file from torch.utils.data import DataLoader, ConcatDataset import torch @@ -17,18 +18,22 @@ from jobs.process import BaseTrainProcess from toolkit.image_utils import show_tensors from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm from toolkit.data_loader import ImageDataset -from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss +from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss, total_variation from toolkit.metadata import get_meta_for_safetensors from toolkit.optimizer import get_optimizer from toolkit.style import get_style_model_and_losses from toolkit.train_tools import get_torch_dtype from diffusers import AutoencoderKL from tqdm import tqdm +import math +import torchvision.utils import time import numpy as np -from .models.vgg19_critic import Critic +from .models.critic import Critic from torchvision.transforms import Resize import lpips +import random +import traceback IMAGE_TRANSFORMS = transforms.Compose( [ @@ -42,13 +47,21 @@ def unnormalize(tensor): return (tensor / 2 + 0.5).clamp(0, 1) +def channel_dropout(x, p=0.5): + keep_prob = 1 - p + mask = torch.rand(x.size(0), x.size(1), 1, 1, device=x.device, dtype=x.dtype) < keep_prob + mask = mask / keep_prob # scale + return x * mask + + class TrainVAEProcess(BaseTrainProcess): def __init__(self, process_id: int, job, config: OrderedDict): super().__init__(process_id, job, config) self.data_loader = None self.vae = None self.device = self.get_conf('device', self.job.device) - self.vae_path = self.get_conf('vae_path', required=True) + self.vae_path = self.get_conf('vae_path', None) + self.eq_vae = self.get_conf('eq_vae', False) self.datasets_objects = self.get_conf('datasets', required=True) self.batch_size = self.get_conf('batch_size', 1, as_type=int) self.resolution = self.get_conf('resolution', 256, as_type=int) @@ -65,11 +78,25 @@ class TrainVAEProcess(BaseTrainProcess): self.content_weight = self.get_conf('content_weight', 0, as_type=float) self.kld_weight = self.get_conf('kld_weight', 0, as_type=float) self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float) - self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float) + self.mv_loss_weight = self.get_conf('mv_loss_weight', 0, as_type=float) + self.tv_weight = self.get_conf('tv_weight', 0, as_type=float) + self.ltv_weight = self.get_conf('ltv_weight', 0, as_type=float) + self.lpm_weight = self.get_conf('lpm_weight', 0, as_type=float) # latent pixel matching self.lpips_weight = self.get_conf('lpips_weight', 1e0, as_type=float) self.critic_weight = self.get_conf('critic_weight', 1, as_type=float) - self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float) + self.pattern_weight = self.get_conf('pattern_weight', 0, as_type=float) self.optimizer_params = self.get_conf('optimizer_params', {}) + self.vae_config = self.get_conf('vae_config', None) + self.dropout = self.get_conf('dropout', 0.0, as_type=float) + self.train_encoder = self.get_conf('train_encoder', False, as_type=bool) + self.random_scaling = self.get_conf('random_scaling', False, as_type=bool) + + if not self.train_encoder: + # remove losses that only target encoder + self.kld_weight = 0 + self.mv_loss_weight = 0 + self.ltv_weight = 0 + self.lpm_weight = 0 self.blocks_to_train = self.get_conf('blocks_to_train', ['all']) self.torch_dtype = get_torch_dtype(self.dtype) @@ -133,7 +160,11 @@ class TrainVAEProcess(BaseTrainProcess): for dataset in self.datasets_objects: print(f" - Dataset: {dataset['path']}") ds = copy.copy(dataset) - ds['resolution'] = self.resolution + dataset_res = self.resolution + if self.random_scaling: + # scale 2x to allow for random scaling + dataset_res = int(dataset_res * 2) + ds['resolution'] = dataset_res image_dataset = ImageDataset(ds) datasets.append(image_dataset) @@ -142,7 +173,7 @@ class TrainVAEProcess(BaseTrainProcess): concatenated_dataset, batch_size=self.batch_size, shuffle=True, - num_workers=6 + num_workers=16 ) def remove_oldest_checkpoint(self): @@ -153,6 +184,13 @@ class TrainVAEProcess(BaseTrainProcess): for folder in folders[:-max_to_keep]: print(f"Removing {folder}") shutil.rmtree(folder) + # also handle CRITIC_vae_42_000000500.safetensors format for critic + critic_files = glob.glob(os.path.join(self.save_root, f"CRITIC_{self.job.name}*.safetensors")) + if len(critic_files) > max_to_keep: + critic_files.sort(key=os.path.getmtime) + for file in critic_files[:-max_to_keep]: + print(f"Removing {file}") + os.remove(file) def setup_vgg19(self): if self.vgg_19 is None: @@ -218,6 +256,62 @@ class TrainVAEProcess(BaseTrainProcess): else: return torch.tensor(0.0, device=self.device) + def get_mean_variance_loss(self, latents: torch.Tensor): + if self.mv_loss_weight > 0: + # collapse rows into channels + latents_col = rearrange(latents, 'b c h (gw w) -> b (c gw) h w', gw=latents.shape[-1]) + mean_col = latents_col.mean(dim=(2, 3), keepdim=True) + std_col = latents_col.std(dim=(2, 3), keepdim=True, unbiased=False) + mean_loss_col = (mean_col ** 2).mean() + std_loss_col = ((std_col - 1) ** 2).mean() + + # collapse columns into channels + latents_row = rearrange(latents, 'b c (gh h) w -> b (c gh) h w', gh=latents.shape[-2]) + mean_row = latents_row.mean(dim=(2, 3), keepdim=True) + std_row = latents_row.std(dim=(2, 3), keepdim=True, unbiased=False) + mean_loss_row = (mean_row ** 2).mean() + std_loss_row = ((std_row - 1) ** 2).mean() + + # do a global one + + mean = latents.mean(dim=(2, 3), keepdim=True) + std = latents.std(dim=(2, 3), keepdim=True, unbiased=False) + mean_loss_global = (mean ** 2).mean() + std_loss_global = ((std - 1) ** 2).mean() + + return (mean_loss_col + std_loss_col + mean_loss_row + std_loss_row + mean_loss_global + std_loss_global) / 3 + else: + return torch.tensor(0.0, device=self.device) + + def get_ltv_loss(self, latent): + # loss to reduce the latent space variance + if self.ltv_weight > 0: + return total_variation(latent).mean() + else: + return torch.tensor(0.0, device=self.device) + + def get_latent_pixel_matching_loss(self, latent, pixels): + if self.lpm_weight > 0: + with torch.no_grad(): + pixels = pixels.to(latent.device, dtype=latent.dtype) + # resize down to latent size + pixels = torch.nn.functional.interpolate(pixels, size=(latent.shape[2], latent.shape[3]), mode='bilinear', align_corners=False) + + # mean the color channel and then expand to latent size + pixels = pixels.mean(dim=1, keepdim=True) + pixels = pixels.repeat(1, latent.shape[1], 1, 1) + # match the mean std of latent + latent_mean = latent.mean(dim=(2, 3), keepdim=True) + latent_std = latent.std(dim=(2, 3), keepdim=True) + pixels_mean = pixels.mean(dim=(2, 3), keepdim=True) + pixels_std = pixels.std(dim=(2, 3), keepdim=True) + pixels = (pixels - pixels_mean) / (pixels_std + 1e-6) * latent_std + latent_mean + + return torch.nn.functional.mse_loss(latent.float(), pixels.float()) + + else: + return torch.tensor(0.0, device=self.device) + def get_tv_loss(self, pred, target): if self.tv_weight > 0: get_tv_loss = ComparativeTotalVariation() @@ -277,7 +371,39 @@ class TrainVAEProcess(BaseTrainProcess): input_img = img img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype) img = img - decoded = self.vae(img).sample + latent = self.vae.encode(img).latent_dist.sample() + + latent_img = latent.clone() + bs, ch, h, w = latent_img.shape + grid_size = math.ceil(math.sqrt(ch)) + pad = grid_size * grid_size - ch + + # take first item in batch + latent_img = latent_img[0] # shape: (ch, h, w) + + if pad > 0: + padding = torch.zeros((pad, h, w), dtype=latent_img.dtype, device=latent_img.device) + latent_img = torch.cat([latent_img, padding], dim=0) + + # make grid + new_img = torch.zeros((1, grid_size * h, grid_size * w), dtype=latent_img.dtype, device=latent_img.device) + for x in range(grid_size): + for y in range(grid_size): + if x * grid_size + y < ch: + new_img[0, x * h:(x + 1) * h, y * w:(y + 1) * w] = latent_img[x * grid_size + y] + latent_img = new_img + # make rgb + latent_img = latent_img.repeat(3, 1, 1).unsqueeze(0) + latent_img = (latent_img / 2 + 0.5).clamp(0, 1) + + # resize to 256x256 + latent_img = torch.nn.functional.interpolate(latent_img, size=(self.resolution, self.resolution), mode='nearest') + latent_img = latent_img.squeeze(0).cpu().permute(1, 2, 0).float().numpy() + latent_img = (latent_img * 255).astype(np.uint8) + # convert to pillow image + latent_img = Image.fromarray(latent_img) + + decoded = self.vae.decode(latent).sample decoded = (decoded / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 decoded = decoded.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy() @@ -289,9 +415,10 @@ class TrainVAEProcess(BaseTrainProcess): input_img = input_img.resize((self.resolution, self.resolution)) decoded = decoded.resize((self.resolution, self.resolution)) - output_img = Image.new('RGB', (self.resolution * 2, self.resolution)) + output_img = Image.new('RGB', (self.resolution * 3, self.resolution)) output_img.paste(input_img, (0, 0)) output_img.paste(decoded, (self.resolution, 0)) + output_img.paste(latent_img, (self.resolution * 2, 0)) scale_up = 2 if output_img.height <= 300: @@ -326,12 +453,20 @@ class TrainVAEProcess(BaseTrainProcess): self.print(f"Loading VAE") self.print(f" - Loading VAE: {path_to_load}") if self.vae is None: - self.vae = AutoencoderKL.from_pretrained(path_to_load) + if path_to_load is not None: + self.vae = AutoencoderKL.from_pretrained(path_to_load) + elif self.vae_config is not None: + self.vae = AutoencoderKL(**self.vae_config) + else: + raise ValueError('vae_path or ae_config must be specified') # set decoder to train self.vae.to(self.device, dtype=self.torch_dtype) - self.vae.requires_grad_(False) - self.vae.eval() + if self.eq_vae: + self.vae.encoder.train() + else: + self.vae.requires_grad_(False) + self.vae.eval() self.vae.decoder.train() self.vae_scale_factor = 2 ** (len(self.vae.config['block_out_channels']) - 1) @@ -374,6 +509,10 @@ class TrainVAEProcess(BaseTrainProcess): if train_all: params = list(self.vae.decoder.parameters()) self.vae.decoder.requires_grad_(True) + if self.train_encoder: + # encoder + params += list(self.vae.encoder.parameters()) + self.vae.encoder.requires_grad_(True) else: # mid_block if train_all or 'mid_block' in self.blocks_to_train: @@ -388,12 +527,13 @@ class TrainVAEProcess(BaseTrainProcess): params += list(self.vae.decoder.conv_out.parameters()) self.vae.decoder.conv_out.requires_grad_(True) - if self.style_weight > 0 or self.content_weight > 0 or self.use_critic: + if self.style_weight > 0 or self.content_weight > 0: self.setup_vgg19() - self.vgg_19.requires_grad_(False) + # self.vgg_19.requires_grad_(False) self.vgg_19.eval() - if self.use_critic: - self.critic.setup() + + if self.use_critic: + self.critic.setup() if self.lpips_weight > 0 and self.lpips_loss is None: # self.lpips_loss = lpips.LPIPS(net='vgg') @@ -426,6 +566,9 @@ class TrainVAEProcess(BaseTrainProcess): "style": [], "content": [], "mse": [], + "mvl": [], + "ltv": [], + "lpm": [], "kl": [], "tv": [], "ptn": [], @@ -435,6 +578,9 @@ class TrainVAEProcess(BaseTrainProcess): epoch_losses = copy.deepcopy(blank_losses) log_losses = copy.deepcopy(blank_losses) # range start at self.epoch_num go to self.epochs + + latent_size = self.resolution // self.vae_scale_factor + for epoch in range(self.epoch_num, self.epochs, 1): if self.step_num >= self.max_steps: break @@ -442,8 +588,20 @@ class TrainVAEProcess(BaseTrainProcess): if self.step_num >= self.max_steps: break with torch.no_grad(): - batch = batch.to(self.device, dtype=self.torch_dtype) + + if self.random_scaling: + # only random scale 0.5 of the time + if random.random() < 0.5: + # random scale the batch + scale_factor = 0.25 + else: + scale_factor = 0.5 + new_size = (int(batch.shape[2] * scale_factor), int(batch.shape[3] * scale_factor)) + # make sure it is vae divisible + new_size = (new_size[0] // self.vae_scale_factor * self.vae_scale_factor, + new_size[1] // self.vae_scale_factor * self.vae_scale_factor) + # resize so it matches size of vae evenly if batch.shape[2] % self.vae_scale_factor != 0 or batch.shape[3] % self.vae_scale_factor != 0: @@ -451,27 +609,92 @@ class TrainVAEProcess(BaseTrainProcess): batch.shape[3] // self.vae_scale_factor * self.vae_scale_factor))(batch) # forward pass + # grad only if eq_vae + with torch.set_grad_enabled(self.train_encoder): dgd = self.vae.encode(batch).latent_dist mu, logvar = dgd.mean, dgd.logvar latents = dgd.sample() - latents.detach().requires_grad_(True) + + if self.eq_vae: + # process flips, rotate, scale + latent_chunks = list(latents.chunk(latents.shape[0], dim=0)) + batch_chunks = list(batch.chunk(batch.shape[0], dim=0)) + out_chunks = [] + for i in range(len(latent_chunks)): + try: + do_rotate = random.randint(0, 3) + do_flip_x = random.randint(0, 1) + do_flip_y = random.randint(0, 1) + do_scale = random.randint(0, 1) + if do_rotate > 0: + latent_chunks[i] = torch.rot90(latent_chunks[i], do_rotate, (2, 3)) + batch_chunks[i] = torch.rot90(batch_chunks[i], do_rotate, (2, 3)) + if do_flip_x > 0: + latent_chunks[i] = torch.flip(latent_chunks[i], [2]) + batch_chunks[i] = torch.flip(batch_chunks[i], [2]) + if do_flip_y > 0: + latent_chunks[i] = torch.flip(latent_chunks[i], [3]) + batch_chunks[i] = torch.flip(batch_chunks[i], [3]) + + # resize latent to fit + if latent_chunks[i].shape[2] != latent_size or latent_chunks[i].shape[3] != latent_size: + latent_chunks[i] = torch.nn.functional.interpolate(latent_chunks[i], size=(latent_size, latent_size), mode='bilinear', align_corners=False) + + # if do_scale > 0: + # scale = 2 + # start_latent_h = latent_chunks[i].shape[2] + # start_latent_w = latent_chunks[i].shape[3] + # start_batch_h = batch_chunks[i].shape[2] + # start_batch_w = batch_chunks[i].shape[3] + # latent_chunks[i] = torch.nn.functional.interpolate(latent_chunks[i], scale_factor=scale, mode='bilinear', align_corners=False) + # batch_chunks[i] = torch.nn.functional.interpolate(batch_chunks[i], scale_factor=scale, mode='bilinear', align_corners=False) + # # random crop. latent is smaller than match but crops need to match + # latent_x = random.randint(0, latent_chunks[i].shape[2] - start_latent_h) + # latent_y = random.randint(0, latent_chunks[i].shape[3] - start_latent_w) + # batch_x = latent_x * self.vae_scale_factor + # batch_y = latent_y * self.vae_scale_factor + + # # crop + # latent_chunks[i] = latent_chunks[i][:, :, latent_x:latent_x + start_latent_h, latent_y:latent_y + start_latent_w] + # batch_chunks[i] = batch_chunks[i][:, :, batch_x:batch_x + start_batch_h, batch_y:batch_y + start_batch_w] + except Exception as e: + print(f"Error processing image {i}: {e}") + traceback.print_exc() + raise e + out_chunks.append(latent_chunks[i]) + latents = torch.cat(out_chunks, dim=0) + # do dropout + if self.dropout > 0: + forward_latents = channel_dropout(latents, self.dropout) + else: + forward_latents = latents + + # resize batch to resolution if needed + if batch_chunks[0].shape[2] != self.resolution or batch_chunks[0].shape[3] != self.resolution: + batch_chunks = [torch.nn.functional.interpolate(b, size=(self.resolution, self.resolution), mode='bilinear', align_corners=False) for b in batch_chunks] + batch = torch.cat(batch_chunks, dim=0) + + else: + latents.detach().requires_grad_(True) + forward_latents = latents + + forward_latents = forward_latents.to(self.device, dtype=self.torch_dtype) + + if not self.train_encoder: + # detach latents if not training encoder + forward_latents = forward_latents.detach() - pred = self.vae.decode(latents).sample - - with torch.no_grad(): - show_tensors( - pred.clamp(-1, 1).clone(), - "combined tensor" - ) + pred = self.vae.decode(forward_latents).sample # Run through VGG19 - if self.style_weight > 0 or self.content_weight > 0 or self.use_critic: + if self.style_weight > 0 or self.content_weight > 0: stacked = torch.cat([pred, batch], dim=0) stacked = (stacked / 2 + 0.5).clamp(0, 1) self.vgg_19(stacked) if self.use_critic: - critic_d_loss = self.critic.step(self.vgg19_pool_4.tensor.detach()) + stacked = torch.cat([pred, batch], dim=0) + critic_d_loss = self.critic.step(stacked.detach()) else: critic_d_loss = 0.0 @@ -489,7 +712,8 @@ class TrainVAEProcess(BaseTrainProcess): tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight pattern_loss = self.get_pattern_loss(pred, batch) * self.pattern_weight if self.use_critic: - critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight + stacked = torch.cat([pred, batch], dim=0) + critic_gen_loss = self.critic.get_critic_loss(stacked) * self.critic_weight # do not let abs critic gen loss be higher than abs lpips * 0.1 if using it if self.lpips_weight > 0: @@ -502,8 +726,42 @@ class TrainVAEProcess(BaseTrainProcess): critic_gen_loss *= crit_g_scaler else: critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) + + if self.mv_loss_weight > 0: + mv_loss = self.get_mean_variance_loss(latents) * self.mv_loss_weight + else: + mv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) + + if self.ltv_weight > 0: + ltv_loss = self.get_ltv_loss(latents) * self.ltv_weight + else: + ltv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) + + if self.lpm_weight > 0: + lpm_loss = self.get_latent_pixel_matching_loss(latents, batch) * self.lpm_weight + else: + lpm_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) - loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + mv_loss + ltv_loss + + # check if loss is NaN or Inf + if torch.isnan(loss) or torch.isinf(loss): + self.print(f"Loss is NaN or Inf, stopping at step {self.step_num}") + self.print(f" - Style loss: {style_loss.item()}") + self.print(f" - Content loss: {content_loss.item()}") + self.print(f" - KLD loss: {kld_loss.item()}") + self.print(f" - MSE loss: {mse_loss.item()}") + self.print(f" - LPIPS loss: {lpips_loss.item()}") + self.print(f" - TV loss: {tv_loss.item()}") + self.print(f" - Pattern loss: {pattern_loss.item()}") + self.print(f" - Critic gen loss: {critic_gen_loss.item()}") + self.print(f" - Critic D loss: {critic_d_loss}") + self.print(f" - Mean variance loss: {mv_loss.item()}") + self.print(f" - Latent TV loss: {ltv_loss.item()}") + self.print(f" - Latent pixel matching loss: {lpm_loss.item()}") + self.print(f" - Total loss: {loss.item()}") + self.print(f" - Stopping training") + exit(1) # Backward pass and optimization optimizer.zero_grad() @@ -533,8 +791,17 @@ class TrainVAEProcess(BaseTrainProcess): loss_string += f" crG: {critic_gen_loss.item():.2e}" if self.use_critic: loss_string += f" crD: {critic_d_loss:.2e}" + if self.mv_loss_weight > 0: + loss_string += f" mvl: {mv_loss:.2e}" + if self.ltv_weight > 0: + loss_string += f" ltv: {ltv_loss:.2e}" + if self.lpm_weight > 0: + loss_string += f" lpm: {lpm_loss:.2e}" + - if self.optimizer_type.startswith('dadaptation') or \ + if hasattr(optimizer, 'get_avg_learning_rate'): + learning_rate = optimizer.get_avg_learning_rate() + elif self.optimizer_type.startswith('dadaptation') or \ self.optimizer_type.lower().startswith('prodigy'): learning_rate = ( optimizer.param_groups[0]["d"] * @@ -562,6 +829,9 @@ class TrainVAEProcess(BaseTrainProcess): epoch_losses["ptn"].append(pattern_loss.item()) epoch_losses["crG"].append(critic_gen_loss.item()) epoch_losses["crD"].append(critic_d_loss) + epoch_losses["mvl"].append(mv_loss.item()) + epoch_losses["ltv"].append(ltv_loss.item()) + epoch_losses["lpm"].append(lpm_loss.item()) log_losses["total"].append(loss_value) log_losses["lpips"].append(lpips_loss.item()) @@ -573,6 +843,9 @@ class TrainVAEProcess(BaseTrainProcess): log_losses["ptn"].append(pattern_loss.item()) log_losses["crG"].append(critic_gen_loss.item()) log_losses["crD"].append(critic_d_loss) + log_losses["mvl"].append(mv_loss.item()) + log_losses["ltv"].append(ltv_loss.item()) + log_losses["lpm"].append(lpm_loss.item()) # don't do on first step if self.step_num != start_step: diff --git a/jobs/process/models/critic.py b/jobs/process/models/critic.py new file mode 100644 index 00000000..c792a9be --- /dev/null +++ b/jobs/process/models/critic.py @@ -0,0 +1,234 @@ +import glob +import os +from typing import TYPE_CHECKING, Union + +import numpy as np +import torch +import torch.nn as nn +from safetensors.torch import load_file, save_file + +from toolkit.losses import get_gradient_penalty +from toolkit.metadata import get_meta_for_safetensors +from toolkit.optimizer import get_optimizer +from toolkit.train_tools import get_torch_dtype + + +class MeanReduce(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, inputs): + # global mean over spatial dims (keeps channel/batch) + return torch.mean(inputs, dim=(2, 3), keepdim=True) + + +class SelfAttention2d(nn.Module): + """ + Lightweight self-attention layer (SAGAN-style) that keeps spatial + resolution unchanged. Adds minimal params / compute but improves + long-range modelling – helpful for variable-sized inputs. + """ + def __init__(self, in_channels: int): + super().__init__() + self.query = nn.Conv1d(in_channels, in_channels // 8, 1) + self.key = nn.Conv1d(in_channels, in_channels // 8, 1) + self.value = nn.Conv1d(in_channels, in_channels, 1) + self.gamma = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + B, C, H, W = x.shape + flat = x.view(B, C, H * W) # (B,C,N) + q = self.query(flat).permute(0, 2, 1) # (B,N,C//8) + k = self.key(flat) # (B,C//8,N) + attn = torch.bmm(q, k) # (B,N,N) + attn = attn.softmax(dim=-1) # softmax along last dim + v = self.value(flat) # (B,C,N) + out = torch.bmm(v, attn.permute(0, 2, 1)) # (B,C,N) + out = out.view(B, C, H, W) # restore spatial dims + return self.gamma * out + x # residual + + +class CriticModel(nn.Module): + def __init__(self, base_channels: int = 64): + super().__init__() + + def sn_conv(in_c, out_c, k, s, p): + return nn.utils.spectral_norm( + nn.Conv2d(in_c, out_c, kernel_size=k, stride=s, padding=p) + ) + + layers = [ + # initial down-sample + sn_conv(3, base_channels, 3, 2, 1), + nn.LeakyReLU(0.2, inplace=True), + ] + + in_c = base_channels + # progressive downsamples ×3 (64→128→256→512) + for _ in range(3): + out_c = min(in_c * 2, 1024) + layers += [ + sn_conv(in_c, out_c, 3, 2, 1), + nn.LeakyReLU(0.2, inplace=True), + ] + # single attention block after reaching 256 channels + if out_c == 256: + layers += [SelfAttention2d(out_c)] + in_c = out_c + + # extra depth (keeps spatial size) + layers += [ + sn_conv(in_c, 1024, 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True), + + # final 1-channel prediction map + sn_conv(1024, 1, 3, 1, 1), + MeanReduce(), # → (B,1,1,1) + nn.Flatten(), # → (B,1) + ] + + self.main = nn.Sequential(*layers) + + def forward(self, inputs): + # force full-precision inside AMP ctx for stability + with torch.cuda.amp.autocast(False): + return self.main(inputs.float()) + + +if TYPE_CHECKING: + from jobs.process.TrainVAEProcess import TrainVAEProcess + from jobs.process.TrainESRGANProcess import TrainESRGANProcess + + +class Critic: + process: Union['TrainVAEProcess', 'TrainESRGANProcess'] + + def __init__( + self, + learning_rate=1e-5, + device='cpu', + optimizer='adam', + num_critic_per_gen=1, + dtype='float32', + lambda_gp=10, + start_step=0, + warmup_steps=1000, + process=None, + optimizer_params=None, + ): + self.learning_rate = learning_rate + self.device = device + self.optimizer_type = optimizer + self.num_critic_per_gen = num_critic_per_gen + self.dtype = dtype + self.torch_dtype = get_torch_dtype(self.dtype) + self.process = process + self.model = None + self.optimizer = None + self.scheduler = None + self.warmup_steps = warmup_steps + self.start_step = start_step + self.lambda_gp = lambda_gp + + if optimizer_params is None: + optimizer_params = {} + self.optimizer_params = optimizer_params + self.print = self.process.print + print(f" Critic config: {self.__dict__}") + + def setup(self): + self.model = CriticModel().to(self.device) + self.load_weights() + self.model.train() + self.model.requires_grad_(True) + params = self.model.parameters() + self.optimizer = get_optimizer( + params, + self.optimizer_type, + self.learning_rate, + optimizer_params=self.optimizer_params, + ) + self.scheduler = torch.optim.lr_scheduler.ConstantLR( + self.optimizer, + total_iters=self.process.max_steps * self.num_critic_per_gen, + factor=1, + verbose=False, + ) + + def load_weights(self): + path_to_load = None + self.print(f"Critic: Looking for latest checkpoint in {self.process.save_root}") + files = glob.glob(os.path.join(self.process.save_root, f"CRITIC_{self.process.job.name}*.safetensors")) + if files: + latest_file = max(files, key=os.path.getmtime) + print(f" - Latest checkpoint is: {latest_file}") + path_to_load = latest_file + else: + self.print(" - No checkpoint found, starting from scratch") + if path_to_load: + self.model.load_state_dict(load_file(path_to_load)) + + def save(self, step=None): + self.process.update_training_metadata() + save_meta = get_meta_for_safetensors(self.process.meta, self.process.job.name) + step_num = f"_{str(step).zfill(9)}" if step is not None else '' + save_path = os.path.join( + self.process.save_root, f"CRITIC_{self.process.job.name}{step_num}.safetensors" + ) + save_file(self.model.state_dict(), save_path, save_meta) + self.print(f"Saved critic to {save_path}") + + def get_critic_loss(self, vgg_output): + # (caller still passes combined [pred|target] images) + if self.start_step > self.process.step_num: + return torch.tensor(0.0, dtype=self.torch_dtype, device=self.device) + + warmup_scaler = 1.0 + if self.process.step_num < self.start_step + self.warmup_steps: + warmup_scaler = (self.process.step_num - self.start_step) / self.warmup_steps + + self.model.eval() + self.model.requires_grad_(False) + + vgg_pred, _ = torch.chunk(vgg_output.float(), 2, dim=0) + stacked_output = self.model(vgg_pred) + return (-torch.mean(stacked_output)) * warmup_scaler + + def step(self, vgg_output): + self.model.train() + self.model.requires_grad_(True) + self.optimizer.zero_grad() + + critic_losses = [] + inputs = vgg_output.detach().to(self.device, dtype=torch.float32) + + vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0) + stacked_output = self.model(inputs).float() + out_pred, out_target = torch.chunk(stacked_output, 2, dim=0) + + # hinge loss + gradient penalty + loss_real = torch.relu(1.0 - out_target).mean() + loss_fake = torch.relu(1.0 + out_pred).mean() + gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device) + critic_loss = loss_real + loss_fake + self.lambda_gp * gradient_penalty + + critic_loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + self.optimizer.step() + self.scheduler.step() + critic_losses.append(critic_loss.item()) + + return float(np.mean(critic_losses)) + + def get_lr(self): + if hasattr(self.optimizer, 'get_avg_learning_rate'): + learning_rate = self.optimizer.get_avg_learning_rate() + elif self.optimizer_type.startswith('dadaptation') or \ + self.optimizer_type.lower().startswith('prodigy'): + learning_rate = ( + self.optimizer.param_groups[0]["d"] * + self.optimizer.param_groups[0]["lr"] + ) + else: + learning_rate = self.optimizer.param_groups[0]['lr'] + return learning_rate diff --git a/jobs/process/models/vgg19_critic.py b/jobs/process/models/vgg19_critic.py index 8cf438bf..4d7f74f9 100644 --- a/jobs/process/models/vgg19_critic.py +++ b/jobs/process/models/vgg19_critic.py @@ -33,11 +33,20 @@ class Vgg19Critic(nn.Module): super(Vgg19Critic, self).__init__() self.main = nn.Sequential( # input (bs, 512, 32, 32) - nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1), + # nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1), + nn.utils.spectral_norm( # SN keeps D’s scale in check + nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1) + ), nn.LeakyReLU(0.2), # (bs, 512, 16, 16) - nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1), + # nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1), + nn.utils.spectral_norm( + nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1) + ), nn.LeakyReLU(0.2), # (bs, 512, 8, 8) - nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1), + # nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1), + nn.utils.spectral_norm( + nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1) + ), # (bs, 1, 4, 4) MeanReduce(), # (bs, 1, 1, 1) nn.Flatten(), # (bs, 1) @@ -47,7 +56,9 @@ class Vgg19Critic(nn.Module): ) def forward(self, inputs): - return self.main(inputs) + # return self.main(inputs) + with torch.cuda.amp.autocast(False): + return self.main(inputs.float()) if TYPE_CHECKING: @@ -92,7 +103,7 @@ class Critic: print(f" Critic config: {self.__dict__}") def setup(self): - self.model = Vgg19Critic().to(self.device, dtype=self.torch_dtype) + self.model = Vgg19Critic().to(self.device) self.load_weights() self.model.train() self.model.requires_grad_(True) @@ -142,7 +153,8 @@ class Critic: # set model to not train for generator loss self.model.eval() self.model.requires_grad_(False) - vgg_pred, vgg_target = torch.chunk(vgg_output, 2, dim=0) + # vgg_pred, vgg_target = torch.chunk(vgg_output, 2, dim=0) + vgg_pred, vgg_target = torch.chunk(vgg_output.float(), 2, dim=0) # run model stacked_output = self.model(vgg_pred) @@ -157,20 +169,34 @@ class Critic: self.optimizer.zero_grad() critic_losses = [] - inputs = vgg_output.detach() - inputs = inputs.to(self.device, dtype=self.torch_dtype) + # inputs = vgg_output.detach() + # inputs = inputs.to(self.device, dtype=self.torch_dtype) + inputs = vgg_output.detach().to(self.device, dtype=torch.float32) self.optimizer.zero_grad() vgg_pred, vgg_target = torch.chunk(inputs, 2, dim=0) + # stacked_output = self.model(inputs).float() + # out_pred, out_target = torch.chunk(stacked_output, 2, dim=0) + + # # Compute gradient penalty + # gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device) + + # # Compute WGAN-GP critic loss + # critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty + stacked_output = self.model(inputs).float() out_pred, out_target = torch.chunk(stacked_output, 2, dim=0) - # Compute gradient penalty + # ── hinge loss ── + loss_real = torch.relu(1.0 - out_target).mean() + loss_fake = torch.relu(1.0 + out_pred).mean() + + # gradient penalty (unchanged helper) gradient_penalty = get_gradient_penalty(self.model, vgg_target, vgg_pred, self.device) - # Compute WGAN-GP critic loss - critic_loss = -(torch.mean(out_target) - torch.mean(out_pred)) + self.lambda_gp * gradient_penalty + critic_loss = loss_real + loss_fake + self.lambda_gp * gradient_penalty + critic_loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self.optimizer.step() diff --git a/toolkit/accelerator.py b/toolkit/accelerator.py index ebcf0095..0736f016 100644 --- a/toolkit/accelerator.py +++ b/toolkit/accelerator.py @@ -11,7 +11,10 @@ def get_accelerator() -> Accelerator: return global_accelerator def unwrap_model(model): - accelerator = get_accelerator() - model = accelerator.unwrap_model(model) - model = model._orig_mod if is_compiled_module(model) else model + try: + accelerator = get_accelerator() + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + except Exception as e: + pass return model diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index e3f80ae3..aa3d84a6 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -325,6 +325,8 @@ class TrainConfig: self.adapter_assist_type: Optional[str] = kwargs.get('adapter_assist_type', 't2i') # t2i, control_net self.noise_multiplier = kwargs.get('noise_multiplier', 1.0) self.target_noise_multiplier = kwargs.get('target_noise_multiplier', 1.0) + self.random_noise_multiplier = kwargs.get('random_noise_multiplier', 0.0) + self.random_noise_shift = kwargs.get('random_noise_shift', 0.0) self.img_multiplier = kwargs.get('img_multiplier', 1.0) self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0) self.latent_multiplier = kwargs.get('latent_multiplier', 1.0) @@ -333,7 +335,6 @@ class TrainConfig: # multiplier applied to loos on regularization images self.reg_weight = kwargs.get('reg_weight', 1.0) self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000) - self.random_noise_shift = kwargs.get('random_noise_shift', 0.0) # automatically adapte the vae scaling based on the image norm self.adaptive_scaling_factor = kwargs.get('adaptive_scaling_factor', False) @@ -412,7 +413,7 @@ class TrainConfig: self.correct_pred_norm = kwargs.get('correct_pred_norm', False) self.correct_pred_norm_multiplier = kwargs.get('correct_pred_norm_multiplier', 1.0) - self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace + self.loss_type = kwargs.get('loss_type', 'mse') # mse, mae, wavelet, pixelspace, cfm # scale the prediction by this. Increase for more detail, decrease for less self.pred_scaler = kwargs.get('pred_scaler', 1.0) @@ -436,7 +437,8 @@ class TrainConfig: # adds an additional loss to the network to encourage it output a normalized standard deviation self.target_norm_std = kwargs.get('target_norm_std', None) self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0) - self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend + self.timestep_type = kwargs.get('timestep_type', 'sigmoid') # sigmoid, linear, lognorm_blend, next_sample + self.next_sample_timesteps = kwargs.get('next_sample_timesteps', 8) self.linear_timesteps = kwargs.get('linear_timesteps', False) self.linear_timesteps2 = kwargs.get('linear_timesteps2', False) self.disable_sampling = kwargs.get('disable_sampling', False) diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index 27a94503..17550dde 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -343,7 +343,7 @@ class BaseModel: pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None, ): - network = unwrap_model(self.network) + network = self.network merge_multiplier = 1.0 flush() # if using assistant, unfuse it @@ -364,6 +364,7 @@ class BaseModel: self.assistant_lora.force_to(self.device_torch, self.torch_dtype) if network is not None: + network = unwrap_model(self.network) network.eval() # check if we have the same network weight for all samples. If we do, we can merge in th # the network to drastically speed up inference diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index 2ea29276..17b259e6 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -255,30 +255,30 @@ class DiffusionFeatureExtractor3(nn.Module): dtype = torch.bfloat16 device = self.vae.device - # first we step the scheduler from current timestep to the very end for a full denoise - # bs = noise_pred.shape[0] - # noise_pred_chunks = torch.chunk(noise_pred, bs) - # timestep_chunks = torch.chunk(timesteps, bs) - # noisy_latent_chunks = torch.chunk(noisy_latents, bs) - # stepped_chunks = [] - # for idx in range(bs): - # model_output = noise_pred_chunks[idx] - # timestep = timestep_chunks[idx] - # scheduler._step_index = None - # scheduler._init_step_index(timestep) - # sample = noisy_latent_chunks[idx].to(torch.float32) - - # sigma = scheduler.sigmas[scheduler.step_index] - # sigma_next = scheduler.sigmas[-1] # use last sigma for final step - # prev_sample = sample + (sigma_next - sigma) * model_output - # stepped_chunks.append(prev_sample) - - # stepped_latents = torch.cat(stepped_chunks, dim=0) if model is not None and hasattr(model, 'get_stepped_pred'): stepped_latents = model.get_stepped_pred(noise_pred, noise) else: - stepped_latents = noise - noise_pred + # stepped_latents = noise - noise_pred + # first we step the scheduler from current timestep to the very end for a full denoise + bs = noise_pred.shape[0] + noise_pred_chunks = torch.chunk(noise_pred, bs) + timestep_chunks = torch.chunk(timesteps, bs) + noisy_latent_chunks = torch.chunk(noisy_latents, bs) + stepped_chunks = [] + for idx in range(bs): + model_output = noise_pred_chunks[idx] + timestep = timestep_chunks[idx] + scheduler._step_index = None + scheduler._init_step_index(timestep) + sample = noisy_latent_chunks[idx].to(torch.float32) + + sigma = scheduler.sigmas[scheduler.step_index] + sigma_next = scheduler.sigmas[-1] # use last sigma for final step + prev_sample = sample + (sigma_next - sigma) * model_output + stepped_chunks.append(prev_sample) + + stepped_latents = torch.cat(stepped_chunks, dim=0) latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index d35fc09e..068e747e 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -142,7 +142,9 @@ class StableDiffusion: ): self.accelerator = get_accelerator() self.custom_pipeline = custom_pipeline - self.device = device + self.device = str(device) + if "cuda" in self.device and ":" not in self.device: + self.device = f"{self.device}:0" self.device_torch = torch.device(device) self.dtype = dtype self.torch_dtype = get_torch_dtype(dtype) @@ -251,13 +253,13 @@ class StableDiffusion: def get_bucket_divisibility(self): if self.vae is None: - return 8 + return 16 divisibility = 2 ** (len(self.vae.config['block_out_channels']) - 1) # flux packs this again, if self.is_flux or self.is_v3: divisibility = divisibility * 2 - return divisibility + return divisibility * 2 # todo remove this def load_model(self): @@ -2086,7 +2088,10 @@ class StableDiffusion: noise_pred = noise_pred else: if self.unet.device != self.device_torch: - self.unet.to(self.device_torch) + try: + self.unet.to(self.device_torch) + except Exception as e: + pass if self.unet.dtype != self.torch_dtype: self.unet = self.unet.to(dtype=self.torch_dtype) if self.is_flux: