Added lorm. WIP

This commit is contained in:
Jaret Burkett
2023-10-26 18:23:51 -06:00
parent 9636194c09
commit 0a79ac9604
3 changed files with 441 additions and 4 deletions

View File

@@ -18,6 +18,7 @@ from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatc
from toolkit.embedding import Embedding
from toolkit.ip_adapter import IPAdapter
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.lorm import convert_diffusers_unet_to_lorm
from toolkit.lycoris_special import LycorisSpecialNetwork
from toolkit.network_mixins import Network
from toolkit.optimizer import get_optimizer
@@ -126,6 +127,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
is_training_adapter = self.adapter_config is not None and self.adapter_config.train
self.do_lorm = self.get_conf('do_lorm', False)
# get the device state preset based on what we are training
self.train_device_state_preset = get_train_sd_device_state_preset(
device=self.device_torch,
@@ -675,6 +678,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# torch.autograd.set_detect_anomaly(True)
# run base process run
BaseTrainProcess.run(self)
params = []
### HOOK ###
self.hook_before_model_load()
@@ -708,6 +712,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
# run base sd process run
self.sd.load_model()
if self.do_lorm:
train_modules = convert_diffusers_unet_to_lorm(self.sd.unet, 'ratio', 0.27)
for module in train_modules:
p = module.parameters()
for param in p:
param.requires_grad_(True)
params.append(param)
dtype = get_torch_dtype(self.train_config.dtype)
# model is loaded from BaseSDProcess
@@ -767,7 +780,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.datasets_reg is not None:
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size,
self.sd)
params = []
if not self.is_fine_tuning:
if self.network_config is not None:
# TODO should we completely switch to LycorisSpecialNetwork?
@@ -903,8 +915,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
# set the device state preset before getting params
self.sd.set_device_state(self.train_device_state_preset)
params = self.get_params()
if not params:
# params = self.get_params()
if len(params) == 0:
# will only return savable weights and ones with grad
params = self.sd.prepare_optimizer_params(
unet=self.train_config.train_unet,