Added multiplier jitter, min_snr, ability to choose sdxl encoders to use, shuffle generator, and other fun

This commit is contained in:
Jaret Burkett
2023-08-19 05:54:22 -06:00
parent 80e2f4a2a4
commit 90eedb78bf
7 changed files with 239 additions and 35 deletions

View File

@@ -63,6 +63,7 @@ class TrainConfig:
self.xformers = kwargs.get('xformers', False)
self.train_unet = kwargs.get('train_unet', True)
self.train_text_encoder = kwargs.get('train_text_encoder', True)
self.min_snr_gamma = kwargs.get('min_snr_gamma', None)
self.noise_offset = kwargs.get('noise_offset', 0.0)
self.optimizer_params = kwargs.get('optimizer_params', {})
self.skip_first_sample = kwargs.get('skip_first_sample', False)
@@ -77,6 +78,10 @@ class ModelConfig:
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
self.dtype: str = kwargs.get('dtype', 'float16')
# only for SDXL models for now
self.use_text_encoder_1: bool = kwargs.get('use_text_encoder_1', True)
self.use_text_encoder_2: bool = kwargs.get('use_text_encoder_2', True)
if self.name_or_path is None:
raise ValueError('name_or_path must be specified')

View File

@@ -16,10 +16,6 @@ from toolkit.config_modules import ModelConfig, GenerateImageConfig
from toolkit.metadata import get_meta_for_safetensors
from toolkit.paths import REPOS_ROOT
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from leco import train_util
import torch
from library import model_util
from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl
@@ -124,6 +120,9 @@ class StableDiffusion:
self.is_xl = model_config.is_xl
self.is_v2 = model_config.is_v2
self.use_text_encoder_1 = model_config.use_text_encoder_1
self.use_text_encoder_2 = model_config.use_text_encoder_2
def load_model(self):
if self.is_loaded:
return
@@ -309,6 +308,7 @@ class StableDiffusion:
torch.manual_seed(gen_config.seed)
torch.cuda.manual_seed(gen_config.seed)
# todo do we disable text encoder here as well if disabled for model, or only do that for training?
if self.is_xl:
img = pipeline(
prompt=gen_config.prompt,
@@ -393,7 +393,7 @@ class StableDiffusion:
dtype = latents.dtype
if self.is_xl:
prompt_ids = train_util.get_add_time_ids(
prompt_ids = train_tools.get_add_time_ids(
height,
width,
dynamic_crops=False, # look into this
@@ -444,7 +444,7 @@ class StableDiffusion:
if do_classifier_free_guidance:
# todo check this with larget batches
add_time_ids = train_util.concat_embeddings(
add_time_ids = train_tools.concat_embeddings(
add_time_ids, add_time_ids, int(latents.shape[0])
)
else:
@@ -459,6 +459,7 @@ class StableDiffusion:
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
added_cond_kwargs = {
# todo can we zero here the second text encoder? or match a blank string?
"text_embeds": text_embeddings.pooled_embeds,
"time_ids": add_time_ids,
}
@@ -541,16 +542,18 @@ class StableDiffusion:
prompt = [prompt]
if self.is_xl:
return PromptEmbeds(
train_util.encode_prompts_xl(
train_tools.encode_prompts_xl(
self.tokenizer,
self.text_encoder,
prompt,
num_images_per_prompt=num_images_per_prompt,
use_text_encoder_1=self.use_text_encoder_1,
use_text_encoder_2=self.use_text_encoder_2,
)
)
else:
return PromptEmbeds(
train_util.encode_prompts(
train_tools.encode_prompts(
self.tokenizer, self.text_encoder, prompt
)
)

View File

@@ -3,7 +3,7 @@ import hashlib
import json
import os
import time
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union
import sys
from toolkit.paths import SD_SCRIPTS_ROOT
@@ -32,6 +32,10 @@ SCHEDULER_LINEAR_END = 0.0120
SCHEDULER_TIMESTEPS = 1000
SCHEDLER_SCHEDULE = "scaled_linear"
UNET_ATTENTION_TIME_EMBED_DIM = 256 # XL
TEXT_ENCODER_2_PROJECTION_DIM = 1280
UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM = 2816
def get_torch_dtype(dtype_str):
# if it is a torch dtype, return it
@@ -433,3 +437,183 @@ def addnet_hash_legacy(b):
b.seek(0x100000)
m.update(b.read(0x10000))
return m.hexdigest()[0:8]
if TYPE_CHECKING:
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
def text_tokenize(
tokenizer: 'CLIPTokenizer', # 普通ならひとつ、XLならふたつ
prompts: list[str],
):
return tokenizer(
prompts,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).input_ids
# https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348
def text_encode_xl(
text_encoder: Union['CLIPTextModel', 'CLIPTextModelWithProjection'],
tokens: torch.FloatTensor,
num_images_per_prompt: int = 1,
):
prompt_embeds = text_encoder(
tokens.to(text_encoder.device), output_hidden_states=True
)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
return prompt_embeds, pooled_prompt_embeds
def encode_prompts_xl(
tokenizers: list['CLIPTokenizer'],
text_encoders: list[Union['CLIPTextModel', 'CLIPTextModelWithProjection']],
prompts: list[str],
num_images_per_prompt: int = 1,
use_text_encoder_1: bool = True, # sdxl
use_text_encoder_2: bool = True # sdxl
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
# text_encoder and text_encoder_2's penuultimate layer's output
text_embeds_list = []
pooled_text_embeds = None # always text_encoder_2's pool
for idx, (tokenizer, text_encoder) in enumerate(zip(tokenizers, text_encoders)):
# todo, we are using a blank string to ignore that encoder for now.
# find a better way to do this (zeroing?, removing it from the unet?)
prompt_list_to_use = prompts
if idx == 0 and not use_text_encoder_1:
prompt_list_to_use = ["" for _ in prompts]
if idx == 1 and not use_text_encoder_2:
prompt_list_to_use = ["" for _ in prompts]
text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use)
text_embeds, pooled_text_embeds = text_encode_xl(
text_encoder, text_tokens_input_ids, num_images_per_prompt
)
text_embeds_list.append(text_embeds)
bs_embed = pooled_text_embeds.shape[0]
pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds
def text_encode(text_encoder: 'CLIPTextModel', tokens):
return text_encoder(tokens.to(text_encoder.device))[0]
def encode_prompts(
tokenizer: 'CLIPTokenizer',
text_encoder: 'CLIPTokenizer',
prompts: list[str],
):
text_tokens = text_tokenize(tokenizer, prompts)
text_embeddings = text_encode(text_encoder, text_tokens)
return text_embeddings
# for XL
def get_add_time_ids(
height: int,
width: int,
dynamic_crops: bool = False,
dtype: torch.dtype = torch.float32,
):
if dynamic_crops:
# random float scale between 1 and 3
random_scale = torch.rand(1).item() * 2 + 1
original_size = (int(height * random_scale), int(width * random_scale))
# random position
crops_coords_top_left = (
torch.randint(0, original_size[0] - height, (1,)).item(),
torch.randint(0, original_size[1] - width, (1,)).item(),
)
target_size = (height, width)
else:
original_size = (height, width)
crops_coords_top_left = (0, 0)
target_size = (height, width)
# this is expected as 6
add_time_ids = list(original_size + crops_coords_top_left + target_size)
# this is expected as 2816
passed_add_embed_dim = (
UNET_ATTENTION_TIME_EMBED_DIM * len(add_time_ids) # 256 * 6
+ TEXT_ENCODER_2_PROJECTION_DIM # + 1280
)
if passed_add_embed_dim != UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM:
raise ValueError(
f"Model expects an added time embedding vector of length {UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
)
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
return add_time_ids
def concat_embeddings(
unconditional: torch.FloatTensor,
conditional: torch.FloatTensor,
n_imgs: int,
):
return torch.cat([unconditional, conditional]).repeat_interleave(n_imgs, dim=0)
def add_all_snr_to_noise_scheduler(noise_scheduler, device):
if hasattr(noise_scheduler, "all_snr"):
return
# compute it
with torch.no_grad():
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
alpha = sqrt_alphas_cumprod
sigma = sqrt_one_minus_alphas_cumprod
all_snr = (alpha / sigma) ** 2
all_snr.requires_grad = False
noise_scheduler.all_snr = all_snr.to(device)
def get_all_snr(noise_scheduler, device):
if hasattr(noise_scheduler, "all_snr"):
return noise_scheduler.all_snr.to(device)
# compute it
with torch.no_grad():
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
alpha = sqrt_alphas_cumprod
sigma = sqrt_one_minus_alphas_cumprod
all_snr = (alpha / sigma) ** 2
all_snr.requires_grad = False
return all_snr.to(device)
def apply_snr_weight(
loss,
timesteps,
noise_scheduler: Union['DDPMScheduler'],
gamma
):
# will get it form noise scheduler if exist or will calculate it if not
all_snr = get_all_snr(noise_scheduler, loss.device)
snr = torch.stack([all_snr[t] for t in timesteps])
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper
loss = loss * snr_weight
return loss