From f321de7bdb1cdc22e52e3d52b468f8646336023a Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 4 Aug 2024 10:37:23 -0600 Subject: [PATCH] Setup to retrain guidance embedding for flux. Use defualt timestep distribution for flux --- extensions_built_in/sd_trainer/SDTrainer.py | 7 ++++++- jobs/process/BaseSDTrainProcess.py | 18 +++++++++--------- toolkit/stable_diffusion_model.py | 15 +++++++++++++-- 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 08089d85..5fd4dd76 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -330,7 +330,12 @@ class SDTrainer(BaseSDTrainProcess): elif self.sd.is_rectified_flow: # only if preconditioning model outputs - # if not preconditioning, (target = noise - batch.latents) is used + # if not preconditioning, (target = noise - batch.latents) + + + # target = noise - batch.latents + + # if preconditioning outputs, target latents target = batch.latents.detach() else: target = noise diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index fd89ce37..e1dba7b8 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -959,15 +959,15 @@ class BaseSDTrainProcess(BaseTrainProcess): raise ValueError(f"Unknown content_or_style {content_or_style}") # do flow matching - if self.sd.is_rectified_flow: - u = compute_density_for_timestep_sampling( - weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"] - batch_size=batch_size, - logit_mean=0.0, - logit_std=1.0, - mode_scale=1.29, - ) - timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long() + # if self.sd.is_rectified_flow: + # u = compute_density_for_timestep_sampling( + # weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"] + # batch_size=batch_size, + # logit_mean=0.0, + # logit_std=1.0, + # mode_scale=1.29, + # ) + # timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long() # convert the timestep_indices to a timestep timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices] timesteps = torch.stack(timesteps, dim=0) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index ca329cf5..f48ecdbb 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -464,7 +464,13 @@ class StableDiffusion: subfolder = None transformer_path = os.path.join(transformer_path, 'transformer') - transformer = FluxTransformer2DModel.from_pretrained(transformer_path, subfolder=subfolder, torch_dtype=dtype) + transformer = FluxTransformer2DModel.from_pretrained( + transformer_path, + subfolder=subfolder, + torch_dtype=dtype, + low_cpu_mem_usage=False, + device_map=None + ) transformer.to(self.device_torch, dtype=dtype) flush() @@ -1609,7 +1615,6 @@ class StableDiffusion: vae_scale_factor=VAE_SCALE_FACTOR * 2, # should be 16 not sure why ) - # todo we do this on sd3 training. I think we do it here too? No paper noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep) elif self.is_v3: noise_pred = self.unet( @@ -2053,6 +2058,12 @@ class StableDiffusion: # for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): # named_params[name] = param + # train the guidance embedding + if self.unet.config.guidance_embeds: + transformer: FluxTransformer2DModel = self.unet + for name, param in transformer.time_text_embed.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): + named_params[name] = param + for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): named_params[name] = param for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):