mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Reworked the sd rescaler script
This commit is contained in:
@@ -35,7 +35,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.start_step = 0
|
||||
self.device = self.get_conf('device', self.job.device)
|
||||
self.device_torch = torch.device(self.device)
|
||||
self.network_config = NetworkConfig(**self.get_conf('network', None))
|
||||
network_config = self.get_conf('network', None)
|
||||
if network_config is not None:
|
||||
self.network_config = NetworkConfig(**network_config)
|
||||
else:
|
||||
self.network_config = None
|
||||
self.training_folder = self.get_conf('training_folder', self.job.training_folder)
|
||||
self.train_config = TrainConfig(**self.get_conf('train', {}))
|
||||
self.model_config = ModelConfig(**self.get_conf('model', {}))
|
||||
|
||||
@@ -1,24 +1,14 @@
|
||||
# ref:
|
||||
# - https://github.com/p1atdev/LECO/blob/main/train_lora.py
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
import glob
|
||||
import os
|
||||
from typing import Optional
|
||||
from collections import OrderedDict
|
||||
import random
|
||||
from typing import Optional, List
|
||||
|
||||
import numpy as np
|
||||
from safetensors.torch import load_file, save_file
|
||||
from safetensors.torch import save_file, load_file
|
||||
from tqdm import tqdm
|
||||
|
||||
from toolkit.config_modules import SliderConfig
|
||||
from toolkit.layers import ReductionKernel
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
import sys
|
||||
|
||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
||||
from toolkit.train_pipelines import TransferStableDiffusionXLPipeline
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
import gc
|
||||
from toolkit import train_tools
|
||||
@@ -40,14 +30,11 @@ class RescaleConfig:
|
||||
):
|
||||
self.from_resolution = kwargs.get('from_resolution', 512)
|
||||
self.scale = kwargs.get('scale', 0.5)
|
||||
self.prompt_file = kwargs.get('prompt_file', None)
|
||||
self.prompt_tensors = kwargs.get('prompt_tensors', None)
|
||||
self.latent_tensor_dir = kwargs.get('latent_tensor_dir', None)
|
||||
self.num_latent_tensors = kwargs.get('num_latent_tensors', 1000)
|
||||
self.to_resolution = kwargs.get('to_resolution', int(self.from_resolution * self.scale))
|
||||
self.prompt_dropout = kwargs.get('prompt_dropout', 0.1)
|
||||
|
||||
if self.prompt_file is None:
|
||||
raise ValueError("prompt_file is required")
|
||||
|
||||
|
||||
class PromptEmbedsCache:
|
||||
prompts: dict[str, PromptEmbeds] = {}
|
||||
@@ -70,7 +57,6 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
|
||||
self.start_step = 0
|
||||
self.device = self.get_conf('device', self.job.device)
|
||||
self.device_torch = torch.device(self.device)
|
||||
self.prompt_cache = PromptEmbedsCache()
|
||||
self.rescale_config = RescaleConfig(**self.get_conf('rescale', required=True))
|
||||
self.reduce_size_fn = ReductionKernel(
|
||||
in_channels=4,
|
||||
@@ -78,80 +64,148 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
|
||||
dtype=get_torch_dtype(self.train_config.dtype),
|
||||
device=self.device_torch,
|
||||
)
|
||||
self.prompt_txt_list = []
|
||||
|
||||
self.latent_paths: List[str] = []
|
||||
self.empty_embedding: PromptEmbeds = None
|
||||
|
||||
def before_model_load(self):
|
||||
pass
|
||||
|
||||
def get_latent_tensors(self):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
num_to_generate = 0
|
||||
# check if dir exists
|
||||
if not os.path.exists(self.rescale_config.latent_tensor_dir):
|
||||
os.makedirs(self.rescale_config.latent_tensor_dir)
|
||||
num_to_generate = self.rescale_config.num_latent_tensors
|
||||
else:
|
||||
# find existing
|
||||
current_tensor_list = glob.glob(os.path.join(self.rescale_config.latent_tensor_dir, "*.safetensors"))
|
||||
num_to_generate = self.rescale_config.num_latent_tensors - len(current_tensor_list)
|
||||
self.latent_paths = current_tensor_list
|
||||
|
||||
if num_to_generate > 0:
|
||||
print(f"Generating {num_to_generate}/{self.rescale_config.num_latent_tensors} latent tensors")
|
||||
|
||||
# unload other model
|
||||
self.sd.unet.to('cpu')
|
||||
|
||||
# load aux network
|
||||
self.sd_parent = StableDiffusion(
|
||||
self.device_torch,
|
||||
model_config=self.model_config,
|
||||
dtype=self.train_config.dtype,
|
||||
)
|
||||
self.sd_parent.load_model()
|
||||
self.sd_parent.unet.to(self.device_torch, dtype=dtype)
|
||||
# we dont need text encoder for this
|
||||
|
||||
del self.sd_parent.text_encoder
|
||||
del self.sd_parent.tokenizer
|
||||
|
||||
self.sd_parent.unet.eval()
|
||||
self.sd_parent.unet.requires_grad_(False)
|
||||
|
||||
# save current seed state for training
|
||||
rng_state = torch.get_rng_state()
|
||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||
|
||||
text_embeddings = train_tools.concat_prompt_embeddings(
|
||||
self.empty_embedding, # unconditional (negative prompt)
|
||||
self.empty_embedding, # conditional (positive prompt)
|
||||
self.train_config.batch_size,
|
||||
)
|
||||
torch.set_default_device(self.device_torch)
|
||||
|
||||
for i in tqdm(range(num_to_generate)):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
# get a random seed
|
||||
seed = torch.randint(0, 2 ** 32, (1,)).item()
|
||||
# zero pad seed string to max length
|
||||
seed_string = str(seed).zfill(10)
|
||||
# set seed
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
# # ger a random number of steps
|
||||
timesteps_to = self.train_config.max_denoising_steps
|
||||
|
||||
# set the scheduler to the number of steps
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
timesteps_to, device=self.device_torch
|
||||
)
|
||||
|
||||
noise = self.sd.get_latent_noise(
|
||||
pixel_height=self.rescale_config.from_resolution,
|
||||
pixel_width=self.rescale_config.from_resolution,
|
||||
batch_size=self.train_config.batch_size,
|
||||
noise_offset=self.train_config.noise_offset,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
|
||||
# get latents
|
||||
latents = noise * self.sd.noise_scheduler.init_noise_sigma
|
||||
latents = latents.to(self.device_torch, dtype=dtype)
|
||||
|
||||
# get random guidance scale from 1.0 to 10.0 (CFG)
|
||||
guidance_scale = torch.rand(1).item() * 9.0 + 1.0
|
||||
|
||||
# do a timestep of 1
|
||||
timestep = 1
|
||||
|
||||
noise_pred_target = self.sd_parent.predict_noise(
|
||||
latents,
|
||||
text_embeddings=text_embeddings,
|
||||
timestep=timestep,
|
||||
guidance_scale=guidance_scale
|
||||
)
|
||||
|
||||
# build state dict
|
||||
state_dict = OrderedDict()
|
||||
state_dict['noise_pred_target'] = noise_pred_target.to('cpu', dtype=torch.float16)
|
||||
state_dict['latents'] = latents.to('cpu', dtype=torch.float16)
|
||||
state_dict['guidance_scale'] = torch.tensor(guidance_scale).to('cpu', dtype=torch.float16)
|
||||
state_dict['timestep'] = torch.tensor(timestep).to('cpu', dtype=torch.float16)
|
||||
state_dict['timesteps_to'] = torch.tensor(timesteps_to).to('cpu', dtype=torch.float16)
|
||||
state_dict['seed'] = torch.tensor(seed).to('cpu', dtype=torch.float32) # must be float 32 to prevent overflow
|
||||
|
||||
file_name = f"{seed_string}_{i}.safetensors"
|
||||
file_path = os.path.join(self.rescale_config.latent_tensor_dir, file_name)
|
||||
save_file(state_dict, file_path)
|
||||
self.latent_paths.append(file_path)
|
||||
|
||||
print("Removing parent model")
|
||||
# delete parent
|
||||
del self.sd_parent
|
||||
flush()
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
if cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
self.sd.unet.to(self.device_torch, dtype=dtype)
|
||||
|
||||
def hook_before_train_loop(self):
|
||||
self.print(f"Loading prompt file from {self.rescale_config.prompt_file}")
|
||||
# encode our empty prompt
|
||||
self.empty_embedding = self.sd.encode_prompt("")
|
||||
self.empty_embedding = self.empty_embedding.to(self.device_torch,
|
||||
dtype=get_torch_dtype(self.train_config.dtype))
|
||||
|
||||
# read line by line from file
|
||||
with open(self.rescale_config.prompt_file, 'r', encoding='utf-8') as f:
|
||||
self.prompt_txt_list = f.readlines()
|
||||
# clean empty lines
|
||||
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
|
||||
|
||||
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
|
||||
|
||||
cache = PromptEmbedsCache()
|
||||
|
||||
# get encoded latents for our prompts
|
||||
with torch.no_grad():
|
||||
if self.rescale_config.prompt_tensors is not None:
|
||||
# check to see if it exists
|
||||
if os.path.exists(self.rescale_config.prompt_tensors):
|
||||
# load it.
|
||||
self.print(f"Loading prompt tensors from {self.rescale_config.prompt_tensors}")
|
||||
prompt_tensors = load_file(self.rescale_config.prompt_tensors, device='cpu')
|
||||
# add them to the cache
|
||||
for prompt_txt, prompt_tensor in prompt_tensors.items():
|
||||
if prompt_txt.startswith("te:"):
|
||||
prompt = prompt_txt[3:]
|
||||
# text_embeds
|
||||
text_embeds = prompt_tensor
|
||||
pooled_embeds = None
|
||||
# find pool embeds
|
||||
if f"pe:{prompt}" in prompt_tensors:
|
||||
pooled_embeds = prompt_tensors[f"pe:{prompt}"]
|
||||
|
||||
# make it
|
||||
prompt_embeds = PromptEmbeds([text_embeds, pooled_embeds])
|
||||
cache[prompt] = prompt_embeds.to(device='cpu', dtype=torch.float32)
|
||||
|
||||
if len(cache.prompts) == 0:
|
||||
print("Prompt tensors not found. Encoding prompts..")
|
||||
neutral = ""
|
||||
# encode neutral
|
||||
cache[neutral] = self.sd.encode_prompt(neutral)
|
||||
for prompt in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False):
|
||||
# build the cache
|
||||
if cache[prompt] is None:
|
||||
cache[prompt] = self.sd.encode_prompt(prompt).to(device="cpu", dtype=torch.float32)
|
||||
|
||||
if self.rescale_config.prompt_tensors:
|
||||
print(f"Saving prompt tensors to {self.rescale_config.prompt_tensors}")
|
||||
state_dict = {}
|
||||
for prompt_txt, prompt_embeds in cache.prompts.items():
|
||||
state_dict[f"te:{prompt_txt}"] = prompt_embeds.text_embeds.to("cpu",
|
||||
dtype=get_torch_dtype('fp16'))
|
||||
if prompt_embeds.pooled_embeds is not None:
|
||||
state_dict[f"pe:{prompt_txt}"] = prompt_embeds.pooled_embeds.to("cpu",
|
||||
dtype=get_torch_dtype(
|
||||
'fp16'))
|
||||
save_file(state_dict, self.rescale_config.prompt_tensors)
|
||||
|
||||
self.print("Encoding complete.")
|
||||
|
||||
# move to cpu to save vram
|
||||
# We don't need text encoder anymore, but keep it on cpu for sampling
|
||||
# if text encoder is list
|
||||
# Move train model encoder to cpu
|
||||
if isinstance(self.sd.text_encoder, list):
|
||||
for encoder in self.sd.text_encoder:
|
||||
encoder.to("cpu")
|
||||
encoder.to('cpu')
|
||||
encoder.eval()
|
||||
encoder.requires_grad_(False)
|
||||
else:
|
||||
self.sd.text_encoder.to("cpu")
|
||||
self.prompt_cache = cache
|
||||
self.sd.text_encoder.to('cpu')
|
||||
self.sd.text_encoder.eval()
|
||||
self.sd.text_encoder.requires_grad_(False)
|
||||
|
||||
# self.sd.unet.to('cpu')
|
||||
flush()
|
||||
|
||||
self.get_latent_tensors()
|
||||
|
||||
flush()
|
||||
# end hook_before_train_loop
|
||||
@@ -159,142 +213,64 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
|
||||
def hook_train_loop(self):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
|
||||
do_dropout = False
|
||||
|
||||
# see if we should dropout
|
||||
if self.rescale_config.prompt_dropout > 0.0:
|
||||
thresh = int(self.rescale_config.prompt_dropout * 100)
|
||||
if torch.randint(0, 100, (1,)).item() < thresh:
|
||||
do_dropout = True
|
||||
|
||||
# get random encoded prompt from cache
|
||||
positive_prompt_txt = self.prompt_txt_list[
|
||||
torch.randint(0, len(self.prompt_txt_list), (1,)).item()
|
||||
]
|
||||
negative_prompt_txt = self.prompt_txt_list[
|
||||
torch.randint(0, len(self.prompt_txt_list), (1,)).item()
|
||||
]
|
||||
if do_dropout:
|
||||
positive_prompt = self.prompt_cache[''].to(device=self.device_torch, dtype=dtype)
|
||||
negative_prompt = self.prompt_cache[''].to(device=self.device_torch, dtype=dtype)
|
||||
else:
|
||||
positive_prompt = self.prompt_cache[positive_prompt_txt].to(device=self.device_torch, dtype=dtype)
|
||||
negative_prompt = self.prompt_cache[negative_prompt_txt].to(device=self.device_torch, dtype=dtype)
|
||||
|
||||
if positive_prompt is None:
|
||||
raise ValueError(f"Prompt {positive_prompt_txt} is not in cache")
|
||||
if negative_prompt is None:
|
||||
raise ValueError(f"Prompt {negative_prompt_txt} is not in cache")
|
||||
|
||||
loss_function = torch.nn.MSELoss()
|
||||
|
||||
# train it
|
||||
# Begin gradient accumulation
|
||||
self.sd.unet.train()
|
||||
self.sd.unet.requires_grad_(True)
|
||||
self.sd.unet.to(self.device_torch, dtype=dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# # ger a random number of steps
|
||||
timesteps_to = torch.randint(
|
||||
1, self.train_config.max_denoising_steps, (1,)
|
||||
).item()
|
||||
# pick random latent tensor
|
||||
latent_path = random.choice(self.latent_paths)
|
||||
latent_tensor = load_file(latent_path)
|
||||
|
||||
# set the scheduler to the number of steps
|
||||
noise_pred_target = (latent_tensor['noise_pred_target']).to(self.device_torch, dtype=dtype)
|
||||
latents = (latent_tensor['latents']).to(self.device_torch, dtype=dtype)
|
||||
guidance_scale = (latent_tensor['guidance_scale']).item()
|
||||
timestep = int((latent_tensor['timestep']).item())
|
||||
timesteps_to = int((latent_tensor['timesteps_to']).item())
|
||||
# seed = int((latent_tensor['seed']).item())
|
||||
|
||||
text_embeddings = train_tools.concat_prompt_embeddings(
|
||||
self.empty_embedding, # unconditional (negative prompt)
|
||||
self.empty_embedding, # conditional (positive prompt)
|
||||
self.train_config.batch_size,
|
||||
)
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
timesteps_to, device=self.device_torch
|
||||
)
|
||||
|
||||
# get noise
|
||||
noise = self.sd.get_latent_noise(
|
||||
pixel_height=self.rescale_config.from_resolution,
|
||||
pixel_width=self.rescale_config.from_resolution,
|
||||
batch_size=self.train_config.batch_size,
|
||||
noise_offset=self.train_config.noise_offset,
|
||||
).to(self.device_torch, dtype=dtype)
|
||||
denoised_target = self.sd.noise_scheduler.step(noise_pred_target, timestep, latents).prev_sample
|
||||
|
||||
torch.set_default_device(self.device_torch)
|
||||
# get the reduced latents
|
||||
# reduced_pred = self.reduce_size_fn(noise_pred_target.detach())
|
||||
denoised_target = self.reduce_size_fn(denoised_target.detach())
|
||||
reduced_latents = self.reduce_size_fn(latents.detach())
|
||||
|
||||
# get latents
|
||||
latents = noise * self.sd.noise_scheduler.init_noise_sigma
|
||||
latents = latents.to(self.device_torch, dtype=dtype)
|
||||
|
||||
# get random guidance scale from 1.0 to 10.0 (CFG)
|
||||
guidance_scale = torch.rand(1).item() * 9.0 + 1.0
|
||||
|
||||
loss_arr = []
|
||||
|
||||
max_len_timestep_str = len(str(self.train_config.max_denoising_steps))
|
||||
# pad with spaces
|
||||
timestep_str = str(timesteps_to).rjust(max_len_timestep_str, " ")
|
||||
new_description = f"{self.job.name} ts: {timestep_str}"
|
||||
self.progress_bar.set_description(new_description)
|
||||
|
||||
# Begin gradient accumulation
|
||||
denoised_target.requires_grad = False
|
||||
self.optimizer.zero_grad()
|
||||
noise_pred_train = self.sd.predict_noise(
|
||||
reduced_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
timestep=timestep,
|
||||
guidance_scale=guidance_scale
|
||||
)
|
||||
denoised_pred = self.sd.noise_scheduler.step(noise_pred_train, timestep, reduced_latents).prev_sample
|
||||
loss = loss_function(denoised_pred, denoised_target)
|
||||
loss_float = loss.item()
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
self.lr_scheduler.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# perform the diffusion
|
||||
for timestep in tqdm(self.sd.noise_scheduler.timesteps, leave=False):
|
||||
assert not self.network.is_active
|
||||
|
||||
text_embeddings = train_tools.concat_prompt_embeddings(
|
||||
negative_prompt, # unconditional (negative prompt)
|
||||
positive_prompt, # conditional (positive prompt)
|
||||
self.train_config.batch_size,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
noise_pred_target = self.sd.predict_noise(
|
||||
latents,
|
||||
text_embeddings=text_embeddings,
|
||||
timestep=timestep,
|
||||
guidance_scale=guidance_scale
|
||||
)
|
||||
|
||||
# todo should we do every step?
|
||||
do_train_cycle = True
|
||||
|
||||
if do_train_cycle:
|
||||
# get the reduced latents
|
||||
with torch.no_grad():
|
||||
reduced_pred = self.reduce_size_fn(noise_pred_target.detach())
|
||||
reduced_latents = self.reduce_size_fn(latents.detach())
|
||||
with self.network:
|
||||
assert self.network.is_active
|
||||
self.network.multiplier = 1.0
|
||||
noise_pred_train = self.sd.predict_noise(
|
||||
reduced_latents,
|
||||
text_embeddings=text_embeddings,
|
||||
timestep=timestep,
|
||||
guidance_scale=guidance_scale
|
||||
)
|
||||
|
||||
reduced_pred.requires_grad = False
|
||||
loss = loss_function(noise_pred_train, reduced_pred)
|
||||
loss_arr.append(loss.item())
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
self.lr_scheduler.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# get next latents
|
||||
# todo allow to show latent here
|
||||
latents = self.sd.noise_scheduler.step(noise_pred_target, timestep, latents).prev_sample
|
||||
|
||||
# reset prompt embeds
|
||||
positive_prompt.to(device="cpu")
|
||||
negative_prompt.to(device="cpu")
|
||||
|
||||
flush()
|
||||
|
||||
# reset network
|
||||
self.network.multiplier = 1.0
|
||||
|
||||
# average losses
|
||||
s = 0
|
||||
for num in loss_arr:
|
||||
s += num
|
||||
|
||||
avg_loss = s / len(loss_arr)
|
||||
|
||||
loss_dict = OrderedDict(
|
||||
{'loss': avg_loss},
|
||||
{'loss': loss_float},
|
||||
)
|
||||
|
||||
return loss_dict
|
||||
|
||||
Reference in New Issue
Block a user