mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
Added support for full finetuning flux with randomized param activation. Examples coming soon
This commit is contained in:
@@ -56,8 +56,9 @@ import gc
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
|
||||
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig
|
||||
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig, validate_configs
|
||||
from toolkit.logging import create_logger
|
||||
from diffusers import FluxTransformer2DModel
|
||||
|
||||
def flush():
|
||||
torch.cuda.empty_cache()
|
||||
@@ -201,6 +202,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.named_lora = True
|
||||
self.snr_gos: Union[LearnableSNRGamma, None] = None
|
||||
self.ema: ExponentialMovingAverage = None
|
||||
|
||||
validate_configs(self.train_config, self.model_config, self.save_config)
|
||||
|
||||
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
|
||||
# override in subclass
|
||||
@@ -587,9 +590,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def hook_before_train_loop(self):
|
||||
self.logger.start()
|
||||
|
||||
def ensure_params_requires_grad(self):
|
||||
# get param groups
|
||||
# for group in self.optimizer.param_groups:
|
||||
def ensure_params_requires_grad(self, force=False):
|
||||
if self.train_config.do_paramiter_swapping and not force:
|
||||
# the optimizer will handle this if we are not forcing
|
||||
return
|
||||
for group in self.params:
|
||||
for param in group['params']:
|
||||
if isinstance(param, torch.nn.Parameter): # Ensure it's a proper parameter
|
||||
@@ -1278,6 +1282,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
torch.backends.cuda.enable_flash_sdp(True)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||
|
||||
# # check if we have sage and is flux
|
||||
# if self.sd.is_flux:
|
||||
# # try_to_activate_sage_attn()
|
||||
# try:
|
||||
# from sageattention import sageattn
|
||||
# from toolkit.models.flux_sage_attn import FluxSageAttnProcessor2_0
|
||||
# model: FluxTransformer2DModel = self.sd.unet
|
||||
# # enable sage attention on each block
|
||||
# for block in model.transformer_blocks:
|
||||
# processor = FluxSageAttnProcessor2_0()
|
||||
# block.attn.set_processor(processor)
|
||||
# for block in model.single_transformer_blocks:
|
||||
# processor = FluxSageAttnProcessor2_0()
|
||||
# block.attn.set_processor(processor)
|
||||
|
||||
# except ImportError:
|
||||
# print("sage attention is not installed. Using SDP instead")
|
||||
|
||||
if self.train_config.gradient_checkpointing:
|
||||
if self.sd.is_flux:
|
||||
@@ -1539,10 +1561,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
optimizer_type = self.train_config.optimizer.lower()
|
||||
|
||||
# esure params require grad
|
||||
self.ensure_params_requires_grad()
|
||||
self.ensure_params_requires_grad(force=True)
|
||||
optimizer = get_optimizer(self.params, optimizer_type, learning_rate=self.train_config.lr,
|
||||
optimizer_params=self.train_config.optimizer_params)
|
||||
self.optimizer = optimizer
|
||||
|
||||
# set it to do paramiter swapping
|
||||
if self.train_config.do_paramiter_swapping:
|
||||
# only works for adafactor, but it should have thrown an error prior to this otherwise
|
||||
self.optimizer.enable_paramiter_swapping(self.train_config.paramiter_swapping_factor)
|
||||
|
||||
# check if it exists
|
||||
optimizer_state_filename = f'optimizer.pt'
|
||||
@@ -1648,7 +1675,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# torch.compile(self.sd.unet, dynamic=True)
|
||||
|
||||
# make sure all params require grad
|
||||
self.ensure_params_requires_grad()
|
||||
self.ensure_params_requires_grad(force=True)
|
||||
|
||||
|
||||
###################################################################
|
||||
@@ -1659,6 +1686,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
start_step_num = self.step_num
|
||||
did_first_flush = False
|
||||
for step in range(start_step_num, self.train_config.steps):
|
||||
if self.train_config.do_paramiter_swapping:
|
||||
self.optimizer.swap_paramiters()
|
||||
self.timer.start('train_loop')
|
||||
if self.train_config.do_random_cfg:
|
||||
self.train_config.do_cfg = True
|
||||
@@ -1738,6 +1767,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# flush()
|
||||
### HOOK ###
|
||||
|
||||
loss_dict = self.hook_train_loop(batch_list)
|
||||
self.timer.stop('train_loop')
|
||||
if not did_first_flush:
|
||||
|
||||
Reference in New Issue
Block a user