reworked samplers. Trying to find what is wrong with diffusers sampling is sdxl

This commit is contained in:
Jaret Burkett
2023-09-03 07:56:09 -06:00
parent 4ca819a05e
commit 2a40937b4f
8 changed files with 517 additions and 63 deletions

View File

@@ -1,3 +1,4 @@
import copy
import glob
from collections import OrderedDict
import os
@@ -11,6 +12,7 @@ from toolkit.embedding import Embedding
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.optimizer import get_optimizer
from toolkit.paths import CONFIG_ROOT
from toolkit.sampler import get_sampler
from toolkit.scheduler import get_lr_scheduler
from toolkit.stable_diffusion_model import StableDiffusion
@@ -89,6 +91,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
if embedding_raw is not None:
self.embed_config = EmbeddingConfig(**embedding_raw)
model_config_to_load = copy.deepcopy(self.model_config)
if self.embed_config is None and self.network_config is None:
# get the latest checkpoint
# check to see if we have a latest save
@@ -96,7 +100,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
if latest_save_path is not None:
print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
self.model_config.name_or_path = latest_save_path
model_config_to_load.name_or_path = latest_save_path
meta = load_metadata_from_safetensors(latest_save_path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
@@ -104,11 +108,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
# get the noise scheduler
sampler = get_sampler(self.train_config.noise_scheduler)
self.sd = StableDiffusion(
device=self.device,
model_config=self.model_config,
model_config=model_config_to_load,
dtype=self.train_config.dtype,
custom_pipeline=self.custom_pipeline,
noise_scheduler=sampler,
)
# to hold network if there is one
@@ -164,7 +172,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
))
# send to be generated
self.sd.generate_images(gen_img_config_list)
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
def update_training_metadata(self):
o_dict = OrderedDict({
@@ -216,10 +224,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
for file in files[:-self.save_config.max_step_saves_to_keep]:
self.print(f"Removing old save: {file}")
os.remove(file)
# see if a yaml file with same name exists
yaml_file = os.path.splitext(file)[0] + ".yaml"
if os.path.exists(yaml_file):
os.remove(yaml_file)
return latest_file
else:
return None
def post_save_hook(self, save_path):
# override in subclass
pass
def save(self, step=None):
if not os.path.exists(self.save_root):
os.makedirs(self.save_root, exist_ok=True)
@@ -263,6 +279,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.print(f"Saved to {file_path}")
self.clean_up_saves()
self.post_save_hook(file_path)
# Called before the model is loaded
def hook_before_model_load(self):
@@ -279,6 +296,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
def before_dataset_load(self):
pass
def get_params(self):
# you can extend this in subclass to get params
# otherwise params will be gathered through normal means
return None
def hook_train_loop(self, batch):
# return loss
return 0.0
@@ -445,11 +467,14 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.network.prepare_grad_etc(text_encoder, unet)
params = self.network.prepare_optimizer_params(
text_encoder_lr=self.train_config.lr,
unet_lr=self.train_config.lr,
default_lr=self.train_config.lr
)
params = self.get_params()
if not params:
params = self.network.prepare_optimizer_params(
text_encoder_lr=self.train_config.lr,
unet_lr=self.train_config.lr,
default_lr=self.train_config.lr
)
if self.train_config.gradient_checkpointing:
self.network.enable_gradient_checkpointing()
@@ -477,8 +502,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.step_num = self.embedding.step
self.start_step = self.step_num
# set trainable params
params = self.embedding.get_trainable_params()
params = self.get_params()
if not params:
# set trainable params
params = self.embedding.get_trainable_params()
else:
# set them to train or not
@@ -506,14 +533,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.sd.text_encoder.requires_grad_(False)
self.sd.text_encoder.eval()
# will only return savable weights and ones with grad
params = self.sd.prepare_optimizer_params(
unet=self.train_config.train_unet,
text_encoder=self.train_config.train_text_encoder,
text_encoder_lr=self.train_config.lr,
unet_lr=self.train_config.lr,
default_lr=self.train_config.lr
)
params = self.get_params()
if params is None:
# will only return savable weights and ones with grad
params = self.sd.prepare_optimizer_params(
unet=self.train_config.train_unet,
text_encoder=self.train_config.train_text_encoder,
text_encoder_lr=self.train_config.lr,
unet_lr=self.train_config.lr,
default_lr=self.train_config.lr
)
### HOOK ###
params = self.hook_add_extra_train_params(params)