From b6d25fcd10b30682bcb79ba7db6b050498bbddf0 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 30 May 2025 12:06:47 -0600 Subject: [PATCH] Improvements to vae trainer. Adjust denoise prediction of DFE v3 --- jobs/process/TrainVAEProcess.py | 35 ++++++++++++++-- jobs/process/models/critic.py | 17 +++++--- .../models/diffusion_feature_extraction.py | 40 +++++++++---------- 3 files changed, 63 insertions(+), 29 deletions(-) diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index 64566db9..0b6d5fab 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -89,6 +89,7 @@ class TrainVAEProcess(BaseTrainProcess): self.vae_config = self.get_conf('vae_config', None) self.dropout = self.get_conf('dropout', 0.0, as_type=float) self.train_encoder = self.get_conf('train_encoder', False, as_type=bool) + self.random_scaling = self.get_conf('random_scaling', False, as_type=bool) if not self.train_encoder: # remove losses that only target encoder @@ -159,7 +160,11 @@ class TrainVAEProcess(BaseTrainProcess): for dataset in self.datasets_objects: print(f" - Dataset: {dataset['path']}") ds = copy.copy(dataset) - ds['resolution'] = self.resolution + dataset_res = self.resolution + if self.random_scaling: + # scale 2x to allow for random scaling + dataset_res = int(dataset_res * 2) + ds['resolution'] = dataset_res image_dataset = ImageDataset(ds) datasets.append(image_dataset) @@ -168,7 +173,7 @@ class TrainVAEProcess(BaseTrainProcess): concatenated_dataset, batch_size=self.batch_size, shuffle=True, - num_workers=8 + num_workers=16 ) def remove_oldest_checkpoint(self): @@ -573,6 +578,9 @@ class TrainVAEProcess(BaseTrainProcess): epoch_losses = copy.deepcopy(blank_losses) log_losses = copy.deepcopy(blank_losses) # range start at self.epoch_num go to self.epochs + + latent_size = self.resolution // self.vae_scale_factor + for epoch in range(self.epoch_num, self.epochs, 1): if self.step_num >= self.max_steps: break @@ -580,8 +588,20 @@ class TrainVAEProcess(BaseTrainProcess): if self.step_num >= self.max_steps: break with torch.no_grad(): - batch = batch.to(self.device, dtype=self.torch_dtype) + + if self.random_scaling: + # only random scale 0.5 of the time + if random.random() < 0.5: + # random scale the batch + scale_factor = 0.25 + else: + scale_factor = 0.5 + new_size = (int(batch.shape[2] * scale_factor), int(batch.shape[3] * scale_factor)) + # make sure it is vae divisible + new_size = (new_size[0] // self.vae_scale_factor * self.vae_scale_factor, + new_size[1] // self.vae_scale_factor * self.vae_scale_factor) + # resize so it matches size of vae evenly if batch.shape[2] % self.vae_scale_factor != 0 or batch.shape[3] % self.vae_scale_factor != 0: @@ -615,6 +635,11 @@ class TrainVAEProcess(BaseTrainProcess): if do_flip_y > 0: latent_chunks[i] = torch.flip(latent_chunks[i], [3]) batch_chunks[i] = torch.flip(batch_chunks[i], [3]) + + # resize latent to fit + if latent_chunks[i].shape[2] != latent_size or latent_chunks[i].shape[3] != latent_size: + latent_chunks[i] = torch.nn.functional.interpolate(latent_chunks[i], size=(latent_size, latent_size), mode='bilinear', align_corners=False) + # if do_scale > 0: # scale = 2 # start_latent_h = latent_chunks[i].shape[2] @@ -643,6 +668,10 @@ class TrainVAEProcess(BaseTrainProcess): forward_latents = channel_dropout(latents, self.dropout) else: forward_latents = latents + + # resize batch to resolution if needed + if batch_chunks[0].shape[2] != self.resolution or batch_chunks[0].shape[3] != self.resolution: + batch_chunks = [torch.nn.functional.interpolate(b, size=(self.resolution, self.resolution), mode='bilinear', align_corners=False) for b in batch_chunks] batch = torch.cat(batch_chunks, dim=0) else: diff --git a/jobs/process/models/critic.py b/jobs/process/models/critic.py index 42bdb637..c792a9be 100644 --- a/jobs/process/models/critic.py +++ b/jobs/process/models/critic.py @@ -220,10 +220,15 @@ class Critic: return float(np.mean(critic_losses)) - def get_lr(self): - if self.optimizer_type.startswith('dadaptation'): - return ( - self.optimizer.param_groups[0]["d"] - * self.optimizer.param_groups[0]["lr"] + def get_lr(self): + if hasattr(self.optimizer, 'get_avg_learning_rate'): + learning_rate = self.optimizer.get_avg_learning_rate() + elif self.optimizer_type.startswith('dadaptation') or \ + self.optimizer_type.lower().startswith('prodigy'): + learning_rate = ( + self.optimizer.param_groups[0]["d"] * + self.optimizer.param_groups[0]["lr"] ) - return self.optimizer.param_groups[0]["lr"] + else: + learning_rate = self.optimizer.param_groups[0]['lr'] + return learning_rate diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index 2ea29276..17b259e6 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -255,30 +255,30 @@ class DiffusionFeatureExtractor3(nn.Module): dtype = torch.bfloat16 device = self.vae.device - # first we step the scheduler from current timestep to the very end for a full denoise - # bs = noise_pred.shape[0] - # noise_pred_chunks = torch.chunk(noise_pred, bs) - # timestep_chunks = torch.chunk(timesteps, bs) - # noisy_latent_chunks = torch.chunk(noisy_latents, bs) - # stepped_chunks = [] - # for idx in range(bs): - # model_output = noise_pred_chunks[idx] - # timestep = timestep_chunks[idx] - # scheduler._step_index = None - # scheduler._init_step_index(timestep) - # sample = noisy_latent_chunks[idx].to(torch.float32) - - # sigma = scheduler.sigmas[scheduler.step_index] - # sigma_next = scheduler.sigmas[-1] # use last sigma for final step - # prev_sample = sample + (sigma_next - sigma) * model_output - # stepped_chunks.append(prev_sample) - - # stepped_latents = torch.cat(stepped_chunks, dim=0) if model is not None and hasattr(model, 'get_stepped_pred'): stepped_latents = model.get_stepped_pred(noise_pred, noise) else: - stepped_latents = noise - noise_pred + # stepped_latents = noise - noise_pred + # first we step the scheduler from current timestep to the very end for a full denoise + bs = noise_pred.shape[0] + noise_pred_chunks = torch.chunk(noise_pred, bs) + timestep_chunks = torch.chunk(timesteps, bs) + noisy_latent_chunks = torch.chunk(noisy_latents, bs) + stepped_chunks = [] + for idx in range(bs): + model_output = noise_pred_chunks[idx] + timestep = timestep_chunks[idx] + scheduler._step_index = None + scheduler._init_step_index(timestep) + sample = noisy_latent_chunks[idx].to(torch.float32) + + sigma = scheduler.sigmas[scheduler.step_index] + sigma_next = scheduler.sigmas[-1] # use last sigma for final step + prev_sample = sample + (sigma_next - sigma) * model_output + stepped_chunks.append(prev_sample) + + stepped_latents = torch.cat(stepped_chunks, dim=0) latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype)