From 787bb37e7633d345b8e61557a57caee79016eb0b Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 12 Feb 2025 09:27:44 -0700 Subject: [PATCH] Small fixed for DFE, polar guidance, and other things --- extensions_built_in/sd_trainer/SDTrainer.py | 4 +- jobs/process/BaseSDTrainProcess.py | 7 +- toolkit/config_modules.py | 2 +- toolkit/guidance.py | 27 ++++++-- .../models/diffusion_feature_extraction.py | 69 ++++++++++--------- toolkit/sampler.py | 14 ++++ toolkit/stable_diffusion_model.py | 7 +- 7 files changed, 87 insertions(+), 43 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 09156909..6dd4d683 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -404,13 +404,14 @@ class SDTrainer(BaseSDTrainProcess): additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight * 100.0 elif self.dfe.version == 3: dfe_loss = self.dfe( + noise=noise, noise_pred=noise_pred, noisy_latents=noisy_latents, timesteps=timesteps, batch=batch, scheduler=self.sd.noise_scheduler ) - additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight + additional_loss += dfe_loss * self.train_config.diffusion_feature_extractor_weight else: raise ValueError(f"Unknown diffusion feature extractor version {self.dfe.version}") @@ -563,6 +564,7 @@ class SDTrainer(BaseSDTrainProcess): noise=noise, sd=self.sd, unconditional_embeds=unconditional_embeds, + train_config=self.train_config, **kwargs ) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 89ceac4f..ea2d2539 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1387,12 +1387,17 @@ class BaseSDTrainProcess(BaseTrainProcess): self.load_training_state_from_metadata(latest_save_path) # get the noise scheduler + arch = 'sd' + if self.model_config.is_pixart: + arch = 'pixart' + if self.model_config.is_flux: + arch = 'flux' sampler = get_sampler( self.train_config.noise_scheduler, { "prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon", }, - 'sd' if not self.model_config.is_pixart else 'pixart' + arch ) if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None: diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 30762c72..4661ce46 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -403,7 +403,7 @@ class TrainConfig: # diffusion feature extractor self.diffusion_feature_extractor_path = kwargs.get('diffusion_feature_extractor_path', None) - self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', 0.1) + self.diffusion_feature_extractor_weight = kwargs.get('diffusion_feature_extractor_weight', 1.0) # optimal noise pairing self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1) diff --git a/toolkit/guidance.py b/toolkit/guidance.py index dcf28204..b9971dc1 100644 --- a/toolkit/guidance.py +++ b/toolkit/guidance.py @@ -6,6 +6,7 @@ from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.prompt_utils import PromptEmbeds, concat_prompt_embeds from toolkit.stable_diffusion_model import StableDiffusion from toolkit.train_tools import get_torch_dtype +from toolkit.config_modules import TrainConfig GuidanceType = Literal["targeted", "polarity", "targeted_polarity", "direct"] @@ -407,6 +408,7 @@ def get_guided_loss_polarity( batch: 'DataLoaderBatchDTO', noise: torch.Tensor, sd: 'StableDiffusion', + train_config: 'TrainConfig', scaler=None, **kwargs ): @@ -423,8 +425,22 @@ def get_guided_loss_polarity( target_neg = noise if sd.is_flow_matching: - # set the timesteps for flow matching as linear since we will do weighing - sd.noise_scheduler.set_train_timesteps(1000, device, linear=True) + linear_timesteps = any([ + train_config.linear_timesteps, + train_config.linear_timesteps2, + train_config.timestep_type == 'linear', + ]) + + timestep_type = 'linear' if linear_timesteps else None + if timestep_type is None: + timestep_type = train_config.timestep_type + + sd.noise_scheduler.set_train_timesteps( + 1000, + device=device, + timestep_type=timestep_type, + latents=conditional_latents + ) target_pos = (noise - conditional_latents).detach() target_neg = (noise - unconditional_latents).detach() @@ -481,11 +497,6 @@ def get_guided_loss_polarity( loss = pred_loss + pred_neg_loss - # if sd.is_flow_matching: - # timestep_weight = sd.noise_scheduler.get_weights_for_timesteps(timesteps).to(loss.device, dtype=loss.dtype).detach() - # loss = loss * timestep_weight - - loss = loss.mean([1, 2, 3]) loss = loss.mean() if scaler is not None: @@ -609,6 +620,7 @@ def get_guidance_loss( mask_multiplier=None, prior_pred=None, scaler=None, + train_config=None, **kwargs ): # TODO add others and process individual batch items separately @@ -641,6 +653,7 @@ def get_guidance_loss( noise, sd, scaler=scaler, + train_config=train_config, **kwargs ) elif guidance_type == "tnt": diff --git a/toolkit/models/diffusion_feature_extraction.py b/toolkit/models/diffusion_feature_extraction.py index 8c6fd966..88ffbd0b 100644 --- a/toolkit/models/diffusion_feature_extraction.py +++ b/toolkit/models/diffusion_feature_extraction.py @@ -226,45 +226,48 @@ class DiffusionFeatureExtractor3(nn.Module): return feats_list # do lpips - lpips_feat_list = [x.detach() for x in get_lpips_features( + lpips_feat_list = [x for x in get_lpips_features( tensors_n1p1.to(device, dtype=torch.float32))] return lpips_feat_list def forward( - self, + self, + noise, noise_pred, noisy_latents, timesteps, batch: DataLoaderBatchDTO, scheduler: CustomFlowMatchEulerDiscreteScheduler, - lpips_weight=20.0, + lpips_weight=1.0, clip_weight=0.1, - pixel_weight=1.0 + pixel_weight=0.1 ): dtype = torch.bfloat16 device = self.vae.device # first we step the scheduler from current timestep to the very end for a full denoise - bs = noise_pred.shape[0] - noise_pred_chunks = torch.chunk(noise_pred, bs) - timestep_chunks = torch.chunk(timesteps, bs) - noisy_latent_chunks = torch.chunk(noisy_latents, bs) - stepped_chunks = [] - for idx in range(bs): - model_output = noise_pred_chunks[idx] - timestep = timestep_chunks[idx] - scheduler._step_index = None - scheduler._init_step_index(timestep) - sample = noisy_latent_chunks[idx].to(torch.float32) + # bs = noise_pred.shape[0] + # noise_pred_chunks = torch.chunk(noise_pred, bs) + # timestep_chunks = torch.chunk(timesteps, bs) + # noisy_latent_chunks = torch.chunk(noisy_latents, bs) + # stepped_chunks = [] + # for idx in range(bs): + # model_output = noise_pred_chunks[idx] + # timestep = timestep_chunks[idx] + # scheduler._step_index = None + # scheduler._init_step_index(timestep) + # sample = noisy_latent_chunks[idx].to(torch.float32) - sigma = scheduler.sigmas[scheduler.step_index] - sigma_next = scheduler.sigmas[-1] # use last sigma for final step - prev_sample = sample + (sigma_next - sigma) * model_output - stepped_chunks.append(prev_sample) + # sigma = scheduler.sigmas[scheduler.step_index] + # sigma_next = scheduler.sigmas[-1] # use last sigma for final step + # prev_sample = sample + (sigma_next - sigma) * model_output + # stepped_chunks.append(prev_sample) - stepped_latents = torch.cat(stepped_chunks, dim=0) + # stepped_latents = torch.cat(stepped_chunks, dim=0) + + stepped_latents = noise - noise_pred latents = stepped_latents.to(self.vae.device, dtype=self.vae.dtype) @@ -274,16 +277,18 @@ class DiffusionFeatureExtractor3(nn.Module): pred_images = (tensors_n1p1 + 1) / 2 # 0 to 1 - pred_clip_output = self.get_siglip_features(pred_images) lpips_feat_list_pred = self.get_lpips_features(pred_images.float()) + total_loss = 0 + with torch.no_grad(): target_img = batch.tensor.to(device, dtype=dtype) # go from -1 to 1 to 0 to 1 target_img = (target_img + 1) / 2 - target_clip_output = self.get_siglip_features(target_img).detach() lpips_feat_list_target = self.get_lpips_features(target_img.float()) - + target_clip_output = self.get_siglip_features(target_img).detach() + + pred_clip_output = self.get_siglip_features(pred_images) clip_loss = torch.nn.functional.mse_loss( pred_clip_output.float(), target_clip_output.float() ) * clip_weight @@ -293,7 +298,7 @@ class DiffusionFeatureExtractor3(nn.Module): else: self.losses['clip_loss'] += clip_loss.item() - total_loss = clip_loss + total_loss += clip_loss lpips_loss = 0 for idx, lpips_feat in enumerate(lpips_feat_list_pred): @@ -308,14 +313,14 @@ class DiffusionFeatureExtractor3(nn.Module): total_loss += lpips_loss - mse_loss = torch.nn.functional.mse_loss( - stepped_latents.float(), batch.latents.float() - ) * pixel_weight + # mse_loss = torch.nn.functional.mse_loss( + # stepped_latents.float(), batch.latents.float() + # ) * pixel_weight - if 'pixel_loss' not in self.losses: - self.losses['pixel_loss'] = mse_loss.item() - else: - self.losses['pixel_loss'] += mse_loss.item() + # if 'pixel_loss' not in self.losses: + # self.losses['pixel_loss'] = mse_loss.item() + # else: + # self.losses['pixel_loss'] += mse_loss.item() if self.step % self.log_every == 0 and self.step > 0: print(f"DFE losses:") @@ -325,7 +330,7 @@ class DiffusionFeatureExtractor3(nn.Module): print(f" - {key}: {self.losses[key]:.3e}") self.losses[key] = 0.0 - total_loss += mse_loss + # total_loss += mse_loss self.step += 1 return total_loss diff --git a/toolkit/sampler.py b/toolkit/sampler.py index aae6e379..f5bb0649 100644 --- a/toolkit/sampler.py +++ b/toolkit/sampler.py @@ -88,6 +88,18 @@ flux_config = { "use_dynamic_shifting": True } +sd_flow_config = { + "_class_name": "FlowMatchEulerDiscreteScheduler", + "_diffusers_version": "0.30.0.dev0", + "base_image_seq_len": 256, + "base_shift": 0.5, + "max_image_seq_len": 4096, + "max_shift": 1.15, + "num_train_timesteps": 1000, + "shift": 3.0, + "use_dynamic_shifting": False +} + def get_sampler( sampler: str, @@ -133,6 +145,8 @@ def get_sampler( elif sampler == "flowmatch": scheduler_cls = CustomFlowMatchEulerDiscreteScheduler config_to_use = copy.deepcopy(flux_config) + if arch == "sd": + config_to_use = copy.deepcopy(sd_flow_config) else: raise ValueError(f"Sampler {sampler} not supported") diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 7bdc586b..c990113b 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -974,12 +974,17 @@ class StableDiffusion: "prediction_type": self.prediction_type, }) else: + arch = 'sd' + if self.model_config.is_pixart: + arch = 'pixart' + if self.model_config.is_flux: + arch = 'flux' noise_scheduler = get_sampler( sampler, { "prediction_type": self.prediction_type, }, - 'sd' if not self.is_pixart else 'pixart' + arch ) try: