diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 58bfba02..8fd652a8 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -35,7 +35,11 @@ class BaseSDTrainProcess(BaseTrainProcess): self.start_step = 0 self.device = self.get_conf('device', self.job.device) self.device_torch = torch.device(self.device) - self.network_config = NetworkConfig(**self.get_conf('network', None)) + network_config = self.get_conf('network', None) + if network_config is not None: + self.network_config = NetworkConfig(**network_config) + else: + self.network_config = None self.training_folder = self.get_conf('training_folder', self.job.training_folder) self.train_config = TrainConfig(**self.get_conf('train', {})) self.model_config = ModelConfig(**self.get_conf('model', {})) diff --git a/jobs/process/TrainSDRescaleProcess.py b/jobs/process/TrainSDRescaleProcess.py index bf77db11..d7464d5c 100644 --- a/jobs/process/TrainSDRescaleProcess.py +++ b/jobs/process/TrainSDRescaleProcess.py @@ -1,24 +1,14 @@ -# ref: -# - https://github.com/p1atdev/LECO/blob/main/train_lora.py -import time -from collections import OrderedDict +import glob import os -from typing import Optional +from collections import OrderedDict +import random +from typing import Optional, List -import numpy as np -from safetensors.torch import load_file, save_file +from safetensors.torch import save_file, load_file from tqdm import tqdm -from toolkit.config_modules import SliderConfig from toolkit.layers import ReductionKernel -from toolkit.paths import REPOS_ROOT -import sys - from toolkit.stable_diffusion_model import PromptEmbeds -from toolkit.train_pipelines import TransferStableDiffusionXLPipeline - -sys.path.append(REPOS_ROOT) -sys.path.append(os.path.join(REPOS_ROOT, 'leco')) from toolkit.train_tools import get_torch_dtype, apply_noise_offset import gc from toolkit import train_tools @@ -40,14 +30,11 @@ class RescaleConfig: ): self.from_resolution = kwargs.get('from_resolution', 512) self.scale = kwargs.get('scale', 0.5) - self.prompt_file = kwargs.get('prompt_file', None) - self.prompt_tensors = kwargs.get('prompt_tensors', None) + self.latent_tensor_dir = kwargs.get('latent_tensor_dir', None) + self.num_latent_tensors = kwargs.get('num_latent_tensors', 1000) self.to_resolution = kwargs.get('to_resolution', int(self.from_resolution * self.scale)) self.prompt_dropout = kwargs.get('prompt_dropout', 0.1) - if self.prompt_file is None: - raise ValueError("prompt_file is required") - class PromptEmbedsCache: prompts: dict[str, PromptEmbeds] = {} @@ -70,7 +57,6 @@ class TrainSDRescaleProcess(BaseSDTrainProcess): self.start_step = 0 self.device = self.get_conf('device', self.job.device) self.device_torch = torch.device(self.device) - self.prompt_cache = PromptEmbedsCache() self.rescale_config = RescaleConfig(**self.get_conf('rescale', required=True)) self.reduce_size_fn = ReductionKernel( in_channels=4, @@ -78,80 +64,148 @@ class TrainSDRescaleProcess(BaseSDTrainProcess): dtype=get_torch_dtype(self.train_config.dtype), device=self.device_torch, ) - self.prompt_txt_list = [] + + self.latent_paths: List[str] = [] + self.empty_embedding: PromptEmbeds = None def before_model_load(self): pass + def get_latent_tensors(self): + dtype = get_torch_dtype(self.train_config.dtype) + + num_to_generate = 0 + # check if dir exists + if not os.path.exists(self.rescale_config.latent_tensor_dir): + os.makedirs(self.rescale_config.latent_tensor_dir) + num_to_generate = self.rescale_config.num_latent_tensors + else: + # find existing + current_tensor_list = glob.glob(os.path.join(self.rescale_config.latent_tensor_dir, "*.safetensors")) + num_to_generate = self.rescale_config.num_latent_tensors - len(current_tensor_list) + self.latent_paths = current_tensor_list + + if num_to_generate > 0: + print(f"Generating {num_to_generate}/{self.rescale_config.num_latent_tensors} latent tensors") + + # unload other model + self.sd.unet.to('cpu') + + # load aux network + self.sd_parent = StableDiffusion( + self.device_torch, + model_config=self.model_config, + dtype=self.train_config.dtype, + ) + self.sd_parent.load_model() + self.sd_parent.unet.to(self.device_torch, dtype=dtype) + # we dont need text encoder for this + + del self.sd_parent.text_encoder + del self.sd_parent.tokenizer + + self.sd_parent.unet.eval() + self.sd_parent.unet.requires_grad_(False) + + # save current seed state for training + rng_state = torch.get_rng_state() + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + + text_embeddings = train_tools.concat_prompt_embeddings( + self.empty_embedding, # unconditional (negative prompt) + self.empty_embedding, # conditional (positive prompt) + self.train_config.batch_size, + ) + torch.set_default_device(self.device_torch) + + for i in tqdm(range(num_to_generate)): + dtype = get_torch_dtype(self.train_config.dtype) + # get a random seed + seed = torch.randint(0, 2 ** 32, (1,)).item() + # zero pad seed string to max length + seed_string = str(seed).zfill(10) + # set seed + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + # # ger a random number of steps + timesteps_to = self.train_config.max_denoising_steps + + # set the scheduler to the number of steps + self.sd.noise_scheduler.set_timesteps( + timesteps_to, device=self.device_torch + ) + + noise = self.sd.get_latent_noise( + pixel_height=self.rescale_config.from_resolution, + pixel_width=self.rescale_config.from_resolution, + batch_size=self.train_config.batch_size, + noise_offset=self.train_config.noise_offset, + ).to(self.device_torch, dtype=dtype) + + # get latents + latents = noise * self.sd.noise_scheduler.init_noise_sigma + latents = latents.to(self.device_torch, dtype=dtype) + + # get random guidance scale from 1.0 to 10.0 (CFG) + guidance_scale = torch.rand(1).item() * 9.0 + 1.0 + + # do a timestep of 1 + timestep = 1 + + noise_pred_target = self.sd_parent.predict_noise( + latents, + text_embeddings=text_embeddings, + timestep=timestep, + guidance_scale=guidance_scale + ) + + # build state dict + state_dict = OrderedDict() + state_dict['noise_pred_target'] = noise_pred_target.to('cpu', dtype=torch.float16) + state_dict['latents'] = latents.to('cpu', dtype=torch.float16) + state_dict['guidance_scale'] = torch.tensor(guidance_scale).to('cpu', dtype=torch.float16) + state_dict['timestep'] = torch.tensor(timestep).to('cpu', dtype=torch.float16) + state_dict['timesteps_to'] = torch.tensor(timesteps_to).to('cpu', dtype=torch.float16) + state_dict['seed'] = torch.tensor(seed).to('cpu', dtype=torch.float32) # must be float 32 to prevent overflow + + file_name = f"{seed_string}_{i}.safetensors" + file_path = os.path.join(self.rescale_config.latent_tensor_dir, file_name) + save_file(state_dict, file_path) + self.latent_paths.append(file_path) + + print("Removing parent model") + # delete parent + del self.sd_parent + flush() + + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + self.sd.unet.to(self.device_torch, dtype=dtype) + def hook_before_train_loop(self): - self.print(f"Loading prompt file from {self.rescale_config.prompt_file}") + # encode our empty prompt + self.empty_embedding = self.sd.encode_prompt("") + self.empty_embedding = self.empty_embedding.to(self.device_torch, + dtype=get_torch_dtype(self.train_config.dtype)) - # read line by line from file - with open(self.rescale_config.prompt_file, 'r', encoding='utf-8') as f: - self.prompt_txt_list = f.readlines() - # clean empty lines - self.prompt_txt_list = [line.strip() for line in self.prompt_txt_list if len(line.strip()) > 0] - - self.print(f"Loaded {len(self.prompt_txt_list)} prompts. Encoding them..") - - cache = PromptEmbedsCache() - - # get encoded latents for our prompts - with torch.no_grad(): - if self.rescale_config.prompt_tensors is not None: - # check to see if it exists - if os.path.exists(self.rescale_config.prompt_tensors): - # load it. - self.print(f"Loading prompt tensors from {self.rescale_config.prompt_tensors}") - prompt_tensors = load_file(self.rescale_config.prompt_tensors, device='cpu') - # add them to the cache - for prompt_txt, prompt_tensor in prompt_tensors.items(): - if prompt_txt.startswith("te:"): - prompt = prompt_txt[3:] - # text_embeds - text_embeds = prompt_tensor - pooled_embeds = None - # find pool embeds - if f"pe:{prompt}" in prompt_tensors: - pooled_embeds = prompt_tensors[f"pe:{prompt}"] - - # make it - prompt_embeds = PromptEmbeds([text_embeds, pooled_embeds]) - cache[prompt] = prompt_embeds.to(device='cpu', dtype=torch.float32) - - if len(cache.prompts) == 0: - print("Prompt tensors not found. Encoding prompts..") - neutral = "" - # encode neutral - cache[neutral] = self.sd.encode_prompt(neutral) - for prompt in tqdm(self.prompt_txt_list, desc="Encoding prompts", leave=False): - # build the cache - if cache[prompt] is None: - cache[prompt] = self.sd.encode_prompt(prompt).to(device="cpu", dtype=torch.float32) - - if self.rescale_config.prompt_tensors: - print(f"Saving prompt tensors to {self.rescale_config.prompt_tensors}") - state_dict = {} - for prompt_txt, prompt_embeds in cache.prompts.items(): - state_dict[f"te:{prompt_txt}"] = prompt_embeds.text_embeds.to("cpu", - dtype=get_torch_dtype('fp16')) - if prompt_embeds.pooled_embeds is not None: - state_dict[f"pe:{prompt_txt}"] = prompt_embeds.pooled_embeds.to("cpu", - dtype=get_torch_dtype( - 'fp16')) - save_file(state_dict, self.rescale_config.prompt_tensors) - - self.print("Encoding complete.") - - # move to cpu to save vram - # We don't need text encoder anymore, but keep it on cpu for sampling - # if text encoder is list + # Move train model encoder to cpu if isinstance(self.sd.text_encoder, list): for encoder in self.sd.text_encoder: - encoder.to("cpu") + encoder.to('cpu') + encoder.eval() + encoder.requires_grad_(False) else: - self.sd.text_encoder.to("cpu") - self.prompt_cache = cache + self.sd.text_encoder.to('cpu') + self.sd.text_encoder.eval() + self.sd.text_encoder.requires_grad_(False) + + # self.sd.unet.to('cpu') + flush() + + self.get_latent_tensors() flush() # end hook_before_train_loop @@ -159,142 +213,64 @@ class TrainSDRescaleProcess(BaseSDTrainProcess): def hook_train_loop(self): dtype = get_torch_dtype(self.train_config.dtype) - do_dropout = False - - # see if we should dropout - if self.rescale_config.prompt_dropout > 0.0: - thresh = int(self.rescale_config.prompt_dropout * 100) - if torch.randint(0, 100, (1,)).item() < thresh: - do_dropout = True - - # get random encoded prompt from cache - positive_prompt_txt = self.prompt_txt_list[ - torch.randint(0, len(self.prompt_txt_list), (1,)).item() - ] - negative_prompt_txt = self.prompt_txt_list[ - torch.randint(0, len(self.prompt_txt_list), (1,)).item() - ] - if do_dropout: - positive_prompt = self.prompt_cache[''].to(device=self.device_torch, dtype=dtype) - negative_prompt = self.prompt_cache[''].to(device=self.device_torch, dtype=dtype) - else: - positive_prompt = self.prompt_cache[positive_prompt_txt].to(device=self.device_torch, dtype=dtype) - negative_prompt = self.prompt_cache[negative_prompt_txt].to(device=self.device_torch, dtype=dtype) - - if positive_prompt is None: - raise ValueError(f"Prompt {positive_prompt_txt} is not in cache") - if negative_prompt is None: - raise ValueError(f"Prompt {negative_prompt_txt} is not in cache") - loss_function = torch.nn.MSELoss() + # train it + # Begin gradient accumulation + self.sd.unet.train() + self.sd.unet.requires_grad_(True) + self.sd.unet.to(self.device_torch, dtype=dtype) + with torch.no_grad(): self.optimizer.zero_grad() - # # ger a random number of steps - timesteps_to = torch.randint( - 1, self.train_config.max_denoising_steps, (1,) - ).item() + # pick random latent tensor + latent_path = random.choice(self.latent_paths) + latent_tensor = load_file(latent_path) - # set the scheduler to the number of steps + noise_pred_target = (latent_tensor['noise_pred_target']).to(self.device_torch, dtype=dtype) + latents = (latent_tensor['latents']).to(self.device_torch, dtype=dtype) + guidance_scale = (latent_tensor['guidance_scale']).item() + timestep = int((latent_tensor['timestep']).item()) + timesteps_to = int((latent_tensor['timesteps_to']).item()) + # seed = int((latent_tensor['seed']).item()) + + text_embeddings = train_tools.concat_prompt_embeddings( + self.empty_embedding, # unconditional (negative prompt) + self.empty_embedding, # conditional (positive prompt) + self.train_config.batch_size, + ) self.sd.noise_scheduler.set_timesteps( timesteps_to, device=self.device_torch ) - # get noise - noise = self.sd.get_latent_noise( - pixel_height=self.rescale_config.from_resolution, - pixel_width=self.rescale_config.from_resolution, - batch_size=self.train_config.batch_size, - noise_offset=self.train_config.noise_offset, - ).to(self.device_torch, dtype=dtype) + denoised_target = self.sd.noise_scheduler.step(noise_pred_target, timestep, latents).prev_sample - torch.set_default_device(self.device_torch) + # get the reduced latents + # reduced_pred = self.reduce_size_fn(noise_pred_target.detach()) + denoised_target = self.reduce_size_fn(denoised_target.detach()) + reduced_latents = self.reduce_size_fn(latents.detach()) - # get latents - latents = noise * self.sd.noise_scheduler.init_noise_sigma - latents = latents.to(self.device_torch, dtype=dtype) - - # get random guidance scale from 1.0 to 10.0 (CFG) - guidance_scale = torch.rand(1).item() * 9.0 + 1.0 - - loss_arr = [] - - max_len_timestep_str = len(str(self.train_config.max_denoising_steps)) - # pad with spaces - timestep_str = str(timesteps_to).rjust(max_len_timestep_str, " ") - new_description = f"{self.job.name} ts: {timestep_str}" - self.progress_bar.set_description(new_description) - - # Begin gradient accumulation + denoised_target.requires_grad = False + self.optimizer.zero_grad() + noise_pred_train = self.sd.predict_noise( + reduced_latents, + text_embeddings=text_embeddings, + timestep=timestep, + guidance_scale=guidance_scale + ) + denoised_pred = self.sd.noise_scheduler.step(noise_pred_train, timestep, reduced_latents).prev_sample + loss = loss_function(denoised_pred, denoised_target) + loss_float = loss.item() + loss.backward() + self.optimizer.step() + self.lr_scheduler.step() self.optimizer.zero_grad() - - # perform the diffusion - for timestep in tqdm(self.sd.noise_scheduler.timesteps, leave=False): - assert not self.network.is_active - - text_embeddings = train_tools.concat_prompt_embeddings( - negative_prompt, # unconditional (negative prompt) - positive_prompt, # conditional (positive prompt) - self.train_config.batch_size, - ) - - with torch.no_grad(): - noise_pred_target = self.sd.predict_noise( - latents, - text_embeddings=text_embeddings, - timestep=timestep, - guidance_scale=guidance_scale - ) - - # todo should we do every step? - do_train_cycle = True - - if do_train_cycle: - # get the reduced latents - with torch.no_grad(): - reduced_pred = self.reduce_size_fn(noise_pred_target.detach()) - reduced_latents = self.reduce_size_fn(latents.detach()) - with self.network: - assert self.network.is_active - self.network.multiplier = 1.0 - noise_pred_train = self.sd.predict_noise( - reduced_latents, - text_embeddings=text_embeddings, - timestep=timestep, - guidance_scale=guidance_scale - ) - - reduced_pred.requires_grad = False - loss = loss_function(noise_pred_train, reduced_pred) - loss_arr.append(loss.item()) - loss.backward() - self.optimizer.step() - self.lr_scheduler.step() - self.optimizer.zero_grad() - - # get next latents - # todo allow to show latent here - latents = self.sd.noise_scheduler.step(noise_pred_target, timestep, latents).prev_sample - - # reset prompt embeds - positive_prompt.to(device="cpu") - negative_prompt.to(device="cpu") flush() - # reset network - self.network.multiplier = 1.0 - - # average losses - s = 0 - for num in loss_arr: - s += num - - avg_loss = s / len(loss_arr) - loss_dict = OrderedDict( - {'loss': avg_loss}, + {'loss': loss_float}, ) return loss_dict