Many bug fixes. Ip adapter bug fixes. Added noise to unconditional, it works better. added an ilora adapter for 1 shotting LoRAs

This commit is contained in:
Jaret Burkett
2024-01-28 08:20:03 -07:00
parent f17ad8d794
commit 92b9c71d44
10 changed files with 352 additions and 56 deletions

View File

@@ -2,6 +2,7 @@ import copy
import glob
import inspect
import json
import random
import shutil
from collections import OrderedDict
import os
@@ -423,7 +424,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
adapter_name += '_t2i'
elif self.adapter_config.type == 'clip':
adapter_name += '_clip'
elif self.adapter_config.type == 'ip':
elif self.adapter_config.type.startswith('ip'):
adapter_name += '_ip'
else:
adapter_name += '_adapter'
@@ -444,7 +445,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
state_dict,
output_file=file_path,
meta=save_meta,
dtype=get_torch_dtype(self.save_config.dtype)
dtype=get_torch_dtype(self.save_config.dtype),
direct_save=self.adapter_config.train_only_image_encoder
)
else:
if self.save_config.save_format == "diffusers":
@@ -1010,7 +1012,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
loaded_state_dict = load_ip_adapter_model(
latest_save_path,
self.device,
dtype=dtype
dtype=dtype,
direct_load=self.adapter_config.train_only_image_encoder
)
self.adapter.load_state_dict(loaded_state_dict)
else:
@@ -1146,14 +1149,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch)
self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch)
# load the adapters before the dataset as they may use the clip encoders
if self.adapter_config is not None:
self.setup_adapter()
flush()
if not self.is_fine_tuning:
if self.network_config is not None:
# TODO should we completely switch to LycorisSpecialNetwork?
network_kwargs = {}
network_kwargs = self.network_config.network_kwargs
is_lycoris = False
is_lorm = self.network_config.type.lower() == 'lorm'
# default to LoCON if there are any conv layers or if it is named
@@ -1279,12 +1279,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
flush()
if self.adapter_config is not None:
# self.setup_adapter()
# set trainable params
params.append({
'params': self.adapter.parameters(),
'lr': self.train_config.adapter_lr
})
self.setup_adapter()
if self.adapter_config.train:
# set trainable params
params.append({
'params': self.adapter.parameters(),
'lr': self.train_config.adapter_lr
})
if self.train_config.gradient_checkpointing:
self.adapter.enable_gradient_checkpointing()
flush()
params = self.load_additional_training_modules(params)
@@ -1306,7 +1310,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
refiner_lr=self.train_config.refiner_lr,
)
# we may be using it for prompt injections
if self.adapter_config is not None:
if self.adapter_config is not None and self.adapter is None:
self.setup_adapter()
flush()
### HOOK ###
@@ -1379,7 +1383,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# sample first
if self.train_config.skip_first_sample:
self.print("Skipping first sample due to config setting")
elif self.step_num <= 1:
elif self.step_num <= 1 or self.train_config.force_first_sample:
self.print("Generating baseline samples before training")
self.sample(self.step_num)
@@ -1422,6 +1426,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
start_step_num = self.step_num
did_first_flush = False
for step in range(start_step_num, self.train_config.steps):
if self.train_config.do_random_cfg:
self.train_config.do_cfg = True
self.train_config.cfg_scale = value_map(random.random(), 0, 1, 1.0, self.train_config.max_cfg_scale)
self.step_num = step
# default to true so various things can turn it off
self.is_grad_accumulation_step = True