diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 5fd4dd76..f198e126 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -31,7 +31,7 @@ from jobs.process import BaseSDTrainProcess from torchvision import transforms from diffusers import EMAModel import math - +from toolkit.train_tools import precondition_model_outputs_flow_match def flush(): @@ -328,14 +328,24 @@ class SDTrainer(BaseSDTrainProcess): # v-parameterization training target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps) - elif self.sd.is_rectified_flow: + elif self.sd.is_flow_matching: # only if preconditioning model outputs # if not preconditioning, (target = noise - batch.latents) - - # target = noise - batch.latents - # if preconditioning outputs, target latents + # model_pred = model_pred * (-sigmas) + noisy_model_input + if self.train_config.target_noise_multiplier != 1.0: + # we are adjusting the target noise, need to recompute the noisy latents with + # the noise adjusted above + with torch.no_grad(): + noisy_latents = self.sd.add_noise(batch.latents, noise, timesteps).detach() + + noise_pred = precondition_model_outputs_flow_match( + noise_pred, + noisy_latents, + timesteps, + self.sd.noise_scheduler + ) target = batch.latents.detach() else: target = noise @@ -383,7 +393,7 @@ class SDTrainer(BaseSDTrainProcess): loss = loss_per_element else: # handle flow matching ref https://github.com/huggingface/diffusers/blob/ec068f9b5bf7c65f93125ec889e0ff1792a00da1/examples/dreambooth/train_dreambooth_lora_sd3.py#L1485C17-L1495C100 - if self.sd.is_rectified_flow and prior_pred is None: + if self.sd.is_flow_matching and prior_pred is None: # outputs should be preprocessed latents sigmas = self.sd.noise_scheduler.get_sigmas(timesteps, pred.ndim, dtype, self.device_torch) weighting = torch.ones_like(sigmas) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index e1dba7b8..853e8177 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -959,7 +959,7 @@ class BaseSDTrainProcess(BaseTrainProcess): raise ValueError(f"Unknown content_or_style {content_or_style}") # do flow matching - # if self.sd.is_rectified_flow: + # if self.sd.is_flow_matching: # u = compute_density_for_timestep_sampling( # weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"] # batch_size=batch_size, diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 0f6d23c2..f3426714 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -166,9 +166,9 @@ class StableDiffusion: self.config_file = None - self.is_rectified_flow = False + self.is_flow_matching = False if self.is_flux or self.is_v3 or self.is_auraflow: - self.is_rectified_flow = True + self.is_flow_matching = True def load_model(self): if self.is_loaded: @@ -1337,20 +1337,6 @@ class StableDiffusion: ) return torch.cat(out_chunks, dim=0) - def precondition_model_outputs_sd3(model_output, model_input, timestep_tensor): - mo_chunks = torch.chunk(model_output, model_output.shape[0], dim=0) - mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) - timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0) - out_chunks = [] - # unsqueeze if timestep is zero dim - for idx in range(model_output.shape[0]): - sigmas = self.noise_scheduler.get_sigmas(timestep_chunks[idx], n_dim=model_output.ndim, dtype=model_output.dtype, device=model_output.device) - # Follow: Section 5 of https://arxiv.org/abs/2206.00364. - # Preconditioning of the model outputs. - out = mo_chunks[idx] * (-sigmas) + mi_chunks[idx] - out_chunks.append(out) - return torch.cat(out_chunks, dim=0) - if self.is_xl: with torch.no_grad(): # 16, 6 for bs of 4 @@ -1614,8 +1600,6 @@ class StableDiffusion: width=width, # 1024 vae_scale_factor=VAE_SCALE_FACTOR * 2, # should be 16 not sure why ) - - noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep) elif self.is_v3: noise_pred = self.unet( hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), @@ -1624,7 +1608,6 @@ class StableDiffusion: pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype), **kwargs, ).sample - noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep) elif self.is_auraflow: # aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image # broadcast to batch dimension in a way that's compatible with ONNX/Core ML diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index 83b88444..6b906fff 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -1008,3 +1008,19 @@ def apply_snr_weight( snr_adjusted_loss = loss * snr_weight return snr_adjusted_loss + + +def precondition_model_outputs_flow_match(model_output, model_input, timestep_tensor, noise_scheduler): + mo_chunks = torch.chunk(model_output, model_output.shape[0], dim=0) + mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) + timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0) + out_chunks = [] + # unsqueeze if timestep is zero dim + for idx in range(model_output.shape[0]): + sigmas = noise_scheduler.get_sigmas(timestep_chunks[idx], n_dim=model_output.ndim, + dtype=model_output.dtype, device=model_output.device) + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + # Preconditioning of the model outputs. + out = mo_chunks[idx] * (-sigmas) + mi_chunks[idx] + out_chunks.append(out) + return torch.cat(out_chunks, dim=0)