Flux training should work now... maybe

This commit is contained in:
Jaret Burkett
2024-08-03 09:17:34 -06:00
parent 369aa143bc
commit 9beea1c268
4 changed files with 55 additions and 49 deletions

View File

@@ -327,6 +327,11 @@ class SDTrainer(BaseSDTrainProcess):
elif self.sd.prediction_type == 'v_prediction':
# v-parameterization training
target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps)
elif self.sd.is_rectified_flow:
# only if preconditioning model outputs
# if not preconditioning, (target = noise - batch.latents) is used
target = batch.latents.detach()
else:
target = noise
@@ -373,26 +378,10 @@ class SDTrainer(BaseSDTrainProcess):
loss = loss_per_element
else:
# handle flow matching ref https://github.com/huggingface/diffusers/blob/ec068f9b5bf7c65f93125ec889e0ff1792a00da1/examples/dreambooth/train_dreambooth_lora_sd3.py#L1485C17-L1495C100
if self.sd.is_v3:
target = noisy_latents.detach()
bsz = pred.shape[0]
# todo implement others
# weighing_scheme =
# 3 just do mode for now?
# if args.weighting_scheme == "sigma_sqrt":
if self.sd.is_rectified_flow and prior_pred is None:
# outputs should be preprocessed latents
sigmas = self.sd.noise_scheduler.get_sigmas(timesteps, pred.ndim, dtype, self.device_torch)
# weighting = (sigmas ** -2.0).float()
weighting = torch.ones_like(sigmas)
# elif args.weighting_scheme == "logit_normal":
# # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
# u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device)
# weighting = torch.nn.functional.sigmoid(u)
# elif args.weighting_scheme == "mode":
# mode_scale = 1.29
# See sec 3.1 in the SD3 paper (20).
# u = torch.rand(size=(bsz,), device=pred.device)
# weighting = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
loss = (weighting.float() * (pred.float() - target.float()) ** 2).reshape(target.shape[0], -1)
elif self.train_config.loss_type == "mae":

View File

@@ -11,6 +11,7 @@ from typing import Union, List, Optional
import numpy as np
import yaml
from diffusers import T2IAdapter, ControlNetModel
from diffusers.training_utils import compute_density_for_timestep_sampling
from safetensors.torch import save_file, load_file
# from lycoris.config import PRESET
from torch.utils.data import DataLoader
@@ -957,6 +958,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
else:
raise ValueError(f"Unknown content_or_style {content_or_style}")
# do flow matching
if self.sd.is_rectified_flow:
u = compute_density_for_timestep_sampling(
weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"]
batch_size=batch_size,
logit_mean=0.0,
logit_std=1.0,
mode_scale=1.29,
)
timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long()
# convert the timestep_indices to a timestep
timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
timesteps = torch.stack(timesteps, dim=0)

View File

@@ -23,9 +23,16 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
## Add noise according to flow matching.
## zt = (1 - texp) * x + texp * z1
# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
n_dim = original_samples.ndim
sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * original_samples
noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
return noisy_model_input
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:

View File

@@ -166,6 +166,10 @@ class StableDiffusion:
self.config_file = None
self.is_rectified_flow = False
if self.is_flux or self.is_v3 or self.is_auraflow:
self.is_rectified_flow = True
def load_model(self):
if self.is_loaded:
return
@@ -448,7 +452,7 @@ class StableDiffusion:
elif self.model_config.is_flux:
print("Loading Flux model")
base_model_path = "/home/jaret/Dev/models/hf/FLUX.1-schnell/"
base_model_path = "black-forest-labs/FLUX.1-schnell"
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
print("Loading vae")
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
@@ -1223,14 +1227,6 @@ class StableDiffusion:
noise: torch.FloatTensor,
timesteps: torch.IntTensor
) -> torch.FloatTensor:
# we handle adding noise for the various schedulers here. Some
# schedulers reference timesteps while others reference idx
# so we need to handle both cases
# get scheduler class name
scheduler_class_name = self.noise_scheduler.__class__.__name__
# todo handle if timestep is single value
original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0)
noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0)
@@ -1582,7 +1578,7 @@ class StableDiffusion:
# sigmas,
# mu=mu,
# )
latent_model_input = self.pipeline._pack_latents(
latent_model_input_packed = self.pipeline._pack_latents(
latent_model_input,
batch_size=latent_model_input.shape[0],
num_channels_latents=latent_model_input.shape[1], # 16
@@ -1592,7 +1588,7 @@ class StableDiffusion:
noise_pred = self.unet(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), # [1, 4096, 64]
hidden_states=latent_model_input_packed.to(self.device_torch, self.torch_dtype), # [1, 4096, 64]
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
# todo make sure this doesnt change
timestep=timestep / 1000, # timestep is 1000 scale
@@ -1609,9 +1605,12 @@ class StableDiffusion:
noise_pred = self.pipeline._unpack_latents(
noise_pred,
height=height, # 1024
width=height, # 1024
width=width, # 1024
vae_scale_factor=VAE_SCALE_FACTOR * 2, # should be 16 not sure why
)
# todo we do this on sd3 training. I think we do it here too? No paper
noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep)
elif self.is_v3:
noise_pred = self.unet(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
@@ -2039,25 +2038,25 @@ class StableDiffusion:
if unet:
if self.is_flux:
# Just train the middle 2 blocks of each transformer block
block_list = []
num_transformer_blocks = 2
start_block = len(self.unet.transformer_blocks) // 2 - (num_transformer_blocks // 2)
for i in range(num_transformer_blocks):
block_list.append(self.unet.transformer_blocks[start_block + i])
# block_list = []
# num_transformer_blocks = 2
# start_block = len(self.unet.transformer_blocks) // 2 - (num_transformer_blocks // 2)
# for i in range(num_transformer_blocks):
# block_list.append(self.unet.transformer_blocks[start_block + i])
#
# num_single_transformer_blocks = 4
# start_block = len(self.unet.single_transformer_blocks) // 2 - (num_single_transformer_blocks // 2)
# for i in range(num_single_transformer_blocks):
# block_list.append(self.unet.single_transformer_blocks[start_block + i])
#
# for block in block_list:
# for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
# named_params[name] = param
num_single_transformer_blocks = 4
start_block = len(self.unet.single_transformer_blocks) // 2 - (num_single_transformer_blocks // 2)
for i in range(num_single_transformer_blocks):
block_list.append(self.unet.single_transformer_blocks[start_block + i])
for block in block_list:
for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
named_params[name] = param
# for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
# named_params[name] = param
# for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
# named_params[name] = param
for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
named_params[name] = param
for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
named_params[name] = param
else:
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
named_params[name] = param