Added rescaling, locon, sdxl, all kinds of stuff. sdxl is still weird

This commit is contained in:
Jaret Burkett
2023-07-26 16:19:50 -06:00
parent 40e60fa021
commit d3ad195b51
11 changed files with 548 additions and 45 deletions

View File

@@ -1,7 +1,10 @@
import glob
import time
from collections import OrderedDict
import os
from safetensors import safe_open
from toolkit.kohya_model_util import load_vae
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.optimizer import get_optimizer
@@ -14,7 +17,7 @@ sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from jobs.process import BaseTrainProcess
from toolkit.metadata import get_meta_for_safetensors
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import gc
@@ -48,6 +51,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.model_config = ModelConfig(**self.get_conf('model', {}))
self.save_config = SaveConfig(**self.get_conf('save', {}))
self.sample_config = SampleConfig(**self.get_conf('sample', {}))
self.first_sample_config = SampleConfig(**self.get_conf('first_sample', {})) if 'first_sample' in self.config else self.sample_config
self.logging_config = LogingConfig(**self.get_conf('logging', {}))
self.optimizer = None
self.lr_scheduler = None
@@ -56,7 +60,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# added later
self.network = None
def sample(self, step=None):
def sample(self, step=None, is_first=False):
sample_folder = os.path.join(self.save_root, 'samples')
if not os.path.exists(sample_folder):
os.makedirs(sample_folder, exist_ok=True)
@@ -112,7 +116,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
# disable progress bar
pipeline.set_progress_bar_config(disable=True)
start_seed = self.sample_config.seed
sample_config = self.first_sample_config if is_first else self.sample_config
start_seed = sample_config.seed
start_multiplier = self.network.multiplier
current_seed = start_seed
@@ -127,14 +133,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
'multiplier': self.network.multiplier,
})
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}",
for i in tqdm(range(len(sample_config.prompts)), desc=f"Generating Samples - step: {step}",
leave=False):
raw_prompt = self.sample_config.prompts[i]
raw_prompt = sample_config.prompts[i]
neg = self.sample_config.neg
multiplier = self.sample_config.network_multiplier
neg = sample_config.neg
multiplier = sample_config.network_multiplier
p_split = raw_prompt.split('--')
prompt = p_split[0].strip()
height = sample_config.height
width = sample_config.width
if len(p_split) > 1:
for split in p_split:
@@ -145,13 +153,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
elif flag == 'm':
# multiplier
multiplier = float(content)
elif flag == 'w':
# multiplier
width = int(content)
elif flag == 'h':
# multiplier
height = int(content)
height = self.sample_config.height
width = self.sample_config.width
height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
if self.sample_config.walk_seed:
if sample_config.walk_seed:
current_seed += i
if self.network is not None:
@@ -159,14 +171,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
torch.manual_seed(current_seed)
torch.cuda.manual_seed(current_seed)
img = pipeline(
prompt,
height=height,
width=width,
num_inference_steps=self.sample_config.sample_steps,
guidance_scale=self.sample_config.guidance_scale,
negative_prompt=neg,
).images[0]
if self.sd.is_xl:
img = pipeline(
prompt,
height=height,
width=width,
num_inference_steps=sample_config.sample_steps,
guidance_scale=sample_config.guidance_scale,
negative_prompt=neg,
).images[0]
else:
img = pipeline(
prompt,
height=height,
width=width,
num_inference_steps=sample_config.sample_steps,
guidance_scale=sample_config.guidance_scale,
negative_prompt=neg,
).images[0]
step_num = ''
if step is not None:
@@ -209,6 +231,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
})
return info
def clean_up_saves(self):
# remove old saves
# get latest saved step
if os.path.exists(self.save_root):
latest_file = None
# pattern is {job_name}_{zero_filles_step}.safetensors but NOT {job_name}.safetensors
pattern = f"{self.job.name}_*.safetensors"
files = glob.glob(os.path.join(self.save_root, pattern))
if len(files) > self.save_config.max_step_saves_to_keep:
# remove all but the latest max_step_saves_to_keep
files.sort(key=os.path.getctime)
for file in files[:-self.save_config.max_step_saves_to_keep]:
self.print(f"Removing old save: {file}")
os.remove(file)
return latest_file
else:
return None
def save(self, step=None):
if not os.path.exists(self.save_root):
os.makedirs(self.save_root, exist_ok=True)
@@ -231,9 +271,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
metadata=save_meta
)
else:
# TODO handle dreambooth, fine tuning, etc
# will probably have to convert dict back to LDM
ValueError("Non network training is not currently supported")
self.sd.save(
file_path,
save_meta,
get_torch_dtype(self.save_config.dtype)
)
self.print(f"Saved to {file_path}")
@@ -258,6 +300,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
):
if height is None and pixel_height is None:
raise ValueError("height or pixel_height must be specified")
raise ValueError("height or pixel_height must be specified")
if width is None and pixel_width is None:
raise ValueError("width or pixel_width must be specified")
if height is None:
@@ -316,18 +359,47 @@ class BaseSDTrainProcess(BaseTrainProcess):
if add_time_ids is None:
add_time_ids = self.get_time_ids_from_latents(latents)
# todo LECOs code looks like it is omitting noise_pred
noise_pred = train_util.predict_noise_xl(
self.sd.unet,
self.sd.noise_scheduler,
# noise_pred = train_util.predict_noise_xl(
# self.sd.unet,
# self.sd.noise_scheduler,
# timestep,
# latents,
# text_embeddings.text_embeds,
# text_embeddings.pooled_embeds,
# add_time_ids,
# guidance_scale=guidance_scale,
# guidance_rescale=guidance_rescale
# )
latent_model_input = torch.cat([latents] * 2)
latent_model_input = self.sd.noise_scheduler.scale_model_input(latent_model_input, timestep)
added_cond_kwargs = {
"text_embeds": text_embeddings.pooled_embeds,
"time_ids": add_time_ids,
}
# predict the noise residual
noise_pred = self.sd.unet(
latent_model_input,
timestep,
latents,
text_embeddings.text_embeds,
text_embeddings.pooled_embeds,
add_time_ids,
guidance_scale=guidance_scale,
guidance_rescale=guidance_rescale
encoder_hidden_states=text_embeddings.text_embeds,
added_cond_kwargs=added_cond_kwargs,
).sample
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
guided_target = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
# noise_pred = rescale_noise_cfg(
# noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
# )
noise_pred = guided_target
else:
noise_pred = train_util.predict_noise(
self.sd.unet,
@@ -366,6 +438,32 @@ class BaseSDTrainProcess(BaseTrainProcess):
# return latents_steps
return latents
def get_latest_save_path(self):
# get latest saved step
if os.path.exists(self.save_root):
latest_file = None
# pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors
pattern = f"{self.job.name}*.safetensors"
files = glob.glob(os.path.join(self.save_root, pattern))
if len(files) > 0:
latest_file = max(files, key=os.path.getctime)
return latest_file
else:
return None
def load_weights(self, path):
if self.network is not None:
self.network.load_weights(path)
meta = load_metadata_from_safetensors(path)
# if 'training_info' in Orderdict keys
if 'training_info' in meta and 'step' in meta['training_info']:
self.step_num = meta['training_info']['step']
self.start_step = self.step_num
print(f"Found step {self.step_num} in metadata, starting from there")
else:
print("load_weights not implemented for non-network models")
def run(self):
super().run()
@@ -407,20 +505,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
unet.to(self.device_torch, dtype=dtype)
if self.train_config.xformers:
unet.enable_xformers_memory_efficient_attention()
if self.train_config.gradient_checkpointing:
unet.enable_gradient_checkpointing()
unet.requires_grad_(False)
unet.eval()
if self.network_config is not None:
conv = self.network_config.conv if self.network_config.conv is not None and self.network_config.conv > 0 else None
self.network = LoRASpecialNetwork(
text_encoder=text_encoder,
unet=unet,
lora_dim=self.network_config.rank,
lora_dim=self.network_config.linear,
multiplier=1.0,
alpha=self.network_config.alpha,
train_unet=self.train_config.train_unet,
train_text_encoder=self.train_config.train_text_encoder,
conv_lora_dim=conv,
conv_alpha=self.network_config.alpha if conv is not None else None,
)
self.network.force_to(self.device_torch, dtype=dtype)
self.network.apply_to(
@@ -438,6 +542,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
default_lr=self.train_config.lr
)
latest_save_path = self.get_latest_save_path()
if latest_save_path is not None:
self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
self.print(f"Loading from {latest_save_path}")
self.load_weights(latest_save_path)
self.network.multiplier = 1.0
else:
params = []
# assume dreambooth/finetune
@@ -475,15 +588,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.print("Skipping first sample due to config setting")
else:
self.print("Generating baseline samples before training")
self.sample(0)
self.sample(0, is_first=True)
self.progress_bar = tqdm(
total=self.train_config.steps,
desc=self.job.name,
leave=True
)
self.step_num = 0
for step in range(self.train_config.steps):
# set it to our current step in case it was updated from a load
self.progress_bar.update(self.step_num)
# self.step_num = 0
for step in range(self.step_num, self.train_config.steps):
# todo handle dataloader here maybe, not sure
### HOOK ###