From 90eedb78bfd433e51bfae90a7b572a1b55f9546b Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 19 Aug 2023 05:54:22 -0600 Subject: [PATCH] Added multiplier jitter, min_snr, ability to choose sdxl encoders to use, shuffle generator, and other fun --- config/examples/train_slider.example.yml | 10 +- .../ImageReferenceSliderTrainerProcess.py | 16 +- jobs/process/GenerateProcess.py | 5 + jobs/process/TrainSliderProcess.py | 33 ++-- toolkit/config_modules.py | 5 + toolkit/stable_diffusion_model.py | 19 +- toolkit/train_tools.py | 186 +++++++++++++++++- 7 files changed, 239 insertions(+), 35 deletions(-) diff --git a/config/examples/train_slider.example.yml b/config/examples/train_slider.example.yml index c92955ec..4c00dc9c 100644 --- a/config/examples/train_slider.example.yml +++ b/config/examples/train_slider.example.yml @@ -23,9 +23,8 @@ config: # network type lierla is traditional LoRA that works everywhere, only linear layers type: "lierla" # rank / dim of the network. Bigger is not always better. Especially for sliders. 8 is good - rank: 8 - alpha: 4 # Do about half of rank - + linear: 8 + linear_alpha: 4 # Do about half of rank # training config train: # this is also used in sampling. Stick with ddpm unless you know what you are doing @@ -42,8 +41,8 @@ config: # for sliders we are adjusting representation of the concept (unet), # not the description of it (text encoder) train_text_encoder: false - - + # same as from sd-scripts, not fully tested but should speed up training + min_snr_gamma: 5.0 # just leave unless you know what you are doing # also supports "dadaptation" but set lr to 1 if you use that, # but it learns too fast and I don't recommend it @@ -64,6 +63,7 @@ config: # I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX # although, the way we train sliders is comparative, so it probably won't work anyway noise_offset: 0.0 +# noise_offset: 0.0357 # SDXL was trained with offset of 0.0357. So use that when training on SDXL # the model to train the LoRA network on model: diff --git a/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py b/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py index 568ee9ae..688e172b 100644 --- a/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py +++ b/extensions_built_in/image_reference_slider_trainer/ImageReferenceSliderTrainerProcess.py @@ -8,11 +8,12 @@ from torch.utils.data import ConcatDataset, DataLoader from toolkit.data_loader import PairedImageDataset from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds -from toolkit.train_tools import get_torch_dtype +from toolkit.train_tools import get_torch_dtype, apply_snr_weight import gc from toolkit import train_tools import torch from jobs.process import BaseSDTrainProcess +import random def flush(): @@ -41,6 +42,7 @@ class DatasetConfig: class ReferenceSliderConfig: def __init__(self, **kwargs): self.additional_losses: List[str] = kwargs.get('additional_losses', []) + self.weight_jitter: float = kwargs.get('weight_jitter', 0.0) self.datasets: List[DatasetConfig] = [DatasetConfig(**d) for d in kwargs.get('datasets', [])] @@ -98,10 +100,19 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): with torch.no_grad(): imgs, prompts, network_weights = batch network_pos_weight, network_neg_weight = network_weights + if isinstance(network_pos_weight, torch.Tensor): network_pos_weight = network_pos_weight.item() if isinstance(network_neg_weight, torch.Tensor): network_neg_weight = network_neg_weight.item() + + # get an array of random floats between -weight_jitter and weight_jitter + weight_jitter = self.slider_config.weight_jitter + if weight_jitter > 0.0: + jitter_list = random.uniform(-weight_jitter, weight_jitter) + network_pos_weight += jitter_list + network_neg_weight += jitter_list + # if items in network_weight list are tensors, convert them to floats dtype = get_torch_dtype(self.train_config.dtype) @@ -211,6 +222,9 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess): loss = loss.mean([1, 2, 3]) # todo add snr gamma here + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma) loss = loss.mean() loss_slide_float = loss.item() diff --git a/jobs/process/GenerateProcess.py b/jobs/process/GenerateProcess.py index 696d6afa..89eb5ee7 100644 --- a/jobs/process/GenerateProcess.py +++ b/jobs/process/GenerateProcess.py @@ -12,6 +12,7 @@ from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safete add_base_model_info_to_meta from toolkit.stable_diffusion_model import StableDiffusion from toolkit.train_tools import get_torch_dtype +import random class GenerateConfig: @@ -41,6 +42,10 @@ class GenerateConfig: else: raise ValueError("Prompts file does not exist, put in list if you want to use a list of prompts") + if kwargs.get('shuffle', False): + # shuffle the prompts + random.shuffle(self.prompts) + class GenerateProcess(BaseProcess): process_id: int diff --git a/jobs/process/TrainSliderProcess.py b/jobs/process/TrainSliderProcess.py index 46cef6b7..8aa5bd11 100644 --- a/jobs/process/TrainSliderProcess.py +++ b/jobs/process/TrainSliderProcess.py @@ -1,21 +1,9 @@ -# ref: -# - https://github.com/p1atdev/LECO/blob/main/train_lora.py import random from collections import OrderedDict -import os -from typing import Optional, Union - -from safetensors.torch import save_file, load_file -import torch.utils.checkpoint as cp from tqdm import tqdm from toolkit.config_modules import SliderConfig -from toolkit.layers import CheckpointGradients -from toolkit.paths import REPOS_ROOT -import sys - -from toolkit.stable_diffusion_model import PromptEmbeds -from toolkit.train_tools import get_torch_dtype +from toolkit.train_tools import get_torch_dtype, apply_snr_weight import gc from toolkit import train_tools from toolkit.prompt_utils import \ @@ -256,9 +244,8 @@ class TrainSliderProcess(BaseSDTrainProcess): noise_scheduler.set_timesteps(1000) - current_timestep = noise_scheduler.timesteps[ - int(timesteps_to * 1000 / self.train_config.max_denoising_steps) - ] + current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps) + current_timestep = noise_scheduler.timesteps[current_timestep_index] # flush() # 4.2GB to 3GB on 512x512 @@ -401,10 +388,16 @@ class TrainSliderProcess(BaseSDTrainProcess): offset_neutral += offset # 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing - loss = loss_function( - target_latents, - offset_neutral, - ) * prompt_pair_chunk.weight + loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none") + loss = loss.mean([1, 2, 3]) + + if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001: + # match batch size + timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])] + # add min_snr_gamma + loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler, self.train_config.min_snr_gamma) + + loss = loss.mean() * prompt_pair_chunk.weight loss.backward() loss_list.append(loss.item()) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index f84273d9..6b7b8a20 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -63,6 +63,7 @@ class TrainConfig: self.xformers = kwargs.get('xformers', False) self.train_unet = kwargs.get('train_unet', True) self.train_text_encoder = kwargs.get('train_text_encoder', True) + self.min_snr_gamma = kwargs.get('min_snr_gamma', None) self.noise_offset = kwargs.get('noise_offset', 0.0) self.optimizer_params = kwargs.get('optimizer_params', {}) self.skip_first_sample = kwargs.get('skip_first_sample', False) @@ -77,6 +78,10 @@ class ModelConfig: self.is_v_pred: bool = kwargs.get('is_v_pred', False) self.dtype: str = kwargs.get('dtype', 'float16') + # only for SDXL models for now + self.use_text_encoder_1: bool = kwargs.get('use_text_encoder_1', True) + self.use_text_encoder_2: bool = kwargs.get('use_text_encoder_2', True) + if self.name_or_path is None: raise ValueError('name_or_path must be specified') diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index f0122a21..1e42fcc1 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -16,10 +16,6 @@ from toolkit.config_modules import ModelConfig, GenerateImageConfig from toolkit.metadata import get_meta_for_safetensors from toolkit.paths import REPOS_ROOT from toolkit.train_tools import get_torch_dtype, apply_noise_offset - -sys.path.append(REPOS_ROOT) -sys.path.append(os.path.join(REPOS_ROOT, 'leco')) -from leco import train_util import torch from library import model_util from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl @@ -124,6 +120,9 @@ class StableDiffusion: self.is_xl = model_config.is_xl self.is_v2 = model_config.is_v2 + self.use_text_encoder_1 = model_config.use_text_encoder_1 + self.use_text_encoder_2 = model_config.use_text_encoder_2 + def load_model(self): if self.is_loaded: return @@ -309,6 +308,7 @@ class StableDiffusion: torch.manual_seed(gen_config.seed) torch.cuda.manual_seed(gen_config.seed) + # todo do we disable text encoder here as well if disabled for model, or only do that for training? if self.is_xl: img = pipeline( prompt=gen_config.prompt, @@ -393,7 +393,7 @@ class StableDiffusion: dtype = latents.dtype if self.is_xl: - prompt_ids = train_util.get_add_time_ids( + prompt_ids = train_tools.get_add_time_ids( height, width, dynamic_crops=False, # look into this @@ -444,7 +444,7 @@ class StableDiffusion: if do_classifier_free_guidance: # todo check this with larget batches - add_time_ids = train_util.concat_embeddings( + add_time_ids = train_tools.concat_embeddings( add_time_ids, add_time_ids, int(latents.shape[0]) ) else: @@ -459,6 +459,7 @@ class StableDiffusion: latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep) added_cond_kwargs = { + # todo can we zero here the second text encoder? or match a blank string? "text_embeds": text_embeddings.pooled_embeds, "time_ids": add_time_ids, } @@ -541,16 +542,18 @@ class StableDiffusion: prompt = [prompt] if self.is_xl: return PromptEmbeds( - train_util.encode_prompts_xl( + train_tools.encode_prompts_xl( self.tokenizer, self.text_encoder, prompt, num_images_per_prompt=num_images_per_prompt, + use_text_encoder_1=self.use_text_encoder_1, + use_text_encoder_2=self.use_text_encoder_2, ) ) else: return PromptEmbeds( - train_util.encode_prompts( + train_tools.encode_prompts( self.tokenizer, self.text_encoder, prompt ) ) diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index 4ad723d8..8bcea8a9 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -3,7 +3,7 @@ import hashlib import json import os import time -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union import sys from toolkit.paths import SD_SCRIPTS_ROOT @@ -32,6 +32,10 @@ SCHEDULER_LINEAR_END = 0.0120 SCHEDULER_TIMESTEPS = 1000 SCHEDLER_SCHEDULE = "scaled_linear" +UNET_ATTENTION_TIME_EMBED_DIM = 256 # XL +TEXT_ENCODER_2_PROJECTION_DIM = 1280 +UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM = 2816 + def get_torch_dtype(dtype_str): # if it is a torch dtype, return it @@ -433,3 +437,183 @@ def addnet_hash_legacy(b): b.seek(0x100000) m.update(b.read(0x10000)) return m.hexdigest()[0:8] + + +if TYPE_CHECKING: + from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection + + +def text_tokenize( + tokenizer: 'CLIPTokenizer', # 普通ならひとつ、XLならふたつ! + prompts: list[str], +): + return tokenizer( + prompts, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ).input_ids + + +# https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348 +def text_encode_xl( + text_encoder: Union['CLIPTextModel', 'CLIPTextModelWithProjection'], + tokens: torch.FloatTensor, + num_images_per_prompt: int = 1, +): + prompt_embeds = text_encoder( + tokens.to(text_encoder.device), output_hidden_states=True + ) + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer + + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, pooled_prompt_embeds + + +def encode_prompts_xl( + tokenizers: list['CLIPTokenizer'], + text_encoders: list[Union['CLIPTextModel', 'CLIPTextModelWithProjection']], + prompts: list[str], + num_images_per_prompt: int = 1, + use_text_encoder_1: bool = True, # sdxl + use_text_encoder_2: bool = True # sdxl +) -> tuple[torch.FloatTensor, torch.FloatTensor]: + # text_encoder and text_encoder_2's penuultimate layer's output + text_embeds_list = [] + pooled_text_embeds = None # always text_encoder_2's pool + + for idx, (tokenizer, text_encoder) in enumerate(zip(tokenizers, text_encoders)): + # todo, we are using a blank string to ignore that encoder for now. + # find a better way to do this (zeroing?, removing it from the unet?) + prompt_list_to_use = prompts + if idx == 0 and not use_text_encoder_1: + prompt_list_to_use = ["" for _ in prompts] + if idx == 1 and not use_text_encoder_2: + prompt_list_to_use = ["" for _ in prompts] + + text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use) + text_embeds, pooled_text_embeds = text_encode_xl( + text_encoder, text_tokens_input_ids, num_images_per_prompt + ) + + text_embeds_list.append(text_embeds) + + bs_embed = pooled_text_embeds.shape[0] + pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds + + +def text_encode(text_encoder: 'CLIPTextModel', tokens): + return text_encoder(tokens.to(text_encoder.device))[0] + + +def encode_prompts( + tokenizer: 'CLIPTokenizer', + text_encoder: 'CLIPTokenizer', + prompts: list[str], +): + text_tokens = text_tokenize(tokenizer, prompts) + text_embeddings = text_encode(text_encoder, text_tokens) + + return text_embeddings + + +# for XL +def get_add_time_ids( + height: int, + width: int, + dynamic_crops: bool = False, + dtype: torch.dtype = torch.float32, +): + if dynamic_crops: + # random float scale between 1 and 3 + random_scale = torch.rand(1).item() * 2 + 1 + original_size = (int(height * random_scale), int(width * random_scale)) + # random position + crops_coords_top_left = ( + torch.randint(0, original_size[0] - height, (1,)).item(), + torch.randint(0, original_size[1] - width, (1,)).item(), + ) + target_size = (height, width) + else: + original_size = (height, width) + crops_coords_top_left = (0, 0) + target_size = (height, width) + + # this is expected as 6 + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + # this is expected as 2816 + passed_add_embed_dim = ( + UNET_ATTENTION_TIME_EMBED_DIM * len(add_time_ids) # 256 * 6 + + TEXT_ENCODER_2_PROJECTION_DIM # + 1280 + ) + if passed_add_embed_dim != UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM: + raise ValueError( + f"Model expects an added time embedding vector of length {UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + +def concat_embeddings( + unconditional: torch.FloatTensor, + conditional: torch.FloatTensor, + n_imgs: int, +): + return torch.cat([unconditional, conditional]).repeat_interleave(n_imgs, dim=0) + + +def add_all_snr_to_noise_scheduler(noise_scheduler, device): + if hasattr(noise_scheduler, "all_snr"): + return + # compute it + with torch.no_grad(): + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + alpha = sqrt_alphas_cumprod + sigma = sqrt_one_minus_alphas_cumprod + all_snr = (alpha / sigma) ** 2 + all_snr.requires_grad = False + noise_scheduler.all_snr = all_snr.to(device) + + +def get_all_snr(noise_scheduler, device): + if hasattr(noise_scheduler, "all_snr"): + return noise_scheduler.all_snr.to(device) + # compute it + with torch.no_grad(): + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + alpha = sqrt_alphas_cumprod + sigma = sqrt_one_minus_alphas_cumprod + all_snr = (alpha / sigma) ** 2 + all_snr.requires_grad = False + return all_snr.to(device) + + +def apply_snr_weight( + loss, + timesteps, + noise_scheduler: Union['DDPMScheduler'], + gamma +): + # will get it form noise scheduler if exist or will calculate it if not + all_snr = get_all_snr(noise_scheduler, loss.device) + + snr = torch.stack([all_snr[t] for t in timesteps]) + gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) + snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper + loss = loss * snr_weight + return loss