mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Many bug fixes. Ip adapter bug fixes. Added noise to unconditional, it works better. added an ilora adapter for 1 shotting LoRAs
This commit is contained in:
@@ -2,6 +2,7 @@ import copy
|
||||
import glob
|
||||
import inspect
|
||||
import json
|
||||
import random
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
@@ -423,7 +424,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
adapter_name += '_t2i'
|
||||
elif self.adapter_config.type == 'clip':
|
||||
adapter_name += '_clip'
|
||||
elif self.adapter_config.type == 'ip':
|
||||
elif self.adapter_config.type.startswith('ip'):
|
||||
adapter_name += '_ip'
|
||||
else:
|
||||
adapter_name += '_adapter'
|
||||
@@ -444,7 +445,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
state_dict,
|
||||
output_file=file_path,
|
||||
meta=save_meta,
|
||||
dtype=get_torch_dtype(self.save_config.dtype)
|
||||
dtype=get_torch_dtype(self.save_config.dtype),
|
||||
direct_save=self.adapter_config.train_only_image_encoder
|
||||
)
|
||||
else:
|
||||
if self.save_config.save_format == "diffusers":
|
||||
@@ -1010,7 +1012,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
loaded_state_dict = load_ip_adapter_model(
|
||||
latest_save_path,
|
||||
self.device,
|
||||
dtype=dtype
|
||||
dtype=dtype,
|
||||
direct_load=self.adapter_config.train_only_image_encoder
|
||||
)
|
||||
self.adapter.load_state_dict(loaded_state_dict)
|
||||
else:
|
||||
@@ -1146,14 +1149,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch)
|
||||
self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch)
|
||||
|
||||
# load the adapters before the dataset as they may use the clip encoders
|
||||
if self.adapter_config is not None:
|
||||
self.setup_adapter()
|
||||
flush()
|
||||
if not self.is_fine_tuning:
|
||||
if self.network_config is not None:
|
||||
# TODO should we completely switch to LycorisSpecialNetwork?
|
||||
network_kwargs = {}
|
||||
network_kwargs = self.network_config.network_kwargs
|
||||
is_lycoris = False
|
||||
is_lorm = self.network_config.type.lower() == 'lorm'
|
||||
# default to LoCON if there are any conv layers or if it is named
|
||||
@@ -1279,12 +1279,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
flush()
|
||||
|
||||
if self.adapter_config is not None:
|
||||
# self.setup_adapter()
|
||||
# set trainable params
|
||||
params.append({
|
||||
'params': self.adapter.parameters(),
|
||||
'lr': self.train_config.adapter_lr
|
||||
})
|
||||
self.setup_adapter()
|
||||
if self.adapter_config.train:
|
||||
# set trainable params
|
||||
params.append({
|
||||
'params': self.adapter.parameters(),
|
||||
'lr': self.train_config.adapter_lr
|
||||
})
|
||||
|
||||
if self.train_config.gradient_checkpointing:
|
||||
self.adapter.enable_gradient_checkpointing()
|
||||
flush()
|
||||
|
||||
params = self.load_additional_training_modules(params)
|
||||
@@ -1306,7 +1310,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
refiner_lr=self.train_config.refiner_lr,
|
||||
)
|
||||
# we may be using it for prompt injections
|
||||
if self.adapter_config is not None:
|
||||
if self.adapter_config is not None and self.adapter is None:
|
||||
self.setup_adapter()
|
||||
flush()
|
||||
### HOOK ###
|
||||
@@ -1379,7 +1383,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# sample first
|
||||
if self.train_config.skip_first_sample:
|
||||
self.print("Skipping first sample due to config setting")
|
||||
elif self.step_num <= 1:
|
||||
elif self.step_num <= 1 or self.train_config.force_first_sample:
|
||||
self.print("Generating baseline samples before training")
|
||||
self.sample(self.step_num)
|
||||
|
||||
@@ -1422,6 +1426,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
start_step_num = self.step_num
|
||||
did_first_flush = False
|
||||
for step in range(start_step_num, self.train_config.steps):
|
||||
if self.train_config.do_random_cfg:
|
||||
self.train_config.do_cfg = True
|
||||
self.train_config.cfg_scale = value_map(random.random(), 0, 1, 1.0, self.train_config.max_cfg_scale)
|
||||
self.step_num = step
|
||||
# default to true so various things can turn it off
|
||||
self.is_grad_accumulation_step = True
|
||||
|
||||
Reference in New Issue
Block a user