mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 17:51:41 +00:00
Added support for full finetuning flux with randomized param activation. Examples coming soon
This commit is contained in:
@@ -56,8 +56,9 @@ import gc
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
|
from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
|
||||||
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig
|
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig, validate_configs
|
||||||
from toolkit.logging import create_logger
|
from toolkit.logging import create_logger
|
||||||
|
from diffusers import FluxTransformer2DModel
|
||||||
|
|
||||||
def flush():
|
def flush():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@@ -201,6 +202,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
self.named_lora = True
|
self.named_lora = True
|
||||||
self.snr_gos: Union[LearnableSNRGamma, None] = None
|
self.snr_gos: Union[LearnableSNRGamma, None] = None
|
||||||
self.ema: ExponentialMovingAverage = None
|
self.ema: ExponentialMovingAverage = None
|
||||||
|
|
||||||
|
validate_configs(self.train_config, self.model_config, self.save_config)
|
||||||
|
|
||||||
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
|
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
|
||||||
# override in subclass
|
# override in subclass
|
||||||
@@ -587,9 +590,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
def hook_before_train_loop(self):
|
def hook_before_train_loop(self):
|
||||||
self.logger.start()
|
self.logger.start()
|
||||||
|
|
||||||
def ensure_params_requires_grad(self):
|
def ensure_params_requires_grad(self, force=False):
|
||||||
# get param groups
|
if self.train_config.do_paramiter_swapping and not force:
|
||||||
# for group in self.optimizer.param_groups:
|
# the optimizer will handle this if we are not forcing
|
||||||
|
return
|
||||||
for group in self.params:
|
for group in self.params:
|
||||||
for param in group['params']:
|
for param in group['params']:
|
||||||
if isinstance(param, torch.nn.Parameter): # Ensure it's a proper parameter
|
if isinstance(param, torch.nn.Parameter): # Ensure it's a proper parameter
|
||||||
@@ -1278,6 +1282,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
torch.backends.cuda.enable_math_sdp(True)
|
torch.backends.cuda.enable_math_sdp(True)
|
||||||
torch.backends.cuda.enable_flash_sdp(True)
|
torch.backends.cuda.enable_flash_sdp(True)
|
||||||
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||||
|
|
||||||
|
# # check if we have sage and is flux
|
||||||
|
# if self.sd.is_flux:
|
||||||
|
# # try_to_activate_sage_attn()
|
||||||
|
# try:
|
||||||
|
# from sageattention import sageattn
|
||||||
|
# from toolkit.models.flux_sage_attn import FluxSageAttnProcessor2_0
|
||||||
|
# model: FluxTransformer2DModel = self.sd.unet
|
||||||
|
# # enable sage attention on each block
|
||||||
|
# for block in model.transformer_blocks:
|
||||||
|
# processor = FluxSageAttnProcessor2_0()
|
||||||
|
# block.attn.set_processor(processor)
|
||||||
|
# for block in model.single_transformer_blocks:
|
||||||
|
# processor = FluxSageAttnProcessor2_0()
|
||||||
|
# block.attn.set_processor(processor)
|
||||||
|
|
||||||
|
# except ImportError:
|
||||||
|
# print("sage attention is not installed. Using SDP instead")
|
||||||
|
|
||||||
if self.train_config.gradient_checkpointing:
|
if self.train_config.gradient_checkpointing:
|
||||||
if self.sd.is_flux:
|
if self.sd.is_flux:
|
||||||
@@ -1539,10 +1561,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
optimizer_type = self.train_config.optimizer.lower()
|
optimizer_type = self.train_config.optimizer.lower()
|
||||||
|
|
||||||
# esure params require grad
|
# esure params require grad
|
||||||
self.ensure_params_requires_grad()
|
self.ensure_params_requires_grad(force=True)
|
||||||
optimizer = get_optimizer(self.params, optimizer_type, learning_rate=self.train_config.lr,
|
optimizer = get_optimizer(self.params, optimizer_type, learning_rate=self.train_config.lr,
|
||||||
optimizer_params=self.train_config.optimizer_params)
|
optimizer_params=self.train_config.optimizer_params)
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
|
|
||||||
|
# set it to do paramiter swapping
|
||||||
|
if self.train_config.do_paramiter_swapping:
|
||||||
|
# only works for adafactor, but it should have thrown an error prior to this otherwise
|
||||||
|
self.optimizer.enable_paramiter_swapping(self.train_config.paramiter_swapping_factor)
|
||||||
|
|
||||||
# check if it exists
|
# check if it exists
|
||||||
optimizer_state_filename = f'optimizer.pt'
|
optimizer_state_filename = f'optimizer.pt'
|
||||||
@@ -1648,7 +1675,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# torch.compile(self.sd.unet, dynamic=True)
|
# torch.compile(self.sd.unet, dynamic=True)
|
||||||
|
|
||||||
# make sure all params require grad
|
# make sure all params require grad
|
||||||
self.ensure_params_requires_grad()
|
self.ensure_params_requires_grad(force=True)
|
||||||
|
|
||||||
|
|
||||||
###################################################################
|
###################################################################
|
||||||
@@ -1659,6 +1686,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
start_step_num = self.step_num
|
start_step_num = self.step_num
|
||||||
did_first_flush = False
|
did_first_flush = False
|
||||||
for step in range(start_step_num, self.train_config.steps):
|
for step in range(start_step_num, self.train_config.steps):
|
||||||
|
if self.train_config.do_paramiter_swapping:
|
||||||
|
self.optimizer.swap_paramiters()
|
||||||
self.timer.start('train_loop')
|
self.timer.start('train_loop')
|
||||||
if self.train_config.do_random_cfg:
|
if self.train_config.do_random_cfg:
|
||||||
self.train_config.do_cfg = True
|
self.train_config.do_cfg = True
|
||||||
@@ -1738,6 +1767,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
|
|
||||||
# flush()
|
# flush()
|
||||||
### HOOK ###
|
### HOOK ###
|
||||||
|
|
||||||
loss_dict = self.hook_train_loop(batch_list)
|
loss_dict = self.hook_train_loop(batch_list)
|
||||||
self.timer.stop('train_loop')
|
self.timer.stop('train_loop')
|
||||||
if not did_first_flush:
|
if not did_first_flush:
|
||||||
|
|||||||
@@ -389,6 +389,10 @@ class TrainConfig:
|
|||||||
# will cache a blank prompt or the trigger word, and unload the text encoder to cpu
|
# will cache a blank prompt or the trigger word, and unload the text encoder to cpu
|
||||||
# will make training faster and use less vram
|
# will make training faster and use less vram
|
||||||
self.unload_text_encoder = kwargs.get('unload_text_encoder', False)
|
self.unload_text_encoder = kwargs.get('unload_text_encoder', False)
|
||||||
|
# for swapping which parameters are trained during training
|
||||||
|
self.do_paramiter_swapping = kwargs.get('do_paramiter_swapping', False)
|
||||||
|
# 0.1 is 10% of the parameters active at a time lower is less vram, higher is more
|
||||||
|
self.paramiter_swapping_factor = kwargs.get('paramiter_swapping_factor', 0.1)
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
@@ -898,4 +902,16 @@ class GenerateImageConfig:
|
|||||||
if self.logger is None:
|
if self.logger is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.logger.log_image(image, count, self.prompt)
|
self.logger.log_image(image, count, self.prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_configs(
|
||||||
|
train_config: TrainConfig,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
save_config: SaveConfig,
|
||||||
|
):
|
||||||
|
if model_config.is_flux:
|
||||||
|
if save_config.save_format != 'diffusers':
|
||||||
|
# make it diffusers
|
||||||
|
save_config.save_format = 'diffusers'
|
||||||
|
|
||||||
94
toolkit/models/flux_sage_attn.py
Normal file
94
toolkit/models/flux_sage_attn.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from diffusers.models.attention_processor import Attention
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class FluxSageAttnProcessor2_0:
|
||||||
|
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
|
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn: Attention,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
encoder_hidden_states: torch.FloatTensor = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
from sageattention import sageattn
|
||||||
|
|
||||||
|
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
|
|
||||||
|
# `sample` projections.
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
key = attn.to_k(hidden_states)
|
||||||
|
value = attn.to_v(hidden_states)
|
||||||
|
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
|
||||||
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if attn.norm_q is not None:
|
||||||
|
query = attn.norm_q(query)
|
||||||
|
if attn.norm_k is not None:
|
||||||
|
key = attn.norm_k(key)
|
||||||
|
|
||||||
|
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
# `context` projections.
|
||||||
|
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||||
|
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||||
|
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||||
|
|
||||||
|
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||||
|
batch_size, -1, attn.heads, head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||||
|
batch_size, -1, attn.heads, head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||||
|
batch_size, -1, attn.heads, head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
if attn.norm_added_q is not None:
|
||||||
|
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||||
|
if attn.norm_added_k is not None:
|
||||||
|
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
||||||
|
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||||
|
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||||
|
|
||||||
|
if image_rotary_emb is not None:
|
||||||
|
from diffusers.models.embeddings import apply_rotary_emb
|
||||||
|
|
||||||
|
query = apply_rotary_emb(query, image_rotary_emb)
|
||||||
|
key = apply_rotary_emb(key, image_rotary_emb)
|
||||||
|
|
||||||
|
hidden_states = sageattn(query, key, value, dropout_p=0.0, is_causal=False)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
encoder_hidden_states, hidden_states = (
|
||||||
|
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||||
|
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||||
|
)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, encoder_hidden_states
|
||||||
|
else:
|
||||||
|
return hidden_states
|
||||||
@@ -3,6 +3,7 @@ from typing import List
|
|||||||
import torch
|
import torch
|
||||||
from toolkit.optimizers.optimizer_utils import copy_stochastic, stochastic_grad_accummulation
|
from toolkit.optimizers.optimizer_utils import copy_stochastic, stochastic_grad_accummulation
|
||||||
from optimum.quanto import QBytesTensor
|
from optimum.quanto import QBytesTensor
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
class Adafactor(torch.optim.Optimizer):
|
class Adafactor(torch.optim.Optimizer):
|
||||||
@@ -105,6 +106,8 @@ class Adafactor(torch.optim.Optimizer):
|
|||||||
scale_parameter=True,
|
scale_parameter=True,
|
||||||
relative_step=True,
|
relative_step=True,
|
||||||
warmup_init=False,
|
warmup_init=False,
|
||||||
|
do_paramiter_swapping=False,
|
||||||
|
paramiter_swapping_factor=0.1,
|
||||||
):
|
):
|
||||||
if lr is not None and relative_step:
|
if lr is not None and relative_step:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -140,6 +143,49 @@ class Adafactor(torch.optim.Optimizer):
|
|||||||
param.register_post_accumulate_grad_hook(
|
param.register_post_accumulate_grad_hook(
|
||||||
stochastic_grad_accummulation
|
stochastic_grad_accummulation
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.do_paramiter_swapping = do_paramiter_swapping
|
||||||
|
self.paramiter_swapping_factor = paramiter_swapping_factor
|
||||||
|
self._total_paramiter_size = 0
|
||||||
|
# count total paramiters
|
||||||
|
for group in self.param_groups:
|
||||||
|
for param in group['params']:
|
||||||
|
self._total_paramiter_size += torch.numel(param)
|
||||||
|
# pretty print total paramiters with comma seperation
|
||||||
|
print(f"Total training paramiters: {self._total_paramiter_size:,}")
|
||||||
|
|
||||||
|
# needs to be enabled to count paramiters
|
||||||
|
if self.do_paramiter_swapping:
|
||||||
|
self.enable_paramiter_swapping(self.paramiter_swapping_factor)
|
||||||
|
|
||||||
|
|
||||||
|
def enable_paramiter_swapping(self, paramiter_swapping_factor=0.1):
|
||||||
|
self.do_paramiter_swapping = True
|
||||||
|
self.paramiter_swapping_factor = paramiter_swapping_factor
|
||||||
|
# call it an initial time
|
||||||
|
self.swap_paramiters()
|
||||||
|
|
||||||
|
def swap_paramiters(self):
|
||||||
|
all_params = []
|
||||||
|
# deactivate all paramiters
|
||||||
|
for group in self.param_groups:
|
||||||
|
for param in group['params']:
|
||||||
|
param.requires_grad_(False)
|
||||||
|
# remove any grad
|
||||||
|
param.grad = None
|
||||||
|
all_params.append(param)
|
||||||
|
# shuffle all paramiters
|
||||||
|
random.shuffle(all_params)
|
||||||
|
|
||||||
|
# keep activating paramiters until we are going to go over the target paramiters
|
||||||
|
target_paramiters = int(self._total_paramiter_size * self.paramiter_swapping_factor)
|
||||||
|
total_paramiters = 0
|
||||||
|
for param in all_params:
|
||||||
|
total_paramiters += torch.numel(param)
|
||||||
|
if total_paramiters >= target_paramiters:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
param.requires_grad_(True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_lr(param_group, param_state):
|
def _get_lr(param_group, param_state):
|
||||||
@@ -209,7 +255,7 @@ class Adafactor(torch.optim.Optimizer):
|
|||||||
|
|
||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
for p in group["params"]:
|
for p in group["params"]:
|
||||||
if p.grad is None:
|
if p.grad is None or not p.requires_grad:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
grad = p.grad
|
grad = p.grad
|
||||||
|
|||||||
Reference in New Issue
Block a user