Reworked the sd rescaler script

This commit is contained in:
Jaret Burkett
2023-08-09 08:57:27 -06:00
parent bf90740b59
commit fbc8a87a05
2 changed files with 186 additions and 206 deletions

View File

@@ -35,7 +35,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.start_step = 0
self.device = self.get_conf('device', self.job.device)
self.device_torch = torch.device(self.device)
self.network_config = NetworkConfig(**self.get_conf('network', None))
network_config = self.get_conf('network', None)
if network_config is not None:
self.network_config = NetworkConfig(**network_config)
else:
self.network_config = None
self.training_folder = self.get_conf('training_folder', self.job.training_folder)
self.train_config = TrainConfig(**self.get_conf('train', {}))
self.model_config = ModelConfig(**self.get_conf('model', {}))

View File

@@ -1,24 +1,14 @@
# ref:
# - https://github.com/p1atdev/LECO/blob/main/train_lora.py
import time
from collections import OrderedDict
import glob
import os
from typing import Optional
from collections import OrderedDict
import random
from typing import Optional, List
import numpy as np
from safetensors.torch import load_file, save_file
from safetensors.torch import save_file, load_file
from tqdm import tqdm
from toolkit.config_modules import SliderConfig
from toolkit.layers import ReductionKernel
from toolkit.paths import REPOS_ROOT
import sys
from toolkit.stable_diffusion_model import PromptEmbeds
from toolkit.train_pipelines import TransferStableDiffusionXLPipeline
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
import gc
from toolkit import train_tools
@@ -40,14 +30,11 @@ class RescaleConfig:
):
self.from_resolution = kwargs.get('from_resolution', 512)
self.scale = kwargs.get('scale', 0.5)
self.prompt_file = kwargs.get('prompt_file', None)
self.prompt_tensors = kwargs.get('prompt_tensors', None)
self.latent_tensor_dir = kwargs.get('latent_tensor_dir', None)
self.num_latent_tensors = kwargs.get('num_latent_tensors', 1000)
self.to_resolution = kwargs.get('to_resolution', int(self.from_resolution * self.scale))
self.prompt_dropout = kwargs.get('prompt_dropout', 0.1)
if self.prompt_file is None:
raise ValueError("prompt_file is required")
class PromptEmbedsCache:
prompts: dict[str, PromptEmbeds] = {}
@@ -70,7 +57,6 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
self.start_step = 0
self.device = self.get_conf('device', self.job.device)
self.device_torch = torch.device(self.device)
self.prompt_cache = PromptEmbedsCache()
self.rescale_config = RescaleConfig(**self.get_conf('rescale', required=True))
self.reduce_size_fn = ReductionKernel(
in_channels=4,
@@ -78,80 +64,148 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
dtype=get_torch_dtype(self.train_config.dtype),
device=self.device_torch,
)
self.prompt_txt_list = []
self.latent_paths: List[str] = []
self.empty_embedding: PromptEmbeds = None
def before_model_load(self):
pass
def get_latent_tensors(self):
dtype = get_torch_dtype(self.train_config.dtype)
num_to_generate = 0
# check if dir exists
if not os.path.exists(self.rescale_config.latent_tensor_dir):
os.makedirs(self.rescale_config.latent_tensor_dir)
num_to_generate = self.rescale_config.num_latent_tensors
else:
# find existing
current_tensor_list = glob.glob(os.path.join(self.rescale_config.latent_tensor_dir, "*.safetensors"))
num_to_generate = self.rescale_config.num_latent_tensors - len(current_tensor_list)
self.latent_paths = current_tensor_list
if num_to_generate > 0:
print(f"Generating {num_to_generate}/{self.rescale_config.num_latent_tensors} latent tensors")
# unload other model
self.sd.unet.to('cpu')
# load aux network
self.sd_parent = StableDiffusion(
self.device_torch,
model_config=self.model_config,
dtype=self.train_config.dtype,
)
self.sd_parent.load_model()
self.sd_parent.unet.to(self.device_torch, dtype=dtype)
# we dont need text encoder for this
del self.sd_parent.text_encoder
del self.sd_parent.tokenizer
self.sd_parent.unet.eval()
self.sd_parent.unet.requires_grad_(False)
# save current seed state for training
rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
text_embeddings = train_tools.concat_prompt_embeddings(
self.empty_embedding, # unconditional (negative prompt)
self.empty_embedding, # conditional (positive prompt)
self.train_config.batch_size,
)
torch.set_default_device(self.device_torch)
for i in tqdm(range(num_to_generate)):
dtype = get_torch_dtype(self.train_config.dtype)
# get a random seed
seed = torch.randint(0, 2 ** 32, (1,)).item()
# zero pad seed string to max length
seed_string = str(seed).zfill(10)
# set seed
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
# # ger a random number of steps
timesteps_to = self.train_config.max_denoising_steps
# set the scheduler to the number of steps
self.sd.noise_scheduler.set_timesteps(
timesteps_to, device=self.device_torch
)
noise = self.sd.get_latent_noise(
pixel_height=self.rescale_config.from_resolution,
pixel_width=self.rescale_config.from_resolution,
batch_size=self.train_config.batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
# get latents
latents = noise * self.sd.noise_scheduler.init_noise_sigma
latents = latents.to(self.device_torch, dtype=dtype)
# get random guidance scale from 1.0 to 10.0 (CFG)
guidance_scale = torch.rand(1).item() * 9.0 + 1.0
# do a timestep of 1
timestep = 1
noise_pred_target = self.sd_parent.predict_noise(
latents,
text_embeddings=text_embeddings,
timestep=timestep,
guidance_scale=guidance_scale
)
# build state dict
state_dict = OrderedDict()
state_dict['noise_pred_target'] = noise_pred_target.to('cpu', dtype=torch.float16)
state_dict['latents'] = latents.to('cpu', dtype=torch.float16)
state_dict['guidance_scale'] = torch.tensor(guidance_scale).to('cpu', dtype=torch.float16)
state_dict['timestep'] = torch.tensor(timestep).to('cpu', dtype=torch.float16)
state_dict['timesteps_to'] = torch.tensor(timesteps_to).to('cpu', dtype=torch.float16)
state_dict['seed'] = torch.tensor(seed).to('cpu', dtype=torch.float32) # must be float 32 to prevent overflow
file_name = f"{seed_string}_{i}.safetensors"
file_path = os.path.join(self.rescale_config.latent_tensor_dir, file_name)
save_file(state_dict, file_path)
self.latent_paths.append(file_path)
print("Removing parent model")
# delete parent
del self.sd_parent
flush()
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
self.sd.unet.to(self.device_torch, dtype=dtype)
def hook_before_train_loop(self):
self.print(f"Loading prompt file from {self.rescale_config.prompt_file}")
# encode our empty prompt
self.empty_embedding = self.sd.encode_prompt("")
self.empty_embedding = self.empty_embedding.to(self.device_torch,
dtype=get_torch_dtype(self.train_config.dtype))
# read line by line from file
with open(self.rescale_config.prompt_file, 'r', encoding='utf-8') as f:
self.prompt_txt_list = f.readlines()
# clean empty lines
self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0]
self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..")
cache = PromptEmbedsCache()
# get encoded latents for our prompts
with torch.no_grad():
if self.rescale_config.prompt_tensors is not None:
# check to see if it exists
if os.path.exists(self.rescale_config.prompt_tensors):
# load it.
self.print(f"Loading prompt tensors from {self.rescale_config.prompt_tensors}")
prompt_tensors = load_file(self.rescale_config.prompt_tensors, device='cpu')
# add them to the cache
for prompt_txt, prompt_tensor in prompt_tensors.items():
if prompt_txt.startswith("te:"):
prompt = prompt_txt[3:]
# text_embeds
text_embeds = prompt_tensor
pooled_embeds = None
# find pool embeds
if f"pe:{prompt}" in prompt_tensors:
pooled_embeds = prompt_tensors[f"pe:{prompt}"]
# make it
prompt_embeds = PromptEmbeds([text_embeds, pooled_embeds])
cache[prompt] = prompt_embeds.to(device='cpu', dtype=torch.float32)
if len(cache.prompts) == 0:
print("Prompt tensors not found. Encoding prompts..")
neutral = ""
# encode neutral
cache[neutral] = self.sd.encode_prompt(neutral)
for prompt in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False):
# build the cache
if cache[prompt] is None:
cache[prompt] = self.sd.encode_prompt(prompt).to(device="cpu", dtype=torch.float32)
if self.rescale_config.prompt_tensors:
print(f"Saving prompt tensors to {self.rescale_config.prompt_tensors}")
state_dict = {}
for prompt_txt, prompt_embeds in cache.prompts.items():
state_dict[f"te:{prompt_txt}"] = prompt_embeds.text_embeds.to("cpu",
dtype=get_torch_dtype('fp16'))
if prompt_embeds.pooled_embeds is not None:
state_dict[f"pe:{prompt_txt}"] = prompt_embeds.pooled_embeds.to("cpu",
dtype=get_torch_dtype(
'fp16'))
save_file(state_dict, self.rescale_config.prompt_tensors)
self.print("Encoding complete.")
# move to cpu to save vram
# We don't need text encoder anymore, but keep it on cpu for sampling
# if text encoder is list
# Move train model encoder to cpu
if isinstance(self.sd.text_encoder, list):
for encoder in self.sd.text_encoder:
encoder.to("cpu")
encoder.to('cpu')
encoder.eval()
encoder.requires_grad_(False)
else:
self.sd.text_encoder.to("cpu")
self.prompt_cache = cache
self.sd.text_encoder.to('cpu')
self.sd.text_encoder.eval()
self.sd.text_encoder.requires_grad_(False)
# self.sd.unet.to('cpu')
flush()
self.get_latent_tensors()
flush()
# end hook_before_train_loop
@@ -159,142 +213,64 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
def hook_train_loop(self):
dtype = get_torch_dtype(self.train_config.dtype)
do_dropout = False
# see if we should dropout
if self.rescale_config.prompt_dropout > 0.0:
thresh = int(self.rescale_config.prompt_dropout * 100)
if torch.randint(0, 100, (1,)).item() < thresh:
do_dropout = True
# get random encoded prompt from cache
positive_prompt_txt = self.prompt_txt_list[
torch.randint(0, len(self.prompt_txt_list), (1,)).item()
]
negative_prompt_txt = self.prompt_txt_list[
torch.randint(0, len(self.prompt_txt_list), (1,)).item()
]
if do_dropout:
positive_prompt = self.prompt_cache[''].to(device=self.device_torch, dtype=dtype)
negative_prompt = self.prompt_cache[''].to(device=self.device_torch, dtype=dtype)
else:
positive_prompt = self.prompt_cache[positive_prompt_txt].to(device=self.device_torch, dtype=dtype)
negative_prompt = self.prompt_cache[negative_prompt_txt].to(device=self.device_torch, dtype=dtype)
if positive_prompt is None:
raise ValueError(f"Prompt {positive_prompt_txt} is not in cache")
if negative_prompt is None:
raise ValueError(f"Prompt {negative_prompt_txt} is not in cache")
loss_function = torch.nn.MSELoss()
# train it
# Begin gradient accumulation
self.sd.unet.train()
self.sd.unet.requires_grad_(True)
self.sd.unet.to(self.device_torch, dtype=dtype)
with torch.no_grad():
self.optimizer.zero_grad()
# # ger a random number of steps
timesteps_to = torch.randint(
1, self.train_config.max_denoising_steps, (1,)
).item()
# pick random latent tensor
latent_path = random.choice(self.latent_paths)
latent_tensor = load_file(latent_path)
# set the scheduler to the number of steps
noise_pred_target = (latent_tensor['noise_pred_target']).to(self.device_torch, dtype=dtype)
latents = (latent_tensor['latents']).to(self.device_torch, dtype=dtype)
guidance_scale = (latent_tensor['guidance_scale']).item()
timestep = int((latent_tensor['timestep']).item())
timesteps_to = int((latent_tensor['timesteps_to']).item())
# seed = int((latent_tensor['seed']).item())
text_embeddings = train_tools.concat_prompt_embeddings(
self.empty_embedding, # unconditional (negative prompt)
self.empty_embedding, # conditional (positive prompt)
self.train_config.batch_size,
)
self.sd.noise_scheduler.set_timesteps(
timesteps_to, device=self.device_torch
)
# get noise
noise = self.sd.get_latent_noise(
pixel_height=self.rescale_config.from_resolution,
pixel_width=self.rescale_config.from_resolution,
batch_size=self.train_config.batch_size,
noise_offset=self.train_config.noise_offset,
).to(self.device_torch, dtype=dtype)
denoised_target = self.sd.noise_scheduler.step(noise_pred_target, timestep, latents).prev_sample
torch.set_default_device(self.device_torch)
# get the reduced latents
# reduced_pred = self.reduce_size_fn(noise_pred_target.detach())
denoised_target = self.reduce_size_fn(denoised_target.detach())
reduced_latents = self.reduce_size_fn(latents.detach())
# get latents
latents = noise * self.sd.noise_scheduler.init_noise_sigma
latents = latents.to(self.device_torch, dtype=dtype)
# get random guidance scale from 1.0 to 10.0 (CFG)
guidance_scale = torch.rand(1).item() * 9.0 + 1.0
loss_arr = []
max_len_timestep_str = len(str(self.train_config.max_denoising_steps))
# pad with spaces
timestep_str = str(timesteps_to).rjust(max_len_timestep_str, " ")
new_description = f"{self.job.name} ts: {timestep_str}"
self.progress_bar.set_description(new_description)
# Begin gradient accumulation
denoised_target.requires_grad = False
self.optimizer.zero_grad()
noise_pred_train = self.sd.predict_noise(
reduced_latents,
text_embeddings=text_embeddings,
timestep=timestep,
guidance_scale=guidance_scale
)
denoised_pred = self.sd.noise_scheduler.step(noise_pred_train, timestep, reduced_latents).prev_sample
loss = loss_function(denoised_pred, denoised_target)
loss_float = loss.item()
loss.backward()
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
# perform the diffusion
for timestep in tqdm(self.sd.noise_scheduler.timesteps, leave=False):
assert not self.network.is_active
text_embeddings = train_tools.concat_prompt_embeddings(
negative_prompt, # unconditional (negative prompt)
positive_prompt, # conditional (positive prompt)
self.train_config.batch_size,
)
with torch.no_grad():
noise_pred_target = self.sd.predict_noise(
latents,
text_embeddings=text_embeddings,
timestep=timestep,
guidance_scale=guidance_scale
)
# todo should we do every step?
do_train_cycle = True
if do_train_cycle:
# get the reduced latents
with torch.no_grad():
reduced_pred = self.reduce_size_fn(noise_pred_target.detach())
reduced_latents = self.reduce_size_fn(latents.detach())
with self.network:
assert self.network.is_active
self.network.multiplier = 1.0
noise_pred_train = self.sd.predict_noise(
reduced_latents,
text_embeddings=text_embeddings,
timestep=timestep,
guidance_scale=guidance_scale
)
reduced_pred.requires_grad = False
loss = loss_function(noise_pred_train, reduced_pred)
loss_arr.append(loss.item())
loss.backward()
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
# get next latents
# todo allow to show latent here
latents = self.sd.noise_scheduler.step(noise_pred_target, timestep, latents).prev_sample
# reset prompt embeds
positive_prompt.to(device="cpu")
negative_prompt.to(device="cpu")
flush()
# reset network
self.network.multiplier = 1.0
# average losses
s = 0
for num in loss_arr:
s += num
avg_loss = s / len(loss_arr)
loss_dict = OrderedDict(
{'loss': avg_loss},
{'loss': loss_float},
)
return loss_dict