mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added lorm. WIP
This commit is contained in:
@@ -18,6 +18,7 @@ from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatc
|
||||
from toolkit.embedding import Embedding
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.lorm import convert_diffusers_unet_to_lorm
|
||||
from toolkit.lycoris_special import LycorisSpecialNetwork
|
||||
from toolkit.network_mixins import Network
|
||||
from toolkit.optimizer import get_optimizer
|
||||
@@ -126,6 +127,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
is_training_adapter = self.adapter_config is not None and self.adapter_config.train
|
||||
|
||||
self.do_lorm = self.get_conf('do_lorm', False)
|
||||
|
||||
# get the device state preset based on what we are training
|
||||
self.train_device_state_preset = get_train_sd_device_state_preset(
|
||||
device=self.device_torch,
|
||||
@@ -675,6 +678,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
# run base process run
|
||||
BaseTrainProcess.run(self)
|
||||
params = []
|
||||
|
||||
### HOOK ###
|
||||
self.hook_before_model_load()
|
||||
@@ -708,6 +712,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# run base sd process run
|
||||
self.sd.load_model()
|
||||
|
||||
if self.do_lorm:
|
||||
train_modules = convert_diffusers_unet_to_lorm(self.sd.unet, 'ratio', 0.27)
|
||||
for module in train_modules:
|
||||
p = module.parameters()
|
||||
for param in p:
|
||||
param.requires_grad_(True)
|
||||
params.append(param)
|
||||
|
||||
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
# model is loaded from BaseSDProcess
|
||||
@@ -767,7 +780,6 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.datasets_reg is not None:
|
||||
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size,
|
||||
self.sd)
|
||||
params = []
|
||||
if not self.is_fine_tuning:
|
||||
if self.network_config is not None:
|
||||
# TODO should we completely switch to LycorisSpecialNetwork?
|
||||
@@ -903,8 +915,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# set the device state preset before getting params
|
||||
self.sd.set_device_state(self.train_device_state_preset)
|
||||
|
||||
params = self.get_params()
|
||||
if not params:
|
||||
|
||||
# params = self.get_params()
|
||||
if len(params) == 0:
|
||||
# will only return savable weights and ones with grad
|
||||
params = self.sd.prepare_optimizer_params(
|
||||
unet=self.train_config.train_unet,
|
||||
|
||||
Reference in New Issue
Block a user