mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
reworked samplers. Trying to find what is wrong with diffusers sampling is sdxl
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import glob
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
@@ -11,6 +12,7 @@ from toolkit.embedding import Embedding
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.optimizer import get_optimizer
|
||||
from toolkit.paths import CONFIG_ROOT
|
||||
from toolkit.sampler import get_sampler
|
||||
|
||||
from toolkit.scheduler import get_lr_scheduler
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
@@ -89,6 +91,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if embedding_raw is not None:
|
||||
self.embed_config = EmbeddingConfig(**embedding_raw)
|
||||
|
||||
model_config_to_load = copy.deepcopy(self.model_config)
|
||||
|
||||
if self.embed_config is None and self.network_config is None:
|
||||
# get the latest checkpoint
|
||||
# check to see if we have a latest save
|
||||
@@ -96,7 +100,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
if latest_save_path is not None:
|
||||
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
self.model_config.name_or_path = latest_save_path
|
||||
model_config_to_load.name_or_path = latest_save_path
|
||||
meta = load_metadata_from_safetensors(latest_save_path)
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||
@@ -104,11 +108,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
# get the noise scheduler
|
||||
sampler = get_sampler(self.train_config.noise_scheduler)
|
||||
|
||||
self.sd = StableDiffusion(
|
||||
device=self.device,
|
||||
model_config=self.model_config,
|
||||
model_config=model_config_to_load,
|
||||
dtype=self.train_config.dtype,
|
||||
custom_pipeline=self.custom_pipeline,
|
||||
noise_scheduler=sampler,
|
||||
)
|
||||
|
||||
# to hold network if there is one
|
||||
@@ -164,7 +172,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
))
|
||||
|
||||
# send to be generated
|
||||
self.sd.generate_images(gen_img_config_list)
|
||||
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
|
||||
|
||||
def update_training_metadata(self):
|
||||
o_dict = OrderedDict({
|
||||
@@ -216,10 +224,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
for file in files[:-self.save_config.max_step_saves_to_keep]:
|
||||
self.print(f"Removing old save: {file}")
|
||||
os.remove(file)
|
||||
# see if a yaml file with same name exists
|
||||
yaml_file = os.path.splitext(file)[0] + ".yaml"
|
||||
if os.path.exists(yaml_file):
|
||||
os.remove(yaml_file)
|
||||
return latest_file
|
||||
else:
|
||||
return None
|
||||
|
||||
def post_save_hook(self, save_path):
|
||||
# override in subclass
|
||||
pass
|
||||
|
||||
def save(self, step=None):
|
||||
if not os.path.exists(self.save_root):
|
||||
os.makedirs(self.save_root, exist_ok=True)
|
||||
@@ -263,6 +279,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
self.print(f"Saved to {file_path}")
|
||||
self.clean_up_saves()
|
||||
self.post_save_hook(file_path)
|
||||
|
||||
# Called before the model is loaded
|
||||
def hook_before_model_load(self):
|
||||
@@ -279,6 +296,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def before_dataset_load(self):
|
||||
pass
|
||||
|
||||
def get_params(self):
|
||||
# you can extend this in subclass to get params
|
||||
# otherwise params will be gathered through normal means
|
||||
return None
|
||||
|
||||
def hook_train_loop(self, batch):
|
||||
# return loss
|
||||
return 0.0
|
||||
@@ -445,11 +467,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
self.network.prepare_grad_etc(text_encoder, unet)
|
||||
|
||||
params = self.network.prepare_optimizer_params(
|
||||
text_encoder_lr=self.train_config.lr,
|
||||
unet_lr=self.train_config.lr,
|
||||
default_lr=self.train_config.lr
|
||||
)
|
||||
params = self.get_params()
|
||||
|
||||
if not params:
|
||||
params = self.network.prepare_optimizer_params(
|
||||
text_encoder_lr=self.train_config.lr,
|
||||
unet_lr=self.train_config.lr,
|
||||
default_lr=self.train_config.lr
|
||||
)
|
||||
|
||||
if self.train_config.gradient_checkpointing:
|
||||
self.network.enable_gradient_checkpointing()
|
||||
@@ -477,8 +502,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.step_num = self.embedding.step
|
||||
self.start_step = self.step_num
|
||||
|
||||
# set trainable params
|
||||
params = self.embedding.get_trainable_params()
|
||||
params = self.get_params()
|
||||
if not params:
|
||||
# set trainable params
|
||||
params = self.embedding.get_trainable_params()
|
||||
|
||||
else:
|
||||
# set them to train or not
|
||||
@@ -506,14 +533,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.sd.text_encoder.requires_grad_(False)
|
||||
self.sd.text_encoder.eval()
|
||||
|
||||
# will only return savable weights and ones with grad
|
||||
params = self.sd.prepare_optimizer_params(
|
||||
unet=self.train_config.train_unet,
|
||||
text_encoder=self.train_config.train_text_encoder,
|
||||
text_encoder_lr=self.train_config.lr,
|
||||
unet_lr=self.train_config.lr,
|
||||
default_lr=self.train_config.lr
|
||||
)
|
||||
params = self.get_params()
|
||||
|
||||
if params is None:
|
||||
# will only return savable weights and ones with grad
|
||||
params = self.sd.prepare_optimizer_params(
|
||||
unet=self.train_config.train_unet,
|
||||
text_encoder=self.train_config.train_text_encoder,
|
||||
text_encoder_lr=self.train_config.lr,
|
||||
unet_lr=self.train_config.lr,
|
||||
default_lr=self.train_config.lr
|
||||
)
|
||||
|
||||
### HOOK ###
|
||||
params = self.hook_add_extra_train_params(params)
|
||||
|
||||
Reference in New Issue
Block a user