mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Flux training should work now... maybe
This commit is contained in:
@@ -327,6 +327,11 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
elif self.sd.prediction_type == 'v_prediction':
|
||||
# v-parameterization training
|
||||
target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps)
|
||||
|
||||
elif self.sd.is_rectified_flow:
|
||||
# only if preconditioning model outputs
|
||||
# if not preconditioning, (target = noise - batch.latents) is used
|
||||
target = batch.latents.detach()
|
||||
else:
|
||||
target = noise
|
||||
|
||||
@@ -373,26 +378,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss = loss_per_element
|
||||
else:
|
||||
# handle flow matching ref https://github.com/huggingface/diffusers/blob/ec068f9b5bf7c65f93125ec889e0ff1792a00da1/examples/dreambooth/train_dreambooth_lora_sd3.py#L1485C17-L1495C100
|
||||
if self.sd.is_v3:
|
||||
target = noisy_latents.detach()
|
||||
bsz = pred.shape[0]
|
||||
# todo implement others
|
||||
# weighing_scheme =
|
||||
# 3 just do mode for now?
|
||||
# if args.weighting_scheme == "sigma_sqrt":
|
||||
if self.sd.is_rectified_flow and prior_pred is None:
|
||||
# outputs should be preprocessed latents
|
||||
sigmas = self.sd.noise_scheduler.get_sigmas(timesteps, pred.ndim, dtype, self.device_torch)
|
||||
# weighting = (sigmas ** -2.0).float()
|
||||
weighting = torch.ones_like(sigmas)
|
||||
# elif args.weighting_scheme == "logit_normal":
|
||||
# # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||
# u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device)
|
||||
# weighting = torch.nn.functional.sigmoid(u)
|
||||
# elif args.weighting_scheme == "mode":
|
||||
# mode_scale = 1.29
|
||||
# See sec 3.1 in the SD3 paper (20).
|
||||
# u = torch.rand(size=(bsz,), device=pred.device)
|
||||
# weighting = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
||||
|
||||
loss = (weighting.float() * (pred.float() - target.float()) ** 2).reshape(target.shape[0], -1)
|
||||
|
||||
elif self.train_config.loss_type == "mae":
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Union, List, Optional
|
||||
import numpy as np
|
||||
import yaml
|
||||
from diffusers import T2IAdapter, ControlNetModel
|
||||
from diffusers.training_utils import compute_density_for_timestep_sampling
|
||||
from safetensors.torch import save_file, load_file
|
||||
# from lycoris.config import PRESET
|
||||
from torch.utils.data import DataLoader
|
||||
@@ -957,6 +958,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
else:
|
||||
raise ValueError(f"Unknown content_or_style {content_or_style}")
|
||||
|
||||
# do flow matching
|
||||
if self.sd.is_rectified_flow:
|
||||
u = compute_density_for_timestep_sampling(
|
||||
weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"]
|
||||
batch_size=batch_size,
|
||||
logit_mean=0.0,
|
||||
logit_std=1.0,
|
||||
mode_scale=1.29,
|
||||
)
|
||||
timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long()
|
||||
# convert the timestep_indices to a timestep
|
||||
timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
|
||||
timesteps = torch.stack(timesteps, dim=0)
|
||||
|
||||
@@ -23,9 +23,16 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
|
||||
## Add noise according to flow matching.
|
||||
## zt = (1 - texp) * x + texp * z1
|
||||
|
||||
# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
|
||||
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
||||
|
||||
n_dim = original_samples.ndim
|
||||
sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
|
||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * original_samples
|
||||
noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
|
||||
return noisy_model_input
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
|
||||
@@ -166,6 +166,10 @@ class StableDiffusion:
|
||||
|
||||
self.config_file = None
|
||||
|
||||
self.is_rectified_flow = False
|
||||
if self.is_flux or self.is_v3 or self.is_auraflow:
|
||||
self.is_rectified_flow = True
|
||||
|
||||
def load_model(self):
|
||||
if self.is_loaded:
|
||||
return
|
||||
@@ -448,7 +452,7 @@ class StableDiffusion:
|
||||
|
||||
elif self.model_config.is_flux:
|
||||
print("Loading Flux model")
|
||||
base_model_path = "/home/jaret/Dev/models/hf/FLUX.1-schnell/"
|
||||
base_model_path = "black-forest-labs/FLUX.1-schnell"
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
|
||||
print("Loading vae")
|
||||
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||
@@ -1223,14 +1227,6 @@ class StableDiffusion:
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor
|
||||
) -> torch.FloatTensor:
|
||||
# we handle adding noise for the various schedulers here. Some
|
||||
# schedulers reference timesteps while others reference idx
|
||||
# so we need to handle both cases
|
||||
# get scheduler class name
|
||||
scheduler_class_name = self.noise_scheduler.__class__.__name__
|
||||
|
||||
# todo handle if timestep is single value
|
||||
|
||||
original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0)
|
||||
noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
|
||||
timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0)
|
||||
@@ -1582,7 +1578,7 @@ class StableDiffusion:
|
||||
# sigmas,
|
||||
# mu=mu,
|
||||
# )
|
||||
latent_model_input = self.pipeline._pack_latents(
|
||||
latent_model_input_packed = self.pipeline._pack_latents(
|
||||
latent_model_input,
|
||||
batch_size=latent_model_input.shape[0],
|
||||
num_channels_latents=latent_model_input.shape[1], # 16
|
||||
@@ -1592,7 +1588,7 @@ class StableDiffusion:
|
||||
|
||||
|
||||
noise_pred = self.unet(
|
||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), # [1, 4096, 64]
|
||||
hidden_states=latent_model_input_packed.to(self.device_torch, self.torch_dtype), # [1, 4096, 64]
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
# todo make sure this doesnt change
|
||||
timestep=timestep / 1000, # timestep is 1000 scale
|
||||
@@ -1609,9 +1605,12 @@ class StableDiffusion:
|
||||
noise_pred = self.pipeline._unpack_latents(
|
||||
noise_pred,
|
||||
height=height, # 1024
|
||||
width=height, # 1024
|
||||
width=width, # 1024
|
||||
vae_scale_factor=VAE_SCALE_FACTOR * 2, # should be 16 not sure why
|
||||
)
|
||||
|
||||
# todo we do this on sd3 training. I think we do it here too? No paper
|
||||
noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep)
|
||||
elif self.is_v3:
|
||||
noise_pred = self.unet(
|
||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
@@ -2039,25 +2038,25 @@ class StableDiffusion:
|
||||
if unet:
|
||||
if self.is_flux:
|
||||
# Just train the middle 2 blocks of each transformer block
|
||||
block_list = []
|
||||
num_transformer_blocks = 2
|
||||
start_block = len(self.unet.transformer_blocks) // 2 - (num_transformer_blocks // 2)
|
||||
for i in range(num_transformer_blocks):
|
||||
block_list.append(self.unet.transformer_blocks[start_block + i])
|
||||
# block_list = []
|
||||
# num_transformer_blocks = 2
|
||||
# start_block = len(self.unet.transformer_blocks) // 2 - (num_transformer_blocks // 2)
|
||||
# for i in range(num_transformer_blocks):
|
||||
# block_list.append(self.unet.transformer_blocks[start_block + i])
|
||||
#
|
||||
# num_single_transformer_blocks = 4
|
||||
# start_block = len(self.unet.single_transformer_blocks) // 2 - (num_single_transformer_blocks // 2)
|
||||
# for i in range(num_single_transformer_blocks):
|
||||
# block_list.append(self.unet.single_transformer_blocks[start_block + i])
|
||||
#
|
||||
# for block in block_list:
|
||||
# for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
# named_params[name] = param
|
||||
|
||||
num_single_transformer_blocks = 4
|
||||
start_block = len(self.unet.single_transformer_blocks) // 2 - (num_single_transformer_blocks // 2)
|
||||
for i in range(num_single_transformer_blocks):
|
||||
block_list.append(self.unet.single_transformer_blocks[start_block + i])
|
||||
|
||||
for block in block_list:
|
||||
for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
named_params[name] = param
|
||||
|
||||
# for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
# named_params[name] = param
|
||||
# for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
# named_params[name] = param
|
||||
for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
named_params[name] = param
|
||||
for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
named_params[name] = param
|
||||
else:
|
||||
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
named_params[name] = param
|
||||
|
||||
Reference in New Issue
Block a user