diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index ae0bb153..08089d85 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -327,6 +327,11 @@ class SDTrainer(BaseSDTrainProcess): elif self.sd.prediction_type == 'v_prediction': # v-parameterization training target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps) + + elif self.sd.is_rectified_flow: + # only if preconditioning model outputs + # if not preconditioning, (target = noise - batch.latents) is used + target = batch.latents.detach() else: target = noise @@ -373,26 +378,10 @@ class SDTrainer(BaseSDTrainProcess): loss = loss_per_element else: # handle flow matching ref https://github.com/huggingface/diffusers/blob/ec068f9b5bf7c65f93125ec889e0ff1792a00da1/examples/dreambooth/train_dreambooth_lora_sd3.py#L1485C17-L1495C100 - if self.sd.is_v3: - target = noisy_latents.detach() - bsz = pred.shape[0] - # todo implement others - # weighing_scheme = - # 3 just do mode for now? - # if args.weighting_scheme == "sigma_sqrt": + if self.sd.is_rectified_flow and prior_pred is None: + # outputs should be preprocessed latents sigmas = self.sd.noise_scheduler.get_sigmas(timesteps, pred.ndim, dtype, self.device_torch) - # weighting = (sigmas ** -2.0).float() weighting = torch.ones_like(sigmas) - # elif args.weighting_scheme == "logit_normal": - # # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - # u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device) - # weighting = torch.nn.functional.sigmoid(u) - # elif args.weighting_scheme == "mode": - # mode_scale = 1.29 - # See sec 3.1 in the SD3 paper (20). - # u = torch.rand(size=(bsz,), device=pred.device) - # weighting = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) - loss = (weighting.float() * (pred.float() - target.float()) ** 2).reshape(target.shape[0], -1) elif self.train_config.loss_type == "mae": diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index cd90808b..fd89ce37 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -11,6 +11,7 @@ from typing import Union, List, Optional import numpy as np import yaml from diffusers import T2IAdapter, ControlNetModel +from diffusers.training_utils import compute_density_for_timestep_sampling from safetensors.torch import save_file, load_file # from lycoris.config import PRESET from torch.utils.data import DataLoader @@ -957,6 +958,16 @@ class BaseSDTrainProcess(BaseTrainProcess): else: raise ValueError(f"Unknown content_or_style {content_or_style}") + # do flow matching + if self.sd.is_rectified_flow: + u = compute_density_for_timestep_sampling( + weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"] + batch_size=batch_size, + logit_mean=0.0, + logit_std=1.0, + mode_scale=1.29, + ) + timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long() # convert the timestep_indices to a timestep timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices] timesteps = torch.stack(timesteps, dim=0) diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py index 1d5750ad..7a513758 100644 --- a/toolkit/samplers/custom_flowmatch_sampler.py +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -23,9 +23,16 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: + ## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578 + ## Add noise according to flow matching. + ## zt = (1 - texp) * x + texp * z1 + + # sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + # noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + n_dim = original_samples.ndim sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device) - noisy_model_input = sigmas * noise + (1.0 - sigmas) * original_samples + noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise return noisy_model_input def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 8e7c2c2a..ca329cf5 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -166,6 +166,10 @@ class StableDiffusion: self.config_file = None + self.is_rectified_flow = False + if self.is_flux or self.is_v3 or self.is_auraflow: + self.is_rectified_flow = True + def load_model(self): if self.is_loaded: return @@ -448,7 +452,7 @@ class StableDiffusion: elif self.model_config.is_flux: print("Loading Flux model") - base_model_path = "/home/jaret/Dev/models/hf/FLUX.1-schnell/" + base_model_path = "black-forest-labs/FLUX.1-schnell" scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler") print("Loading vae") vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype) @@ -1223,14 +1227,6 @@ class StableDiffusion: noise: torch.FloatTensor, timesteps: torch.IntTensor ) -> torch.FloatTensor: - # we handle adding noise for the various schedulers here. Some - # schedulers reference timesteps while others reference idx - # so we need to handle both cases - # get scheduler class name - scheduler_class_name = self.noise_scheduler.__class__.__name__ - - # todo handle if timestep is single value - original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0) noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) @@ -1582,7 +1578,7 @@ class StableDiffusion: # sigmas, # mu=mu, # ) - latent_model_input = self.pipeline._pack_latents( + latent_model_input_packed = self.pipeline._pack_latents( latent_model_input, batch_size=latent_model_input.shape[0], num_channels_latents=latent_model_input.shape[1], # 16 @@ -1592,7 +1588,7 @@ class StableDiffusion: noise_pred = self.unet( - hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), # [1, 4096, 64] + hidden_states=latent_model_input_packed.to(self.device_torch, self.torch_dtype), # [1, 4096, 64] # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) # todo make sure this doesnt change timestep=timestep / 1000, # timestep is 1000 scale @@ -1609,9 +1605,12 @@ class StableDiffusion: noise_pred = self.pipeline._unpack_latents( noise_pred, height=height, # 1024 - width=height, # 1024 + width=width, # 1024 vae_scale_factor=VAE_SCALE_FACTOR * 2, # should be 16 not sure why ) + + # todo we do this on sd3 training. I think we do it here too? No paper + noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep) elif self.is_v3: noise_pred = self.unet( hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), @@ -2039,25 +2038,25 @@ class StableDiffusion: if unet: if self.is_flux: # Just train the middle 2 blocks of each transformer block - block_list = [] - num_transformer_blocks = 2 - start_block = len(self.unet.transformer_blocks) // 2 - (num_transformer_blocks // 2) - for i in range(num_transformer_blocks): - block_list.append(self.unet.transformer_blocks[start_block + i]) + # block_list = [] + # num_transformer_blocks = 2 + # start_block = len(self.unet.transformer_blocks) // 2 - (num_transformer_blocks // 2) + # for i in range(num_transformer_blocks): + # block_list.append(self.unet.transformer_blocks[start_block + i]) + # + # num_single_transformer_blocks = 4 + # start_block = len(self.unet.single_transformer_blocks) // 2 - (num_single_transformer_blocks // 2) + # for i in range(num_single_transformer_blocks): + # block_list.append(self.unet.single_transformer_blocks[start_block + i]) + # + # for block in block_list: + # for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): + # named_params[name] = param - num_single_transformer_blocks = 4 - start_block = len(self.unet.single_transformer_blocks) // 2 - (num_single_transformer_blocks // 2) - for i in range(num_single_transformer_blocks): - block_list.append(self.unet.single_transformer_blocks[start_block + i]) - - for block in block_list: - for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): - named_params[name] = param - - # for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): - # named_params[name] = param - # for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): - # named_params[name] = param + for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): + named_params[name] = param + for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): + named_params[name] = param else: for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): named_params[name] = param