mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Added rescaling, locon, sdxl, all kinds of stuff. sdxl is still weird
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
import glob
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
|
||||
from safetensors import safe_open
|
||||
|
||||
from toolkit.kohya_model_util import load_vae
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.optimizer import get_optimizer
|
||||
@@ -14,7 +17,7 @@ sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||
|
||||
from jobs.process import BaseTrainProcess
|
||||
from toolkit.metadata import get_meta_for_safetensors
|
||||
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import gc
|
||||
|
||||
@@ -48,6 +51,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.model_config = ModelConfig(**self.get_conf('model', {}))
|
||||
self.save_config = SaveConfig(**self.get_conf('save', {}))
|
||||
self.sample_config = SampleConfig(**self.get_conf('sample', {}))
|
||||
self.first_sample_config = SampleConfig(**self.get_conf('first_sample', {})) if 'first_sample' in self.config else self.sample_config
|
||||
self.logging_config = LogingConfig(**self.get_conf('logging', {}))
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
@@ -56,7 +60,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# added later
|
||||
self.network = None
|
||||
|
||||
def sample(self, step=None):
|
||||
def sample(self, step=None, is_first=False):
|
||||
sample_folder = os.path.join(self.save_root, 'samples')
|
||||
if not os.path.exists(sample_folder):
|
||||
os.makedirs(sample_folder, exist_ok=True)
|
||||
@@ -112,7 +116,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# disable progress bar
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
start_seed = self.sample_config.seed
|
||||
sample_config = self.first_sample_config if is_first else self.sample_config
|
||||
|
||||
start_seed = sample_config.seed
|
||||
start_multiplier = self.network.multiplier
|
||||
current_seed = start_seed
|
||||
|
||||
@@ -127,14 +133,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
'multiplier': self.network.multiplier,
|
||||
})
|
||||
|
||||
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}",
|
||||
for i in tqdm(range(len(sample_config.prompts)), desc=f"Generating Samples - step: {step}",
|
||||
leave=False):
|
||||
raw_prompt = self.sample_config.prompts[i]
|
||||
raw_prompt = sample_config.prompts[i]
|
||||
|
||||
neg = self.sample_config.neg
|
||||
multiplier = self.sample_config.network_multiplier
|
||||
neg = sample_config.neg
|
||||
multiplier = sample_config.network_multiplier
|
||||
p_split = raw_prompt.split('--')
|
||||
prompt = p_split[0].strip()
|
||||
height = sample_config.height
|
||||
width = sample_config.width
|
||||
|
||||
if len(p_split) > 1:
|
||||
for split in p_split:
|
||||
@@ -145,13 +153,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
elif flag == 'm':
|
||||
# multiplier
|
||||
multiplier = float(content)
|
||||
elif flag == 'w':
|
||||
# multiplier
|
||||
width = int(content)
|
||||
elif flag == 'h':
|
||||
# multiplier
|
||||
height = int(content)
|
||||
|
||||
height = self.sample_config.height
|
||||
width = self.sample_config.width
|
||||
height = max(64, height - height % 8) # round to divisible by 8
|
||||
width = max(64, width - width % 8) # round to divisible by 8
|
||||
|
||||
if self.sample_config.walk_seed:
|
||||
if sample_config.walk_seed:
|
||||
current_seed += i
|
||||
|
||||
if self.network is not None:
|
||||
@@ -159,14 +171,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
torch.manual_seed(current_seed)
|
||||
torch.cuda.manual_seed(current_seed)
|
||||
|
||||
img = pipeline(
|
||||
prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=self.sample_config.sample_steps,
|
||||
guidance_scale=self.sample_config.guidance_scale,
|
||||
negative_prompt=neg,
|
||||
).images[0]
|
||||
if self.sd.is_xl:
|
||||
img = pipeline(
|
||||
prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=sample_config.sample_steps,
|
||||
guidance_scale=sample_config.guidance_scale,
|
||||
negative_prompt=neg,
|
||||
).images[0]
|
||||
else:
|
||||
img = pipeline(
|
||||
prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=sample_config.sample_steps,
|
||||
guidance_scale=sample_config.guidance_scale,
|
||||
negative_prompt=neg,
|
||||
).images[0]
|
||||
|
||||
step_num = ''
|
||||
if step is not None:
|
||||
@@ -209,6 +231,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
})
|
||||
return info
|
||||
|
||||
def clean_up_saves(self):
|
||||
# remove old saves
|
||||
# get latest saved step
|
||||
if os.path.exists(self.save_root):
|
||||
latest_file = None
|
||||
# pattern is {job_name}_{zero_filles_step}.safetensors but NOT {job_name}.safetensors
|
||||
pattern = f"{self.job.name}_*.safetensors"
|
||||
files = glob.glob(os.path.join(self.save_root, pattern))
|
||||
if len(files) > self.save_config.max_step_saves_to_keep:
|
||||
# remove all but the latest max_step_saves_to_keep
|
||||
files.sort(key=os.path.getctime)
|
||||
for file in files[:-self.save_config.max_step_saves_to_keep]:
|
||||
self.print(f"Removing old save: {file}")
|
||||
os.remove(file)
|
||||
return latest_file
|
||||
else:
|
||||
return None
|
||||
|
||||
def save(self, step=None):
|
||||
if not os.path.exists(self.save_root):
|
||||
os.makedirs(self.save_root, exist_ok=True)
|
||||
@@ -231,9 +271,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
metadata=save_meta
|
||||
)
|
||||
else:
|
||||
# TODO handle dreambooth, fine tuning, etc
|
||||
# will probably have to convert dict back to LDM
|
||||
ValueError("Non network training is not currently supported")
|
||||
self.sd.save(
|
||||
file_path,
|
||||
save_meta,
|
||||
get_torch_dtype(self.save_config.dtype)
|
||||
)
|
||||
|
||||
self.print(f"Saved to {file_path}")
|
||||
|
||||
@@ -258,6 +300,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
):
|
||||
if height is None and pixel_height is None:
|
||||
raise ValueError("height or pixel_height must be specified")
|
||||
raise ValueError("height or pixel_height must be specified")
|
||||
if width is None and pixel_width is None:
|
||||
raise ValueError("width or pixel_width must be specified")
|
||||
if height is None:
|
||||
@@ -316,18 +359,47 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if add_time_ids is None:
|
||||
add_time_ids = self.get_time_ids_from_latents(latents)
|
||||
# todo LECOs code looks like it is omitting noise_pred
|
||||
noise_pred = train_util.predict_noise_xl(
|
||||
self.sd.unet,
|
||||
self.sd.noise_scheduler,
|
||||
# noise_pred = train_util.predict_noise_xl(
|
||||
# self.sd.unet,
|
||||
# self.sd.noise_scheduler,
|
||||
# timestep,
|
||||
# latents,
|
||||
# text_embeddings.text_embeds,
|
||||
# text_embeddings.pooled_embeds,
|
||||
# add_time_ids,
|
||||
# guidance_scale=guidance_scale,
|
||||
# guidance_rescale=guidance_rescale
|
||||
# )
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
|
||||
latent_model_input = self.sd.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
||||
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": text_embeddings.pooled_embeds,
|
||||
"time_ids": add_time_ids,
|
||||
}
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.sd.unet(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
latents,
|
||||
text_embeddings.text_embeds,
|
||||
text_embeddings.pooled_embeds,
|
||||
add_time_ids,
|
||||
guidance_scale=guidance_scale,
|
||||
guidance_rescale=guidance_rescale
|
||||
encoder_hidden_states=text_embeddings.text_embeds,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
guided_target = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
|
||||
# noise_pred = rescale_noise_cfg(
|
||||
# noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
|
||||
# )
|
||||
|
||||
noise_pred = guided_target
|
||||
|
||||
else:
|
||||
noise_pred = train_util.predict_noise(
|
||||
self.sd.unet,
|
||||
@@ -366,6 +438,32 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# return latents_steps
|
||||
return latents
|
||||
|
||||
def get_latest_save_path(self):
|
||||
# get latest saved step
|
||||
if os.path.exists(self.save_root):
|
||||
latest_file = None
|
||||
# pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors
|
||||
pattern = f"{self.job.name}*.safetensors"
|
||||
files = glob.glob(os.path.join(self.save_root, pattern))
|
||||
if len(files) > 0:
|
||||
latest_file = max(files, key=os.path.getctime)
|
||||
return latest_file
|
||||
else:
|
||||
return None
|
||||
|
||||
def load_weights(self, path):
|
||||
if self.network is not None:
|
||||
self.network.load_weights(path)
|
||||
meta = load_metadata_from_safetensors(path)
|
||||
# if 'training_info' in Orderdict keys
|
||||
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||
self.step_num = meta['training_info']['step']
|
||||
self.start_step = self.step_num
|
||||
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||
|
||||
else:
|
||||
print("load_weights not implemented for non-network models")
|
||||
|
||||
def run(self):
|
||||
super().run()
|
||||
|
||||
@@ -407,20 +505,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
unet.to(self.device_torch, dtype=dtype)
|
||||
if self.train_config.xformers:
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
if self.train_config.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
unet.requires_grad_(False)
|
||||
unet.eval()
|
||||
|
||||
if self.network_config is not None:
|
||||
conv = self.network_config.conv if self.network_config.conv is not None and self.network_config.conv > 0 else None
|
||||
self.network = LoRASpecialNetwork(
|
||||
text_encoder=text_encoder,
|
||||
unet=unet,
|
||||
lora_dim=self.network_config.rank,
|
||||
lora_dim=self.network_config.linear,
|
||||
multiplier=1.0,
|
||||
alpha=self.network_config.alpha,
|
||||
train_unet=self.train_config.train_unet,
|
||||
train_text_encoder=self.train_config.train_text_encoder,
|
||||
conv_lora_dim=conv,
|
||||
conv_alpha=self.network_config.alpha if conv is not None else None,
|
||||
)
|
||||
|
||||
|
||||
self.network.force_to(self.device_torch, dtype=dtype)
|
||||
|
||||
self.network.apply_to(
|
||||
@@ -438,6 +542,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
default_lr=self.train_config.lr
|
||||
)
|
||||
|
||||
latest_save_path = self.get_latest_save_path()
|
||||
if latest_save_path is not None:
|
||||
self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||
self.print(f"Loading from {latest_save_path}")
|
||||
self.load_weights(latest_save_path)
|
||||
self.network.multiplier = 1.0
|
||||
|
||||
|
||||
|
||||
else:
|
||||
params = []
|
||||
# assume dreambooth/finetune
|
||||
@@ -475,15 +588,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.print("Skipping first sample due to config setting")
|
||||
else:
|
||||
self.print("Generating baseline samples before training")
|
||||
self.sample(0)
|
||||
self.sample(0, is_first=True)
|
||||
|
||||
self.progress_bar = tqdm(
|
||||
total=self.train_config.steps,
|
||||
desc=self.job.name,
|
||||
leave=True
|
||||
)
|
||||
self.step_num = 0
|
||||
for step in range(self.train_config.steps):
|
||||
# set it to our current step in case it was updated from a load
|
||||
self.progress_bar.update(self.step_num)
|
||||
# self.step_num = 0
|
||||
for step in range(self.step_num, self.train_config.steps):
|
||||
# todo handle dataloader here maybe, not sure
|
||||
|
||||
### HOOK ###
|
||||
|
||||
Reference in New Issue
Block a user