Improved lorm extraction and training

This commit is contained in:
Jaret Burkett
2023-10-28 08:21:59 -06:00
parent 0a79ac9604
commit 6f3e0d5af2
10 changed files with 559 additions and 196 deletions

View File

@@ -7,6 +7,7 @@ from typing import Union, List
import numpy as np
from diffusers import T2IAdapter
from safetensors.torch import save_file, load_file
# from lycoris.config import PRESET
from torch.utils.data import DataLoader
import torch
@@ -18,7 +19,8 @@ from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatc
from toolkit.embedding import Embedding
from toolkit.ip_adapter import IPAdapter
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.lorm import convert_diffusers_unet_to_lorm
from toolkit.lorm import convert_diffusers_unet_to_lorm, count_parameters, print_lorm_extract_details, \
lorm_ignore_if_contains, lorm_parameter_threshold, LORM_TARGET_REPLACE_MODULE
from toolkit.lycoris_special import LycorisSpecialNetwork
from toolkit.network_mixins import Network
from toolkit.optimizer import get_optimizer
@@ -128,6 +130,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
is_training_adapter = self.adapter_config is not None and self.adapter_config.train
self.do_lorm = self.get_conf('do_lorm', False)
self.lorm_extract_mode = self.get_conf('lorm_extract_mode', 'ratio')
self.lorm_extract_mode_param = self.get_conf('lorm_extract_mode_param', 0.25)
# 'ratio', 0.25)
# get the device state preset based on what we are training
self.train_device_state_preset = get_train_sd_device_state_preset(
@@ -300,9 +305,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
file_path = os.path.join(self.save_root, filename)
prev_multiplier = self.network.multiplier
self.network.multiplier = 1.0
if self.network_config.normalize:
# apply the normalization
self.network.apply_stored_normalizer()
# if we are doing embedding training as well, add that
embedding_dict = self.embedding.state_dict() if self.embedding else None
@@ -427,6 +429,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
print("load_weights not implemented for non-network models")
return None
def load_lorm(self):
latest_save_path = self.get_latest_save_path()
if latest_save_path is not None:
# hacky way to reload weights for now
# todo, do this
state_dict = load_file(latest_save_path, device=self.device)
self.sd.unet.load_state_dict(state_dict)
meta = load_metadata_from_safetensors(latest_save_path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
# def get_sigmas(self, timesteps, n_dim=4, dtype=torch.float32):
# self.sd.noise_scheduler.set_timesteps(1000, device=self.device_torch)
# sigmas = self.sd.noise_scheduler.sigmas.to(device=self.device_torch, dtype=dtype)
@@ -610,7 +627,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
batch.mask_tensor = double_up_tensor(batch.mask_tensor)
batch.control_tensor = double_up_tensor(batch.control_tensor)
# remove grads for these
noisy_latents.requires_grad = False
noisy_latents = noisy_latents.detach()
@@ -712,15 +728,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
# run base sd process run
self.sd.load_model()
if self.do_lorm:
train_modules = convert_diffusers_unet_to_lorm(self.sd.unet, 'ratio', 0.27)
for module in train_modules:
p = module.parameters()
for param in p:
param.requires_grad_(True)
params.append(param)
dtype = get_torch_dtype(self.train_config.dtype)
# model is loaded from BaseSDProcess
@@ -783,14 +790,20 @@ class BaseSDTrainProcess(BaseTrainProcess):
if not self.is_fine_tuning:
if self.network_config is not None:
# TODO should we completely switch to LycorisSpecialNetwork?
network_kwargs = {}
is_lycoris = False
is_lorm = self.network_config.type.lower() == 'lorm'
# default to LoCON if there are any conv layers or if it is named
NetworkClass = LoRASpecialNetwork
if self.network_config.type.lower() == 'locon' or self.network_config.type.lower() == 'lycoris':
NetworkClass = LycorisSpecialNetwork
is_lycoris = True
if is_lorm:
network_kwargs['ignore_if_contains'] = lorm_ignore_if_contains
network_kwargs['parameter_threshold'] = lorm_parameter_threshold
network_kwargs['target_lin_modules'] = LORM_TARGET_REPLACE_MODULE
# if is_lycoris:
# preset = PRESET['full']
# NetworkClass.apply_preset(preset)
@@ -810,6 +823,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
dropout=self.network_config.dropout,
use_text_encoder_1=self.model_config.use_text_encoder_1,
use_text_encoder_2=self.model_config.use_text_encoder_2,
use_bias=is_lorm,
is_lorm=is_lorm,
network_config=self.network_config,
**network_kwargs
)
self.network.force_to(self.device_torch, dtype=dtype)
@@ -824,6 +841,20 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.train_config.train_unet
)
if is_lorm:
self.network.is_lorm = True
# make sure it is on the right device
self.sd.unet.to(self.sd.device, dtype=dtype)
original_unet_param_count = count_parameters(self.sd.unet)
self.network.setup_lorm()
new_unet_param_count = original_unet_param_count - self.network.calculate_lorem_parameter_reduction()
print_lorm_extract_details(
start_num_params=original_unet_param_count,
end_num_params=new_unet_param_count,
num_replaced=len(self.network.get_all_modules()),
)
self.network.prepare_grad_etc(text_encoder, unet)
flush()
@@ -846,9 +877,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.train_config.gradient_checkpointing:
self.network.enable_gradient_checkpointing()
# set the network to normalize if we are
self.network.is_normalizing = self.network_config.normalize
lora_name = self.name
# need to adapt name so they are not mixed up
if self.named_lora:
@@ -915,7 +943,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
# set the device state preset before getting params
self.sd.set_device_state(self.train_device_state_preset)
# params = self.get_params()
if len(params) == 0:
# will only return savable weights and ones with grad
@@ -1050,9 +1077,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
else:
batch = None
# turn on normalization if we are using it and it is not on
if self.network is not None and self.network_config.normalize and not self.network.is_normalizing:
self.network.is_normalizing = True
# flush()
### HOOK ###
self.timer.start('train_loop')
@@ -1078,11 +1102,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.progress_bar.set_postfix_str(prog_bar_string)
# apply network normalizer if we are using it, not on regularization steps
if self.network is not None and self.network.is_normalizing and not is_reg_step:
with self.timer('apply_normalizer'):
self.network.apply_stored_normalizer()
# if the batch is a DataLoaderBatchDTO, then we need to clean it up
if isinstance(batch, DataLoaderBatchDTO):
with self.timer('batch_cleanup'):