mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added rescaling, locon, sdxl, all kinds of stuff. sdxl is still weird
This commit is contained in:
@@ -18,6 +18,7 @@ process_dict = {
|
|||||||
'vae': 'TrainVAEProcess',
|
'vae': 'TrainVAEProcess',
|
||||||
'slider': 'TrainSliderProcess',
|
'slider': 'TrainSliderProcess',
|
||||||
'lora_hack': 'TrainLoRAHack',
|
'lora_hack': 'TrainLoRAHack',
|
||||||
|
'rescale_sd': 'TrainSDRescaleProcess',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
|
import glob
|
||||||
import time
|
import time
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
from toolkit.kohya_model_util import load_vae
|
from toolkit.kohya_model_util import load_vae
|
||||||
from toolkit.lora_special import LoRASpecialNetwork
|
from toolkit.lora_special import LoRASpecialNetwork
|
||||||
from toolkit.optimizer import get_optimizer
|
from toolkit.optimizer import get_optimizer
|
||||||
@@ -14,7 +17,7 @@ sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
|||||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
||||||
|
|
||||||
from jobs.process import BaseTrainProcess
|
from jobs.process import BaseTrainProcess
|
||||||
from toolkit.metadata import get_meta_for_safetensors
|
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors
|
||||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
@@ -48,6 +51,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
self.model_config = ModelConfig(**self.get_conf('model', {}))
|
self.model_config = ModelConfig(**self.get_conf('model', {}))
|
||||||
self.save_config = SaveConfig(**self.get_conf('save', {}))
|
self.save_config = SaveConfig(**self.get_conf('save', {}))
|
||||||
self.sample_config = SampleConfig(**self.get_conf('sample', {}))
|
self.sample_config = SampleConfig(**self.get_conf('sample', {}))
|
||||||
|
self.first_sample_config = SampleConfig(**self.get_conf('first_sample', {})) if 'first_sample' in self.config else self.sample_config
|
||||||
self.logging_config = LogingConfig(**self.get_conf('logging', {}))
|
self.logging_config = LogingConfig(**self.get_conf('logging', {}))
|
||||||
self.optimizer = None
|
self.optimizer = None
|
||||||
self.lr_scheduler = None
|
self.lr_scheduler = None
|
||||||
@@ -56,7 +60,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# added later
|
# added later
|
||||||
self.network = None
|
self.network = None
|
||||||
|
|
||||||
def sample(self, step=None):
|
def sample(self, step=None, is_first=False):
|
||||||
sample_folder = os.path.join(self.save_root, 'samples')
|
sample_folder = os.path.join(self.save_root, 'samples')
|
||||||
if not os.path.exists(sample_folder):
|
if not os.path.exists(sample_folder):
|
||||||
os.makedirs(sample_folder, exist_ok=True)
|
os.makedirs(sample_folder, exist_ok=True)
|
||||||
@@ -112,7 +116,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# disable progress bar
|
# disable progress bar
|
||||||
pipeline.set_progress_bar_config(disable=True)
|
pipeline.set_progress_bar_config(disable=True)
|
||||||
|
|
||||||
start_seed = self.sample_config.seed
|
sample_config = self.first_sample_config if is_first else self.sample_config
|
||||||
|
|
||||||
|
start_seed = sample_config.seed
|
||||||
start_multiplier = self.network.multiplier
|
start_multiplier = self.network.multiplier
|
||||||
current_seed = start_seed
|
current_seed = start_seed
|
||||||
|
|
||||||
@@ -127,14 +133,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
'multiplier': self.network.multiplier,
|
'multiplier': self.network.multiplier,
|
||||||
})
|
})
|
||||||
|
|
||||||
for i in tqdm(range(len(self.sample_config.prompts)), desc=f"Generating Samples - step: {step}",
|
for i in tqdm(range(len(sample_config.prompts)), desc=f"Generating Samples - step: {step}",
|
||||||
leave=False):
|
leave=False):
|
||||||
raw_prompt = self.sample_config.prompts[i]
|
raw_prompt = sample_config.prompts[i]
|
||||||
|
|
||||||
neg = self.sample_config.neg
|
neg = sample_config.neg
|
||||||
multiplier = self.sample_config.network_multiplier
|
multiplier = sample_config.network_multiplier
|
||||||
p_split = raw_prompt.split('--')
|
p_split = raw_prompt.split('--')
|
||||||
prompt = p_split[0].strip()
|
prompt = p_split[0].strip()
|
||||||
|
height = sample_config.height
|
||||||
|
width = sample_config.width
|
||||||
|
|
||||||
if len(p_split) > 1:
|
if len(p_split) > 1:
|
||||||
for split in p_split:
|
for split in p_split:
|
||||||
@@ -145,13 +153,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
elif flag == 'm':
|
elif flag == 'm':
|
||||||
# multiplier
|
# multiplier
|
||||||
multiplier = float(content)
|
multiplier = float(content)
|
||||||
|
elif flag == 'w':
|
||||||
|
# multiplier
|
||||||
|
width = int(content)
|
||||||
|
elif flag == 'h':
|
||||||
|
# multiplier
|
||||||
|
height = int(content)
|
||||||
|
|
||||||
height = self.sample_config.height
|
|
||||||
width = self.sample_config.width
|
|
||||||
height = max(64, height - height % 8) # round to divisible by 8
|
height = max(64, height - height % 8) # round to divisible by 8
|
||||||
width = max(64, width - width % 8) # round to divisible by 8
|
width = max(64, width - width % 8) # round to divisible by 8
|
||||||
|
|
||||||
if self.sample_config.walk_seed:
|
if sample_config.walk_seed:
|
||||||
current_seed += i
|
current_seed += i
|
||||||
|
|
||||||
if self.network is not None:
|
if self.network is not None:
|
||||||
@@ -159,14 +171,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
torch.manual_seed(current_seed)
|
torch.manual_seed(current_seed)
|
||||||
torch.cuda.manual_seed(current_seed)
|
torch.cuda.manual_seed(current_seed)
|
||||||
|
|
||||||
img = pipeline(
|
if self.sd.is_xl:
|
||||||
prompt,
|
img = pipeline(
|
||||||
height=height,
|
prompt,
|
||||||
width=width,
|
height=height,
|
||||||
num_inference_steps=self.sample_config.sample_steps,
|
width=width,
|
||||||
guidance_scale=self.sample_config.guidance_scale,
|
num_inference_steps=sample_config.sample_steps,
|
||||||
negative_prompt=neg,
|
guidance_scale=sample_config.guidance_scale,
|
||||||
).images[0]
|
negative_prompt=neg,
|
||||||
|
).images[0]
|
||||||
|
else:
|
||||||
|
img = pipeline(
|
||||||
|
prompt,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_inference_steps=sample_config.sample_steps,
|
||||||
|
guidance_scale=sample_config.guidance_scale,
|
||||||
|
negative_prompt=neg,
|
||||||
|
).images[0]
|
||||||
|
|
||||||
step_num = ''
|
step_num = ''
|
||||||
if step is not None:
|
if step is not None:
|
||||||
@@ -209,6 +231,24 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
})
|
})
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
def clean_up_saves(self):
|
||||||
|
# remove old saves
|
||||||
|
# get latest saved step
|
||||||
|
if os.path.exists(self.save_root):
|
||||||
|
latest_file = None
|
||||||
|
# pattern is {job_name}_{zero_filles_step}.safetensors but NOT {job_name}.safetensors
|
||||||
|
pattern = f"{self.job.name}_*.safetensors"
|
||||||
|
files = glob.glob(os.path.join(self.save_root, pattern))
|
||||||
|
if len(files) > self.save_config.max_step_saves_to_keep:
|
||||||
|
# remove all but the latest max_step_saves_to_keep
|
||||||
|
files.sort(key=os.path.getctime)
|
||||||
|
for file in files[:-self.save_config.max_step_saves_to_keep]:
|
||||||
|
self.print(f"Removing old save: {file}")
|
||||||
|
os.remove(file)
|
||||||
|
return latest_file
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
def save(self, step=None):
|
def save(self, step=None):
|
||||||
if not os.path.exists(self.save_root):
|
if not os.path.exists(self.save_root):
|
||||||
os.makedirs(self.save_root, exist_ok=True)
|
os.makedirs(self.save_root, exist_ok=True)
|
||||||
@@ -231,9 +271,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
metadata=save_meta
|
metadata=save_meta
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# TODO handle dreambooth, fine tuning, etc
|
self.sd.save(
|
||||||
# will probably have to convert dict back to LDM
|
file_path,
|
||||||
ValueError("Non network training is not currently supported")
|
save_meta,
|
||||||
|
get_torch_dtype(self.save_config.dtype)
|
||||||
|
)
|
||||||
|
|
||||||
self.print(f"Saved to {file_path}")
|
self.print(f"Saved to {file_path}")
|
||||||
|
|
||||||
@@ -258,6 +300,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
):
|
):
|
||||||
if height is None and pixel_height is None:
|
if height is None and pixel_height is None:
|
||||||
raise ValueError("height or pixel_height must be specified")
|
raise ValueError("height or pixel_height must be specified")
|
||||||
|
raise ValueError("height or pixel_height must be specified")
|
||||||
if width is None and pixel_width is None:
|
if width is None and pixel_width is None:
|
||||||
raise ValueError("width or pixel_width must be specified")
|
raise ValueError("width or pixel_width must be specified")
|
||||||
if height is None:
|
if height is None:
|
||||||
@@ -316,18 +359,47 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
if add_time_ids is None:
|
if add_time_ids is None:
|
||||||
add_time_ids = self.get_time_ids_from_latents(latents)
|
add_time_ids = self.get_time_ids_from_latents(latents)
|
||||||
# todo LECOs code looks like it is omitting noise_pred
|
# todo LECOs code looks like it is omitting noise_pred
|
||||||
noise_pred = train_util.predict_noise_xl(
|
# noise_pred = train_util.predict_noise_xl(
|
||||||
self.sd.unet,
|
# self.sd.unet,
|
||||||
self.sd.noise_scheduler,
|
# self.sd.noise_scheduler,
|
||||||
|
# timestep,
|
||||||
|
# latents,
|
||||||
|
# text_embeddings.text_embeds,
|
||||||
|
# text_embeddings.pooled_embeds,
|
||||||
|
# add_time_ids,
|
||||||
|
# guidance_scale=guidance_scale,
|
||||||
|
# guidance_rescale=guidance_rescale
|
||||||
|
# )
|
||||||
|
latent_model_input = torch.cat([latents] * 2)
|
||||||
|
|
||||||
|
latent_model_input = self.sd.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
||||||
|
|
||||||
|
added_cond_kwargs = {
|
||||||
|
"text_embeds": text_embeddings.pooled_embeds,
|
||||||
|
"time_ids": add_time_ids,
|
||||||
|
}
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
noise_pred = self.sd.unet(
|
||||||
|
latent_model_input,
|
||||||
timestep,
|
timestep,
|
||||||
latents,
|
encoder_hidden_states=text_embeddings.text_embeds,
|
||||||
text_embeddings.text_embeds,
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
text_embeddings.pooled_embeds,
|
).sample
|
||||||
add_time_ids,
|
|
||||||
guidance_scale=guidance_scale,
|
# perform guidance
|
||||||
guidance_rescale=guidance_rescale
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
|
guided_target = noise_pred_uncond + guidance_scale * (
|
||||||
|
noise_pred_text - noise_pred_uncond
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
|
||||||
|
# noise_pred = rescale_noise_cfg(
|
||||||
|
# noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
|
||||||
|
# )
|
||||||
|
|
||||||
|
noise_pred = guided_target
|
||||||
|
|
||||||
else:
|
else:
|
||||||
noise_pred = train_util.predict_noise(
|
noise_pred = train_util.predict_noise(
|
||||||
self.sd.unet,
|
self.sd.unet,
|
||||||
@@ -366,6 +438,32 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
# return latents_steps
|
# return latents_steps
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
def get_latest_save_path(self):
|
||||||
|
# get latest saved step
|
||||||
|
if os.path.exists(self.save_root):
|
||||||
|
latest_file = None
|
||||||
|
# pattern is {job_name}_{zero_filles_step}.safetensors or {job_name}.safetensors
|
||||||
|
pattern = f"{self.job.name}*.safetensors"
|
||||||
|
files = glob.glob(os.path.join(self.save_root, pattern))
|
||||||
|
if len(files) > 0:
|
||||||
|
latest_file = max(files, key=os.path.getctime)
|
||||||
|
return latest_file
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_weights(self, path):
|
||||||
|
if self.network is not None:
|
||||||
|
self.network.load_weights(path)
|
||||||
|
meta = load_metadata_from_safetensors(path)
|
||||||
|
# if 'training_info' in Orderdict keys
|
||||||
|
if 'training_info' in meta and 'step' in meta['training_info']:
|
||||||
|
self.step_num = meta['training_info']['step']
|
||||||
|
self.start_step = self.step_num
|
||||||
|
print(f"Found step {self.step_num} in metadata, starting from there")
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("load_weights not implemented for non-network models")
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
super().run()
|
super().run()
|
||||||
|
|
||||||
@@ -407,20 +505,26 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
unet.to(self.device_torch, dtype=dtype)
|
unet.to(self.device_torch, dtype=dtype)
|
||||||
if self.train_config.xformers:
|
if self.train_config.xformers:
|
||||||
unet.enable_xformers_memory_efficient_attention()
|
unet.enable_xformers_memory_efficient_attention()
|
||||||
|
if self.train_config.gradient_checkpointing:
|
||||||
|
unet.enable_gradient_checkpointing()
|
||||||
unet.requires_grad_(False)
|
unet.requires_grad_(False)
|
||||||
unet.eval()
|
unet.eval()
|
||||||
|
|
||||||
if self.network_config is not None:
|
if self.network_config is not None:
|
||||||
|
conv = self.network_config.conv if self.network_config.conv is not None and self.network_config.conv > 0 else None
|
||||||
self.network = LoRASpecialNetwork(
|
self.network = LoRASpecialNetwork(
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
unet=unet,
|
unet=unet,
|
||||||
lora_dim=self.network_config.rank,
|
lora_dim=self.network_config.linear,
|
||||||
multiplier=1.0,
|
multiplier=1.0,
|
||||||
alpha=self.network_config.alpha,
|
alpha=self.network_config.alpha,
|
||||||
train_unet=self.train_config.train_unet,
|
train_unet=self.train_config.train_unet,
|
||||||
train_text_encoder=self.train_config.train_text_encoder,
|
train_text_encoder=self.train_config.train_text_encoder,
|
||||||
|
conv_lora_dim=conv,
|
||||||
|
conv_alpha=self.network_config.alpha if conv is not None else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
self.network.force_to(self.device_torch, dtype=dtype)
|
self.network.force_to(self.device_torch, dtype=dtype)
|
||||||
|
|
||||||
self.network.apply_to(
|
self.network.apply_to(
|
||||||
@@ -438,6 +542,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
default_lr=self.train_config.lr
|
default_lr=self.train_config.lr
|
||||||
)
|
)
|
||||||
|
|
||||||
|
latest_save_path = self.get_latest_save_path()
|
||||||
|
if latest_save_path is not None:
|
||||||
|
self.print(f"#### IMPORTANT RESUMING FROM {latest_save_path} ####")
|
||||||
|
self.print(f"Loading from {latest_save_path}")
|
||||||
|
self.load_weights(latest_save_path)
|
||||||
|
self.network.multiplier = 1.0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
params = []
|
params = []
|
||||||
# assume dreambooth/finetune
|
# assume dreambooth/finetune
|
||||||
@@ -475,15 +588,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
self.print("Skipping first sample due to config setting")
|
self.print("Skipping first sample due to config setting")
|
||||||
else:
|
else:
|
||||||
self.print("Generating baseline samples before training")
|
self.print("Generating baseline samples before training")
|
||||||
self.sample(0)
|
self.sample(0, is_first=True)
|
||||||
|
|
||||||
self.progress_bar = tqdm(
|
self.progress_bar = tqdm(
|
||||||
total=self.train_config.steps,
|
total=self.train_config.steps,
|
||||||
desc=self.job.name,
|
desc=self.job.name,
|
||||||
leave=True
|
leave=True
|
||||||
)
|
)
|
||||||
self.step_num = 0
|
# set it to our current step in case it was updated from a load
|
||||||
for step in range(self.train_config.steps):
|
self.progress_bar.update(self.step_num)
|
||||||
|
# self.step_num = 0
|
||||||
|
for step in range(self.step_num, self.train_config.steps):
|
||||||
# todo handle dataloader here maybe, not sure
|
# todo handle dataloader here maybe, not sure
|
||||||
|
|
||||||
### HOOK ###
|
### HOOK ###
|
||||||
|
|||||||
278
jobs/process/TrainSDRescaleProcess.py
Normal file
278
jobs/process/TrainSDRescaleProcess.py
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
# ref:
|
||||||
|
# - https://github.com/p1atdev/LECO/blob/main/train_lora.py
|
||||||
|
import time
|
||||||
|
from collections import OrderedDict
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from toolkit.config_modules import SliderConfig
|
||||||
|
from toolkit.layers import ReductionKernel
|
||||||
|
from toolkit.paths import REPOS_ROOT
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||||
|
|
||||||
|
sys.path.append(REPOS_ROOT)
|
||||||
|
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||||
|
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||||
|
import gc
|
||||||
|
from toolkit import train_tools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from leco import train_util, model_util
|
||||||
|
from .BaseSDTrainProcess import BaseSDTrainProcess, StableDiffusion
|
||||||
|
|
||||||
|
|
||||||
|
def flush():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
class RescaleConfig:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
self.from_resolution = kwargs.get('from_resolution', 512)
|
||||||
|
self.scale = kwargs.get('scale', 0.5)
|
||||||
|
self.prompt_file = kwargs.get('prompt_file', None)
|
||||||
|
self.prompt_tensors = kwargs.get('prompt_tensors', None)
|
||||||
|
self.to_resolution = kwargs.get('to_resolution', int(self.from_resolution * self.scale))
|
||||||
|
|
||||||
|
if self.prompt_file is None:
|
||||||
|
raise ValueError("prompt_file is required")
|
||||||
|
|
||||||
|
|
||||||
|
class PromptEmbedsCache:
|
||||||
|
prompts: dict[str, PromptEmbeds] = {}
|
||||||
|
|
||||||
|
def __setitem__(self, __name: str, __value: PromptEmbeds) -> None:
|
||||||
|
self.prompts[__name] = __value
|
||||||
|
|
||||||
|
def __getitem__(self, __name: str) -> Optional[PromptEmbeds]:
|
||||||
|
if __name in self.prompts:
|
||||||
|
return self.prompts[__name]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class TrainSDRescaleProcess(BaseSDTrainProcess):
|
||||||
|
def __init__(self, process_id: int, job, config: OrderedDict):
|
||||||
|
super().__init__(process_id, job, config)
|
||||||
|
self.step_num = 0
|
||||||
|
self.start_step = 0
|
||||||
|
self.device = self.get_conf('device', self.job.device)
|
||||||
|
self.device_torch = torch.device(self.device)
|
||||||
|
self.prompt_cache = PromptEmbedsCache()
|
||||||
|
self.rescale_config = RescaleConfig(**self.get_conf('rescale', required=True))
|
||||||
|
self.reduce_size_fn = ReductionKernel(
|
||||||
|
in_channels=4,
|
||||||
|
kernel_size=int(self.rescale_config.from_resolution // self.rescale_config.to_resolution),
|
||||||
|
dtype=get_torch_dtype(self.train_config.dtype),
|
||||||
|
device=self.device_torch,
|
||||||
|
)
|
||||||
|
self.prompt_txt_list = []
|
||||||
|
|
||||||
|
def before_model_load(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def hook_before_train_loop(self):
|
||||||
|
self.print(f"Loading prompt file from {self.rescale_config.prompt_file}")
|
||||||
|
|
||||||
|
# read line by line from file
|
||||||
|
with open(self.rescale_config.prompt_file, 'r') as f:
|
||||||
|
self.prompt_txt_list = f.readlines()
|
||||||
|
# clean empty lines
|
||||||
|
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
|
||||||
|
|
||||||
|
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
|
||||||
|
|
||||||
|
cache = PromptEmbedsCache()
|
||||||
|
|
||||||
|
# get encoded latents for our prompts
|
||||||
|
with torch.no_grad():
|
||||||
|
if self.rescale_config.prompt_tensors is not None:
|
||||||
|
# check to see if it exists
|
||||||
|
if os.path.exists(self.rescale_config.prompt_tensors):
|
||||||
|
# load it.
|
||||||
|
self.print(f"Loading prompt tensors from {self.rescale_config.prompt_tensors}")
|
||||||
|
prompt_tensors = load_file(self.rescale_config.prompt_tensors, device='cpu')
|
||||||
|
# add them to the cache
|
||||||
|
for prompt_txt, prompt_tensor in prompt_tensors.items():
|
||||||
|
if prompt_txt.startswith("te:"):
|
||||||
|
prompt = prompt_txt[3:]
|
||||||
|
# text_embeds
|
||||||
|
text_embeds = prompt_tensor
|
||||||
|
pooled_embeds = None
|
||||||
|
# find pool embeds
|
||||||
|
if f"pe:{prompt}" in prompt_tensors:
|
||||||
|
pooled_embeds = prompt_tensors[f"pe:{prompt}"]
|
||||||
|
|
||||||
|
# make it
|
||||||
|
prompt_embeds = PromptEmbeds([text_embeds, pooled_embeds])
|
||||||
|
cache[prompt] = prompt_embeds.to(device='cpu', dtype=torch.float32)
|
||||||
|
|
||||||
|
if len(cache.prompts) == 0:
|
||||||
|
print("Prompt tensors not found. Encoding prompts..")
|
||||||
|
neutral = ""
|
||||||
|
# encode neutral
|
||||||
|
cache[neutral] = self.sd.encode_prompt(neutral)
|
||||||
|
for prompt in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False):
|
||||||
|
# build the cache
|
||||||
|
if cache[prompt] is None:
|
||||||
|
cache[prompt] = self.sd.encode_prompt(prompt).to(device="cpu", dtype=torch.float32)
|
||||||
|
|
||||||
|
if self.rescale_config.prompt_tensors:
|
||||||
|
print(f"Saving prompt tensors to {self.rescale_config.prompt_tensors}")
|
||||||
|
state_dict = {}
|
||||||
|
for prompt_txt, prompt_embeds in cache.prompts.items():
|
||||||
|
state_dict[f"te:{prompt_txt}"] = prompt_embeds.text_embeds.to("cpu", dtype=get_torch_dtype('fp16'))
|
||||||
|
if prompt_embeds.pooled_embeds is not None:
|
||||||
|
state_dict[f"pe:{prompt_txt}"] = prompt_embeds.pooled_embeds.to("cpu", dtype=get_torch_dtype('fp16'))
|
||||||
|
save_file(state_dict, self.rescale_config.prompt_tensors)
|
||||||
|
|
||||||
|
self.print("Encoding complete.")
|
||||||
|
|
||||||
|
# move to cpu to save vram
|
||||||
|
# We don't need text encoder anymore, but keep it on cpu for sampling
|
||||||
|
# if text encoder is list
|
||||||
|
if isinstance(self.sd.text_encoder, list):
|
||||||
|
for encoder in self.sd.text_encoder:
|
||||||
|
encoder.to("cpu")
|
||||||
|
else:
|
||||||
|
self.sd.text_encoder.to("cpu")
|
||||||
|
self.prompt_cache = cache
|
||||||
|
|
||||||
|
flush()
|
||||||
|
# end hook_before_train_loop
|
||||||
|
|
||||||
|
def hook_train_loop(self):
|
||||||
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
|
|
||||||
|
# get random encoded prompt from cache
|
||||||
|
prompt_txt = self.prompt_txt_list[
|
||||||
|
torch.randint(0, len(self.prompt_txt_list), (1,)).item()
|
||||||
|
]
|
||||||
|
prompt = self.prompt_cache[prompt_txt].to(device=self.device_torch, dtype=dtype)
|
||||||
|
neutral = self.prompt_cache[""].to(device=self.device_torch, dtype=dtype)
|
||||||
|
if prompt is None:
|
||||||
|
raise ValueError(f"Prompt {prompt_txt} is not in cache")
|
||||||
|
|
||||||
|
prompt_batch = train_tools.concat_prompt_embeddings(
|
||||||
|
prompt,
|
||||||
|
neutral,
|
||||||
|
self.train_config.batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
noise_scheduler = self.sd.noise_scheduler
|
||||||
|
optimizer = self.optimizer
|
||||||
|
lr_scheduler = self.lr_scheduler
|
||||||
|
loss_function = torch.nn.MSELoss()
|
||||||
|
|
||||||
|
def get_noise_pred(p, n, gs, cts, dn):
|
||||||
|
return self.predict_noise(
|
||||||
|
latents=dn,
|
||||||
|
text_embeddings=train_tools.concat_prompt_embeddings(
|
||||||
|
p, # unconditional
|
||||||
|
n, # positive
|
||||||
|
self.train_config.batch_size,
|
||||||
|
),
|
||||||
|
timestep=cts,
|
||||||
|
guidance_scale=gs,
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
self.sd.noise_scheduler.set_timesteps(
|
||||||
|
self.train_config.max_denoising_steps, device=self.device_torch
|
||||||
|
)
|
||||||
|
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
# # ger a random number of steps
|
||||||
|
timesteps_to = torch.randint(
|
||||||
|
1, self.train_config.max_denoising_steps, (1,)
|
||||||
|
).item()
|
||||||
|
|
||||||
|
# get noise
|
||||||
|
noise = self.get_latent_noise(
|
||||||
|
pixel_height=self.rescale_config.from_resolution,
|
||||||
|
pixel_width=self.rescale_config.from_resolution,
|
||||||
|
).to(self.device_torch, dtype=dtype)
|
||||||
|
|
||||||
|
# get latents
|
||||||
|
latents = noise * self.sd.noise_scheduler.init_noise_sigma
|
||||||
|
latents = latents.to(self.device_torch, dtype=dtype)
|
||||||
|
#
|
||||||
|
# # predict without network
|
||||||
|
# assert self.network.is_active is False
|
||||||
|
# denoised_latents = self.diffuse_some_steps(
|
||||||
|
# latents, # pass simple noise latents
|
||||||
|
# prompt_batch,
|
||||||
|
# start_timesteps=0,
|
||||||
|
# total_timesteps=timesteps_to,
|
||||||
|
# guidance_scale=3,
|
||||||
|
# )
|
||||||
|
# noise_scheduler.set_timesteps(1000)
|
||||||
|
#
|
||||||
|
# current_timestep = noise_scheduler.timesteps[
|
||||||
|
# int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
|
||||||
|
# ]
|
||||||
|
|
||||||
|
current_timestep = 0
|
||||||
|
denoised_latents = latents
|
||||||
|
# get noise prediction at full scale
|
||||||
|
from_prediction = get_noise_pred(
|
||||||
|
prompt, neutral, 1, current_timestep, denoised_latents
|
||||||
|
)
|
||||||
|
|
||||||
|
reduced_from_prediction = self.reduce_size_fn(from_prediction).to("cpu", dtype=torch.float32)
|
||||||
|
|
||||||
|
# get noise prediction at reduced scale
|
||||||
|
to_denoised_latents = self.reduce_size_fn(denoised_latents)
|
||||||
|
|
||||||
|
# start gradient
|
||||||
|
optimizer.zero_grad()
|
||||||
|
self.network.multiplier = 1.0
|
||||||
|
with self.network:
|
||||||
|
assert self.network.is_active is True
|
||||||
|
to_prediction = get_noise_pred(
|
||||||
|
prompt, neutral, 1, current_timestep, to_denoised_latents
|
||||||
|
).to("cpu", dtype=torch.float32)
|
||||||
|
|
||||||
|
reduced_from_prediction.requires_grad = False
|
||||||
|
from_prediction.requires_grad = False
|
||||||
|
|
||||||
|
loss = loss_function(
|
||||||
|
reduced_from_prediction,
|
||||||
|
to_prediction,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss_float = loss.item()
|
||||||
|
|
||||||
|
loss = loss.to(self.device_torch)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
|
||||||
|
del (
|
||||||
|
reduced_from_prediction,
|
||||||
|
from_prediction,
|
||||||
|
to_denoised_latents,
|
||||||
|
to_prediction,
|
||||||
|
latents,
|
||||||
|
)
|
||||||
|
flush()
|
||||||
|
|
||||||
|
# reset network
|
||||||
|
self.network.multiplier = 1.0
|
||||||
|
|
||||||
|
loss_dict = OrderedDict(
|
||||||
|
{'loss': loss_float},
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss_dict
|
||||||
|
# end hook_train_loop
|
||||||
@@ -669,7 +669,7 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
if self.writer is not None:
|
if self.writer is not None:
|
||||||
# get avg loss
|
# get avg loss
|
||||||
for key in log_losses:
|
for key in log_losses:
|
||||||
log_losses[key] = sum(log_losses[key]) / len(log_losses[key])
|
log_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + 1e-6)
|
||||||
# if log_losses[key] > 0:
|
# if log_losses[key] > 0:
|
||||||
self.writer.add_scalar(f"loss/{key}", log_losses[key], self.step_num)
|
self.writer.add_scalar(f"loss/{key}", log_losses[key], self.step_num)
|
||||||
# reset log losses
|
# reset log losses
|
||||||
@@ -678,9 +678,10 @@ class TrainVAEProcess(BaseTrainProcess):
|
|||||||
self.step_num += 1
|
self.step_num += 1
|
||||||
# end epoch
|
# end epoch
|
||||||
if self.writer is not None:
|
if self.writer is not None:
|
||||||
|
eps = 1e-6
|
||||||
# get avg loss
|
# get avg loss
|
||||||
for key in epoch_losses:
|
for key in epoch_losses:
|
||||||
epoch_losses[key] = sum(log_losses[key]) / len(log_losses[key])
|
epoch_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + eps)
|
||||||
if epoch_losses[key] > 0:
|
if epoch_losses[key] > 0:
|
||||||
self.writer.add_scalar(f"epoch loss/{key}", epoch_losses[key], epoch)
|
self.writer.add_scalar(f"epoch loss/{key}", epoch_losses[key], epoch)
|
||||||
# reset epoch losses
|
# reset epoch losses
|
||||||
|
|||||||
@@ -7,3 +7,4 @@ from .TrainVAEProcess import TrainVAEProcess
|
|||||||
from .BaseMergeProcess import BaseMergeProcess
|
from .BaseMergeProcess import BaseMergeProcess
|
||||||
from .TrainSliderProcess import TrainSliderProcess
|
from .TrainSliderProcess import TrainSliderProcess
|
||||||
from .TrainLoRAHack import TrainLoRAHack
|
from .TrainLoRAHack import TrainLoRAHack
|
||||||
|
from .TrainSDRescaleProcess import TrainSDRescaleProcess
|
||||||
@@ -10,4 +10,6 @@ pyyaml
|
|||||||
oyaml
|
oyaml
|
||||||
tensorboard
|
tensorboard
|
||||||
kornia
|
kornia
|
||||||
invisible-watermark
|
invisible-watermark
|
||||||
|
einops
|
||||||
|
accelerate
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ class SaveConfig:
|
|||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.save_every: int = kwargs.get('save_every', 1000)
|
self.save_every: int = kwargs.get('save_every', 1000)
|
||||||
self.dtype: str = kwargs.get('save_dtype', 'float16')
|
self.dtype: str = kwargs.get('save_dtype', 'float16')
|
||||||
|
self.max_step_saves_to_keep: int = kwargs.get('max_step_saves_to_keep', 5)
|
||||||
|
|
||||||
|
|
||||||
class LogingConfig:
|
class LogingConfig:
|
||||||
@@ -30,8 +31,16 @@ class SampleConfig:
|
|||||||
|
|
||||||
class NetworkConfig:
|
class NetworkConfig:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.type: str = kwargs.get('type', 'lierla')
|
self.type: str = kwargs.get('type', 'lora')
|
||||||
self.rank: int = kwargs.get('rank', 4)
|
rank = kwargs.get('rank', None)
|
||||||
|
linear = kwargs.get('linear', None)
|
||||||
|
if rank is not None:
|
||||||
|
self.rank: int = rank # rank for backward compatibility
|
||||||
|
self.linear: int = rank
|
||||||
|
elif linear is not None:
|
||||||
|
self.rank: int = linear
|
||||||
|
self.linear: int = linear
|
||||||
|
self.conv: int = kwargs.get('conv', None)
|
||||||
self.alpha: float = kwargs.get('alpha', 1.0)
|
self.alpha: float = kwargs.get('alpha', 1.0)
|
||||||
|
|
||||||
|
|
||||||
@@ -51,6 +60,7 @@ class TrainConfig:
|
|||||||
self.noise_offset = kwargs.get('noise_offset', 0.0)
|
self.noise_offset = kwargs.get('noise_offset', 0.0)
|
||||||
self.optimizer_params = kwargs.get('optimizer_params', {})
|
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||||
self.skip_first_sample = kwargs.get('skip_first_sample', False)
|
self.skip_first_sample = kwargs.get('skip_first_sample', False)
|
||||||
|
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', False)
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
|
|||||||
31
toolkit/layers.py
Normal file
31
toolkit/layers.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class ReductionKernel(nn.Module):
|
||||||
|
# Tensorflow
|
||||||
|
def __init__(self, in_channels, kernel_size=2, dtype=torch.float32, device=None):
|
||||||
|
if device is None:
|
||||||
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
super(ReductionKernel, self).__init__()
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.in_channels = in_channels
|
||||||
|
numpy_kernel = self.build_kernel()
|
||||||
|
self.kernel = torch.from_numpy(numpy_kernel).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def build_kernel(self):
|
||||||
|
# tensorflow kernel is (height, width, in_channels, out_channels)
|
||||||
|
# pytorch kernel is (out_channels, in_channels, height, width)
|
||||||
|
kernel_size = self.kernel_size
|
||||||
|
channels = self.in_channels
|
||||||
|
kernel_shape = [channels, channels, kernel_size, kernel_size]
|
||||||
|
kernel = np.zeros(kernel_shape, np.float32)
|
||||||
|
|
||||||
|
kernel_value = 1.0 / (kernel_size * kernel_size)
|
||||||
|
for i in range(0, channels):
|
||||||
|
kernel[i, i, :, :] = kernel_value
|
||||||
|
return kernel
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return nn.functional.conv2d(x, self.kernel, stride=self.kernel_size, padding=0, groups=1)
|
||||||
@@ -1,5 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
from info import software_meta
|
from info import software_meta
|
||||||
|
|
||||||
|
|
||||||
@@ -25,4 +28,10 @@ def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict:
|
|||||||
parsed_meta[key] = json.loads(value)
|
parsed_meta[key] = json.loads(value)
|
||||||
except json.decoder.JSONDecodeError:
|
except json.decoder.JSONDecodeError:
|
||||||
parsed_meta[key] = value
|
parsed_meta[key] = value
|
||||||
return meta
|
return parsed_meta
|
||||||
|
|
||||||
|
|
||||||
|
def load_metadata_from_safetensors(file_path: str) -> OrderedDict:
|
||||||
|
with safe_open(file_path, framework="pt") as f:
|
||||||
|
metadata = f.metadata()
|
||||||
|
return parse_metadata_from_safetensors(metadata)
|
||||||
|
|||||||
@@ -1,11 +1,18 @@
|
|||||||
from typing import Union
|
from typing import Union, OrderedDict
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
from toolkit.paths import REPOS_ROOT
|
from toolkit.paths import REPOS_ROOT
|
||||||
|
from toolkit.train_tools import get_torch_dtype
|
||||||
|
|
||||||
sys.path.append(REPOS_ROOT)
|
sys.path.append(REPOS_ROOT)
|
||||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||||
from leco import train_util
|
from leco import train_util
|
||||||
import torch
|
import torch
|
||||||
|
from library import model_util
|
||||||
|
from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl
|
||||||
|
|
||||||
|
|
||||||
class PromptEmbeds:
|
class PromptEmbeds:
|
||||||
@@ -22,6 +29,12 @@ class PromptEmbeds:
|
|||||||
self.text_embeds = args
|
self.text_embeds = args
|
||||||
self.pooled_embeds = None
|
self.pooled_embeds = None
|
||||||
|
|
||||||
|
def to(self, **kwargs):
|
||||||
|
self.text_embeds = self.text_embeds.to(**kwargs)
|
||||||
|
if self.pooled_embeds is not None:
|
||||||
|
self.pooled_embeds = self.pooled_embeds.to(**kwargs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusion:
|
class StableDiffusion:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -61,3 +74,41 @@ class StableDiffusion:
|
|||||||
self.tokenizer, self.text_encoder, prompt
|
self.tokenizer, self.text_encoder, prompt
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
||||||
|
# todo see what logit scale is
|
||||||
|
if self.is_xl:
|
||||||
|
|
||||||
|
state_dict = {}
|
||||||
|
|
||||||
|
def update_sd(prefix, sd):
|
||||||
|
for k, v in sd.items():
|
||||||
|
key = prefix + k
|
||||||
|
v = v.detach().clone().to("cpu").to(get_torch_dtype(save_dtype))
|
||||||
|
state_dict[key] = v
|
||||||
|
|
||||||
|
# Convert the UNet model
|
||||||
|
update_sd("model.diffusion_model.", self.unet.state_dict())
|
||||||
|
|
||||||
|
# Convert the text encoders
|
||||||
|
update_sd("conditioner.embedders.0.transformer.", self.text_encoder[0].state_dict())
|
||||||
|
|
||||||
|
text_enc2_dict = convert_text_encoder_2_state_dict_to_sdxl(self.text_encoder[1].state_dict(), logit_scale)
|
||||||
|
update_sd("conditioner.embedders.1.model.", text_enc2_dict)
|
||||||
|
|
||||||
|
# Convert the VAE
|
||||||
|
vae_dict = model_util.convert_vae_state_dict(self.vae.state_dict())
|
||||||
|
update_sd("first_stage_model.", vae_dict)
|
||||||
|
|
||||||
|
# Put together new checkpoint
|
||||||
|
key_count = len(state_dict.keys())
|
||||||
|
new_ckpt = {"state_dict": state_dict}
|
||||||
|
|
||||||
|
if model_util.is_safetensors(output_file):
|
||||||
|
save_file(state_dict, output_file)
|
||||||
|
else:
|
||||||
|
torch.save(new_ckpt, output_file, meta)
|
||||||
|
|
||||||
|
return key_count
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("sdv1.x, sdv2.x is not implemented yet")
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
@@ -21,8 +22,6 @@ from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipel
|
|||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
|
||||||
|
|
||||||
SCHEDULER_LINEAR_START = 0.00085
|
SCHEDULER_LINEAR_START = 0.00085
|
||||||
SCHEDULER_LINEAR_END = 0.0120
|
SCHEDULER_LINEAR_END = 0.0120
|
||||||
SCHEDULER_TIMESTEPS = 1000
|
SCHEDULER_TIMESTEPS = 1000
|
||||||
@@ -381,11 +380,16 @@ def apply_noise_offset(noise, noise_offset):
|
|||||||
return noise
|
return noise
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||||
|
|
||||||
|
|
||||||
def concat_prompt_embeddings(
|
def concat_prompt_embeddings(
|
||||||
unconditional: PromptEmbeds,
|
unconditional: 'PromptEmbeds',
|
||||||
conditional: PromptEmbeds,
|
conditional: 'PromptEmbeds',
|
||||||
n_imgs: int,
|
n_imgs: int,
|
||||||
):
|
):
|
||||||
|
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||||
text_embeds = torch.cat(
|
text_embeds = torch.cat(
|
||||||
[unconditional.text_embeds, conditional.text_embeds]
|
[unconditional.text_embeds, conditional.text_embeds]
|
||||||
).repeat_interleave(n_imgs, dim=0)
|
).repeat_interleave(n_imgs, dim=0)
|
||||||
|
|||||||
Reference in New Issue
Block a user