mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Improved lorm extraction and training
This commit is contained in:
@@ -7,6 +7,7 @@ from typing import Union, List
|
||||
|
||||
import numpy as np
|
||||
from diffusers import T2IAdapter
|
||||
from safetensors.torch import save_file, load_file
|
||||
# from lycoris.config import PRESET
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
@@ -18,7 +19,8 @@ from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatc
|
||||
from toolkit.embedding import Embedding
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.lorm import convert_diffusers_unet_to_lorm
|
||||
from toolkit.lorm import convert_diffusers_unet_to_lorm, count_parameters, print_lorm_extract_details, \
|
||||
lorm_ignore_if_contains, lorm_parameter_threshold, LORM_TARGET_REPLACE_MODULE
|
||||
from toolkit.lycoris_special import LycorisSpecialNetwork
|
||||
from toolkit.network_mixins import Network
|
||||
from toolkit.optimizer import get_optimizer
|
||||
@@ -128,6 +130,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
is_training_adapter = self.adapter_config is not None and self.adapter_config.train
|
||||
|
||||
self.do_lorm = self.get_conf('do_lorm', False)
|
||||
self.lorm_extract_mode = self.get_conf('lorm_extract_mode', 'ratio')
|
||||
self.lorm_extract_mode_param = self.get_conf('lorm_extract_mode_param', 0.25)
|
||||
# 'ratio', 0.25)
|
||||
|
||||
# get the device state preset based on what we are training
|
||||
self.train_device_state_preset = get_train_sd_device_state_preset(
|
||||
@@ -300,9 +305,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
file_path = os.path.join(self.save_root, filename)
|
||||
prev_multiplier = self.network.multiplier
|
||||
self.network.multiplier = 1.0
|
||||
if self.network_config.normalize:
|
||||
# apply the normalization
|
||||
self.network.apply_stored_normalizer()
|
||||
|
||||
# if we are doing embedding training as well, add that
|
||||
embedding_dict = self.embedding.state_dict() if self.embedding else None
|
||||
@@ -427,6 +429,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
print("load_weights not implemented for non-network models")
|
||||
return None
|
||||
|
||||
def load_lorm(self):
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
if latest_save_path is not None:
|
||||
# hacky way to reload weights for now
|
||||
# todo, do this
|
||||
state_dict = load_file(latest_save_path, device=self.device)
|
||||
self.sd.unet.load_state_dict(state_dict)
|
||||
|
||||
meta = load_metadata_from_safetensors(latest_save_path)
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||
self.step_num = meta['training_info']['step']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
# def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
|
||||
# self.sd.noise_scheduler.set_timesteps(1000, device=self.device_torch)
|
||||
# sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device_torch, dtype=dtype)
|
||||
@@ -610,7 +627,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
batch.mask_tensor = double_up_tensor(batch.mask_tensor)
|
||||
batch.control_tensor = double_up_tensor(batch.control_tensor)
|
||||
|
||||
|
||||
# remove grads for these
|
||||
noisy_latents.requires_grad = False
|
||||
noisy_latents = noisy_latents.detach()
|
||||
@@ -712,15 +728,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# run base sd process run
|
||||
self.sd.load_model()
|
||||
|
||||
if self.do_lorm:
|
||||
train_modules = convert_diffusers_unet_to_lorm(self.sd.unet, 'ratio', 0.27)
|
||||
for module in train_modules:
|
||||
p = module.parameters()
|
||||
for param in p:
|
||||
param.requires_grad_(True)
|
||||
params.append(param)
|
||||
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
# model is loaded from BaseSDProcess
|
||||
@@ -783,14 +790,20 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if not self.is_fine_tuning:
|
||||
if self.network_config is not None:
|
||||
# TODO should we completely switch to LycorisSpecialNetwork?
|
||||
|
||||
network_kwargs = {}
|
||||
is_lycoris = False
|
||||
is_lorm = self.network_config.type.lower() == 'lorm'
|
||||
# default to LoCON if there are any conv layers or if it is named
|
||||
NetworkClass = LoRASpecialNetwork
|
||||
if self.network_config.type.lower() == 'locon' or self.network_config.type.lower() == 'lycoris':
|
||||
NetworkClass = LycorisSpecialNetwork
|
||||
is_lycoris = True
|
||||
|
||||
if is_lorm:
|
||||
network_kwargs['ignore_if_contains'] = lorm_ignore_if_contains
|
||||
network_kwargs['parameter_threshold'] = lorm_parameter_threshold
|
||||
network_kwargs['target_lin_modules'] = LORM_TARGET_REPLACE_MODULE
|
||||
|
||||
# if is_lycoris:
|
||||
# preset = PRESET['full']
|
||||
# NetworkClass.apply_preset(preset)
|
||||
@@ -810,6 +823,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
dropout=self.network_config.dropout,
|
||||
use_text_encoder_1=self.model_config.use_text_encoder_1,
|
||||
use_text_encoder_2=self.model_config.use_text_encoder_2,
|
||||
use_bias=is_lorm,
|
||||
is_lorm=is_lorm,
|
||||
network_config=self.network_config,
|
||||
**network_kwargs
|
||||
)
|
||||
|
||||
self.network.force_to(self.device_torch, dtype=dtype)
|
||||
@@ -824,6 +841,20 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.train_config.train_unet
|
||||
)
|
||||
|
||||
if is_lorm:
|
||||
self.network.is_lorm = True
|
||||
# make sure it is on the right device
|
||||
self.sd.unet.to(self.sd.device, dtype=dtype)
|
||||
original_unet_param_count = count_parameters(self.sd.unet)
|
||||
self.network.setup_lorm()
|
||||
new_unet_param_count = original_unet_param_count - self.network.calculate_lorem_parameter_reduction()
|
||||
|
||||
print_lorm_extract_details(
|
||||
start_num_params=original_unet_param_count,
|
||||
end_num_params=new_unet_param_count,
|
||||
num_replaced=len(self.network.get_all_modules()),
|
||||
)
|
||||
|
||||
self.network.prepare_grad_etc(text_encoder, unet)
|
||||
flush()
|
||||
|
||||
@@ -846,9 +877,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.train_config.gradient_checkpointing:
|
||||
self.network.enable_gradient_checkpointing()
|
||||
|
||||
# set the network to normalize if we are
|
||||
self.network.is_normalizing = self.network_config.normalize
|
||||
|
||||
lora_name = self.name
|
||||
# need to adapt name so they are not mixed up
|
||||
if self.named_lora:
|
||||
@@ -915,7 +943,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# set the device state preset before getting params
|
||||
self.sd.set_device_state(self.train_device_state_preset)
|
||||
|
||||
|
||||
# params = self.get_params()
|
||||
if len(params) == 0:
|
||||
# will only return savable weights and ones with grad
|
||||
@@ -1050,9 +1077,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
else:
|
||||
batch = None
|
||||
|
||||
# turn on normalization if we are using it and it is not on
|
||||
if self.network is not None and self.network_config.normalize and not self.network.is_normalizing:
|
||||
self.network.is_normalizing = True
|
||||
# flush()
|
||||
### HOOK ###
|
||||
self.timer.start('train_loop')
|
||||
@@ -1078,11 +1102,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
self.progress_bar.set_postfix_str(prog_bar_string)
|
||||
|
||||
# apply network normalizer if we are using it, not on regularization steps
|
||||
if self.network is not None and self.network.is_normalizing and not is_reg_step:
|
||||
with self.timer('apply_normalizer'):
|
||||
self.network.apply_stored_normalizer()
|
||||
|
||||
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
|
||||
if isinstance(batch, DataLoaderBatchDTO):
|
||||
with self.timer('batch_cleanup'):
|
||||
|
||||
Reference in New Issue
Block a user