mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Added multiplier jitter, min_snr, ability to choose sdxl encoders to use, shuffle generator, and other fun
This commit is contained in:
@@ -23,9 +23,8 @@ config:
|
|||||||
# network type lierla is traditional LoRA that works everywhere, only linear layers
|
# network type lierla is traditional LoRA that works everywhere, only linear layers
|
||||||
type: "lierla"
|
type: "lierla"
|
||||||
# rank / dim of the network. Bigger is not always better. Especially for sliders. 8 is good
|
# rank / dim of the network. Bigger is not always better. Especially for sliders. 8 is good
|
||||||
rank: 8
|
linear: 8
|
||||||
alpha: 4 # Do about half of rank
|
linear_alpha: 4 # Do about half of rank
|
||||||
|
|
||||||
# training config
|
# training config
|
||||||
train:
|
train:
|
||||||
# this is also used in sampling. Stick with ddpm unless you know what you are doing
|
# this is also used in sampling. Stick with ddpm unless you know what you are doing
|
||||||
@@ -42,8 +41,8 @@ config:
|
|||||||
# for sliders we are adjusting representation of the concept (unet),
|
# for sliders we are adjusting representation of the concept (unet),
|
||||||
# not the description of it (text encoder)
|
# not the description of it (text encoder)
|
||||||
train_text_encoder: false
|
train_text_encoder: false
|
||||||
|
# same as from sd-scripts, not fully tested but should speed up training
|
||||||
|
min_snr_gamma: 5.0
|
||||||
# just leave unless you know what you are doing
|
# just leave unless you know what you are doing
|
||||||
# also supports "dadaptation" but set lr to 1 if you use that,
|
# also supports "dadaptation" but set lr to 1 if you use that,
|
||||||
# but it learns too fast and I don't recommend it
|
# but it learns too fast and I don't recommend it
|
||||||
@@ -64,6 +63,7 @@ config:
|
|||||||
# I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX
|
# I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX
|
||||||
# although, the way we train sliders is comparative, so it probably won't work anyway
|
# although, the way we train sliders is comparative, so it probably won't work anyway
|
||||||
noise_offset: 0.0
|
noise_offset: 0.0
|
||||||
|
# noise_offset: 0.0357 # SDXL was trained with offset of 0.0357. So use that when training on SDXL
|
||||||
|
|
||||||
# the model to train the LoRA network on
|
# the model to train the LoRA network on
|
||||||
model:
|
model:
|
||||||
|
|||||||
@@ -8,11 +8,12 @@ from torch.utils.data import ConcatDataset, DataLoader
|
|||||||
from toolkit.data_loader import PairedImageDataset
|
from toolkit.data_loader import PairedImageDataset
|
||||||
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
|
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
|
||||||
from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
|
from toolkit.stable_diffusion_model import StableDiffusion, PromptEmbeds
|
||||||
from toolkit.train_tools import get_torch_dtype
|
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
||||||
import gc
|
import gc
|
||||||
from toolkit import train_tools
|
from toolkit import train_tools
|
||||||
import torch
|
import torch
|
||||||
from jobs.process import BaseSDTrainProcess
|
from jobs.process import BaseSDTrainProcess
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
def flush():
|
def flush():
|
||||||
@@ -41,6 +42,7 @@ class DatasetConfig:
|
|||||||
class ReferenceSliderConfig:
|
class ReferenceSliderConfig:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.additional_losses: List[str] = kwargs.get('additional_losses', [])
|
self.additional_losses: List[str] = kwargs.get('additional_losses', [])
|
||||||
|
self.weight_jitter: float = kwargs.get('weight_jitter', 0.0)
|
||||||
self.datasets: List[DatasetConfig] = [DatasetConfig(**d) for d in kwargs.get('datasets', [])]
|
self.datasets: List[DatasetConfig] = [DatasetConfig(**d) for d in kwargs.get('datasets', [])]
|
||||||
|
|
||||||
|
|
||||||
@@ -98,10 +100,19 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
imgs, prompts, network_weights = batch
|
imgs, prompts, network_weights = batch
|
||||||
network_pos_weight, network_neg_weight = network_weights
|
network_pos_weight, network_neg_weight = network_weights
|
||||||
|
|
||||||
if isinstance(network_pos_weight, torch.Tensor):
|
if isinstance(network_pos_weight, torch.Tensor):
|
||||||
network_pos_weight = network_pos_weight.item()
|
network_pos_weight = network_pos_weight.item()
|
||||||
if isinstance(network_neg_weight, torch.Tensor):
|
if isinstance(network_neg_weight, torch.Tensor):
|
||||||
network_neg_weight = network_neg_weight.item()
|
network_neg_weight = network_neg_weight.item()
|
||||||
|
|
||||||
|
# get an array of random floats between -weight_jitter and weight_jitter
|
||||||
|
weight_jitter = self.slider_config.weight_jitter
|
||||||
|
if weight_jitter > 0.0:
|
||||||
|
jitter_list = random.uniform(-weight_jitter, weight_jitter)
|
||||||
|
network_pos_weight += jitter_list
|
||||||
|
network_neg_weight += jitter_list
|
||||||
|
|
||||||
# if items in network_weight list are tensors, convert them to floats
|
# if items in network_weight list are tensors, convert them to floats
|
||||||
|
|
||||||
dtype = get_torch_dtype(self.train_config.dtype)
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
@@ -211,6 +222,9 @@ class ImageReferenceSliderTrainerProcess(BaseSDTrainProcess):
|
|||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
# todo add snr gamma here
|
# todo add snr gamma here
|
||||||
|
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||||
|
# add min_snr_gamma
|
||||||
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, self.train_config.min_snr_gamma)
|
||||||
|
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
loss_slide_float = loss.item()
|
loss_slide_float = loss.item()
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safete
|
|||||||
add_base_model_info_to_meta
|
add_base_model_info_to_meta
|
||||||
from toolkit.stable_diffusion_model import StableDiffusion
|
from toolkit.stable_diffusion_model import StableDiffusion
|
||||||
from toolkit.train_tools import get_torch_dtype
|
from toolkit.train_tools import get_torch_dtype
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
class GenerateConfig:
|
class GenerateConfig:
|
||||||
@@ -41,6 +42,10 @@ class GenerateConfig:
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Prompts file does not exist, put in list if you want to use a list of prompts")
|
raise ValueError("Prompts file does not exist, put in list if you want to use a list of prompts")
|
||||||
|
|
||||||
|
if kwargs.get('shuffle', False):
|
||||||
|
# shuffle the prompts
|
||||||
|
random.shuffle(self.prompts)
|
||||||
|
|
||||||
|
|
||||||
class GenerateProcess(BaseProcess):
|
class GenerateProcess(BaseProcess):
|
||||||
process_id: int
|
process_id: int
|
||||||
|
|||||||
@@ -1,21 +1,9 @@
|
|||||||
# ref:
|
|
||||||
# - https://github.com/p1atdev/LECO/blob/main/train_lora.py
|
|
||||||
import random
|
import random
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import os
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
from safetensors.torch import save_file, load_file
|
|
||||||
import torch.utils.checkpoint as cp
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from toolkit.config_modules import SliderConfig
|
from toolkit.config_modules import SliderConfig
|
||||||
from toolkit.layers import CheckpointGradients
|
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
||||||
from toolkit.paths import REPOS_ROOT
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from toolkit.stable_diffusion_model import PromptEmbeds
|
|
||||||
from toolkit.train_tools import get_torch_dtype
|
|
||||||
import gc
|
import gc
|
||||||
from toolkit import train_tools
|
from toolkit import train_tools
|
||||||
from toolkit.prompt_utils import \
|
from toolkit.prompt_utils import \
|
||||||
@@ -256,9 +244,8 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
|
|
||||||
noise_scheduler.set_timesteps(1000)
|
noise_scheduler.set_timesteps(1000)
|
||||||
|
|
||||||
current_timestep = noise_scheduler.timesteps[
|
current_timestep_index = int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
|
||||||
int(timesteps_to * 1000 / self.train_config.max_denoising_steps)
|
current_timestep = noise_scheduler.timesteps[current_timestep_index]
|
||||||
]
|
|
||||||
|
|
||||||
# flush() # 4.2GB to 3GB on 512x512
|
# flush() # 4.2GB to 3GB on 512x512
|
||||||
|
|
||||||
@@ -401,10 +388,16 @@ class TrainSliderProcess(BaseSDTrainProcess):
|
|||||||
offset_neutral += offset
|
offset_neutral += offset
|
||||||
|
|
||||||
# 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing
|
# 16.15 GB RAM for 512x512 -> 4.20GB RAM for 512x512 with new grad_checkpointing
|
||||||
loss = loss_function(
|
loss = torch.nn.functional.mse_loss(target_latents.float(), offset_neutral.float(), reduction="none")
|
||||||
target_latents,
|
loss = loss.mean([1, 2, 3])
|
||||||
offset_neutral,
|
|
||||||
) * prompt_pair_chunk.weight
|
if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
|
||||||
|
# match batch size
|
||||||
|
timesteps_index_list = [current_timestep_index for _ in range(target_latents.shape[0])]
|
||||||
|
# add min_snr_gamma
|
||||||
|
loss = apply_snr_weight(loss, timesteps_index_list, noise_scheduler, self.train_config.min_snr_gamma)
|
||||||
|
|
||||||
|
loss = loss.mean() * prompt_pair_chunk.weight
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
loss_list.append(loss.item())
|
loss_list.append(loss.item())
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ class TrainConfig:
|
|||||||
self.xformers = kwargs.get('xformers', False)
|
self.xformers = kwargs.get('xformers', False)
|
||||||
self.train_unet = kwargs.get('train_unet', True)
|
self.train_unet = kwargs.get('train_unet', True)
|
||||||
self.train_text_encoder = kwargs.get('train_text_encoder', True)
|
self.train_text_encoder = kwargs.get('train_text_encoder', True)
|
||||||
|
self.min_snr_gamma = kwargs.get('min_snr_gamma', None)
|
||||||
self.noise_offset = kwargs.get('noise_offset', 0.0)
|
self.noise_offset = kwargs.get('noise_offset', 0.0)
|
||||||
self.optimizer_params = kwargs.get('optimizer_params', {})
|
self.optimizer_params = kwargs.get('optimizer_params', {})
|
||||||
self.skip_first_sample = kwargs.get('skip_first_sample', False)
|
self.skip_first_sample = kwargs.get('skip_first_sample', False)
|
||||||
@@ -77,6 +78,10 @@ class ModelConfig:
|
|||||||
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
|
self.is_v_pred: bool = kwargs.get('is_v_pred', False)
|
||||||
self.dtype: str = kwargs.get('dtype', 'float16')
|
self.dtype: str = kwargs.get('dtype', 'float16')
|
||||||
|
|
||||||
|
# only for SDXL models for now
|
||||||
|
self.use_text_encoder_1: bool = kwargs.get('use_text_encoder_1', True)
|
||||||
|
self.use_text_encoder_2: bool = kwargs.get('use_text_encoder_2', True)
|
||||||
|
|
||||||
if self.name_or_path is None:
|
if self.name_or_path is None:
|
||||||
raise ValueError('name_or_path must be specified')
|
raise ValueError('name_or_path must be specified')
|
||||||
|
|
||||||
|
|||||||
@@ -16,10 +16,6 @@ from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
|||||||
from toolkit.metadata import get_meta_for_safetensors
|
from toolkit.metadata import get_meta_for_safetensors
|
||||||
from toolkit.paths import REPOS_ROOT
|
from toolkit.paths import REPOS_ROOT
|
||||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||||
|
|
||||||
sys.path.append(REPOS_ROOT)
|
|
||||||
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
|
|
||||||
from leco import train_util
|
|
||||||
import torch
|
import torch
|
||||||
from library import model_util
|
from library import model_util
|
||||||
from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl
|
from library.sdxl_model_util import convert_text_encoder_2_state_dict_to_sdxl
|
||||||
@@ -124,6 +120,9 @@ class StableDiffusion:
|
|||||||
self.is_xl = model_config.is_xl
|
self.is_xl = model_config.is_xl
|
||||||
self.is_v2 = model_config.is_v2
|
self.is_v2 = model_config.is_v2
|
||||||
|
|
||||||
|
self.use_text_encoder_1 = model_config.use_text_encoder_1
|
||||||
|
self.use_text_encoder_2 = model_config.use_text_encoder_2
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
if self.is_loaded:
|
if self.is_loaded:
|
||||||
return
|
return
|
||||||
@@ -309,6 +308,7 @@ class StableDiffusion:
|
|||||||
torch.manual_seed(gen_config.seed)
|
torch.manual_seed(gen_config.seed)
|
||||||
torch.cuda.manual_seed(gen_config.seed)
|
torch.cuda.manual_seed(gen_config.seed)
|
||||||
|
|
||||||
|
# todo do we disable text encoder here as well if disabled for model, or only do that for training?
|
||||||
if self.is_xl:
|
if self.is_xl:
|
||||||
img = pipeline(
|
img = pipeline(
|
||||||
prompt=gen_config.prompt,
|
prompt=gen_config.prompt,
|
||||||
@@ -393,7 +393,7 @@ class StableDiffusion:
|
|||||||
dtype = latents.dtype
|
dtype = latents.dtype
|
||||||
|
|
||||||
if self.is_xl:
|
if self.is_xl:
|
||||||
prompt_ids = train_util.get_add_time_ids(
|
prompt_ids = train_tools.get_add_time_ids(
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
dynamic_crops=False, # look into this
|
dynamic_crops=False, # look into this
|
||||||
@@ -444,7 +444,7 @@ class StableDiffusion:
|
|||||||
|
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
# todo check this with larget batches
|
# todo check this with larget batches
|
||||||
add_time_ids = train_util.concat_embeddings(
|
add_time_ids = train_tools.concat_embeddings(
|
||||||
add_time_ids, add_time_ids, int(latents.shape[0])
|
add_time_ids, add_time_ids, int(latents.shape[0])
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -459,6 +459,7 @@ class StableDiffusion:
|
|||||||
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep)
|
||||||
|
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
|
# todo can we zero here the second text encoder? or match a blank string?
|
||||||
"text_embeds": text_embeddings.pooled_embeds,
|
"text_embeds": text_embeddings.pooled_embeds,
|
||||||
"time_ids": add_time_ids,
|
"time_ids": add_time_ids,
|
||||||
}
|
}
|
||||||
@@ -541,16 +542,18 @@ class StableDiffusion:
|
|||||||
prompt = [prompt]
|
prompt = [prompt]
|
||||||
if self.is_xl:
|
if self.is_xl:
|
||||||
return PromptEmbeds(
|
return PromptEmbeds(
|
||||||
train_util.encode_prompts_xl(
|
train_tools.encode_prompts_xl(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
self.text_encoder,
|
self.text_encoder,
|
||||||
prompt,
|
prompt,
|
||||||
num_images_per_prompt=num_images_per_prompt,
|
num_images_per_prompt=num_images_per_prompt,
|
||||||
|
use_text_encoder_1=self.use_text_encoder_1,
|
||||||
|
use_text_encoder_2=self.use_text_encoder_2,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return PromptEmbeds(
|
return PromptEmbeds(
|
||||||
train_util.encode_prompts(
|
train_tools.encode_prompts(
|
||||||
self.tokenizer, self.text_encoder, prompt
|
self.tokenizer, self.text_encoder, prompt
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Union
|
||||||
import sys
|
import sys
|
||||||
from toolkit.paths import SD_SCRIPTS_ROOT
|
from toolkit.paths import SD_SCRIPTS_ROOT
|
||||||
|
|
||||||
@@ -32,6 +32,10 @@ SCHEDULER_LINEAR_END = 0.0120
|
|||||||
SCHEDULER_TIMESTEPS = 1000
|
SCHEDULER_TIMESTEPS = 1000
|
||||||
SCHEDLER_SCHEDULE = "scaled_linear"
|
SCHEDLER_SCHEDULE = "scaled_linear"
|
||||||
|
|
||||||
|
UNET_ATTENTION_TIME_EMBED_DIM = 256 # XL
|
||||||
|
TEXT_ENCODER_2_PROJECTION_DIM = 1280
|
||||||
|
UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM = 2816
|
||||||
|
|
||||||
|
|
||||||
def get_torch_dtype(dtype_str):
|
def get_torch_dtype(dtype_str):
|
||||||
# if it is a torch dtype, return it
|
# if it is a torch dtype, return it
|
||||||
@@ -433,3 +437,183 @@ def addnet_hash_legacy(b):
|
|||||||
b.seek(0x100000)
|
b.seek(0x100000)
|
||||||
m.update(b.read(0x10000))
|
m.update(b.read(0x10000))
|
||||||
return m.hexdigest()[0:8]
|
return m.hexdigest()[0:8]
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
||||||
|
|
||||||
|
|
||||||
|
def text_tokenize(
|
||||||
|
tokenizer: 'CLIPTokenizer', # 普通ならひとつ、XLならふたつ!
|
||||||
|
prompts: list[str],
|
||||||
|
):
|
||||||
|
return tokenizer(
|
||||||
|
prompts,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=tokenizer.model_max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
).input_ids
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348
|
||||||
|
def text_encode_xl(
|
||||||
|
text_encoder: Union['CLIPTextModel', 'CLIPTextModelWithProjection'],
|
||||||
|
tokens: torch.FloatTensor,
|
||||||
|
num_images_per_prompt: int = 1,
|
||||||
|
):
|
||||||
|
prompt_embeds = text_encoder(
|
||||||
|
tokens.to(text_encoder.device), output_hidden_states=True
|
||||||
|
)
|
||||||
|
pooled_prompt_embeds = prompt_embeds[0]
|
||||||
|
prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer
|
||||||
|
|
||||||
|
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||||
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||||
|
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||||
|
|
||||||
|
return prompt_embeds, pooled_prompt_embeds
|
||||||
|
|
||||||
|
|
||||||
|
def encode_prompts_xl(
|
||||||
|
tokenizers: list['CLIPTokenizer'],
|
||||||
|
text_encoders: list[Union['CLIPTextModel', 'CLIPTextModelWithProjection']],
|
||||||
|
prompts: list[str],
|
||||||
|
num_images_per_prompt: int = 1,
|
||||||
|
use_text_encoder_1: bool = True, # sdxl
|
||||||
|
use_text_encoder_2: bool = True # sdxl
|
||||||
|
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||||
|
# text_encoder and text_encoder_2's penuultimate layer's output
|
||||||
|
text_embeds_list = []
|
||||||
|
pooled_text_embeds = None # always text_encoder_2's pool
|
||||||
|
|
||||||
|
for idx, (tokenizer, text_encoder) in enumerate(zip(tokenizers, text_encoders)):
|
||||||
|
# todo, we are using a blank string to ignore that encoder for now.
|
||||||
|
# find a better way to do this (zeroing?, removing it from the unet?)
|
||||||
|
prompt_list_to_use = prompts
|
||||||
|
if idx == 0 and not use_text_encoder_1:
|
||||||
|
prompt_list_to_use = ["" for _ in prompts]
|
||||||
|
if idx == 1 and not use_text_encoder_2:
|
||||||
|
prompt_list_to_use = ["" for _ in prompts]
|
||||||
|
|
||||||
|
text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use)
|
||||||
|
text_embeds, pooled_text_embeds = text_encode_xl(
|
||||||
|
text_encoder, text_tokens_input_ids, num_images_per_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
text_embeds_list.append(text_embeds)
|
||||||
|
|
||||||
|
bs_embed = pooled_text_embeds.shape[0]
|
||||||
|
pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view(
|
||||||
|
bs_embed * num_images_per_prompt, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds
|
||||||
|
|
||||||
|
|
||||||
|
def text_encode(text_encoder: 'CLIPTextModel', tokens):
|
||||||
|
return text_encoder(tokens.to(text_encoder.device))[0]
|
||||||
|
|
||||||
|
|
||||||
|
def encode_prompts(
|
||||||
|
tokenizer: 'CLIPTokenizer',
|
||||||
|
text_encoder: 'CLIPTokenizer',
|
||||||
|
prompts: list[str],
|
||||||
|
):
|
||||||
|
text_tokens = text_tokenize(tokenizer, prompts)
|
||||||
|
text_embeddings = text_encode(text_encoder, text_tokens)
|
||||||
|
|
||||||
|
return text_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
# for XL
|
||||||
|
def get_add_time_ids(
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
dynamic_crops: bool = False,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
|
):
|
||||||
|
if dynamic_crops:
|
||||||
|
# random float scale between 1 and 3
|
||||||
|
random_scale = torch.rand(1).item() * 2 + 1
|
||||||
|
original_size = (int(height * random_scale), int(width * random_scale))
|
||||||
|
# random position
|
||||||
|
crops_coords_top_left = (
|
||||||
|
torch.randint(0, original_size[0] - height, (1,)).item(),
|
||||||
|
torch.randint(0, original_size[1] - width, (1,)).item(),
|
||||||
|
)
|
||||||
|
target_size = (height, width)
|
||||||
|
else:
|
||||||
|
original_size = (height, width)
|
||||||
|
crops_coords_top_left = (0, 0)
|
||||||
|
target_size = (height, width)
|
||||||
|
|
||||||
|
# this is expected as 6
|
||||||
|
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||||
|
|
||||||
|
# this is expected as 2816
|
||||||
|
passed_add_embed_dim = (
|
||||||
|
UNET_ATTENTION_TIME_EMBED_DIM * len(add_time_ids) # 256 * 6
|
||||||
|
+ TEXT_ENCODER_2_PROJECTION_DIM # + 1280
|
||||||
|
)
|
||||||
|
if passed_add_embed_dim != UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model expects an added time embedding vector of length {UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
||||||
|
)
|
||||||
|
|
||||||
|
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||||||
|
return add_time_ids
|
||||||
|
|
||||||
|
|
||||||
|
def concat_embeddings(
|
||||||
|
unconditional: torch.FloatTensor,
|
||||||
|
conditional: torch.FloatTensor,
|
||||||
|
n_imgs: int,
|
||||||
|
):
|
||||||
|
return torch.cat([unconditional, conditional]).repeat_interleave(n_imgs, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
def add_all_snr_to_noise_scheduler(noise_scheduler, device):
|
||||||
|
if hasattr(noise_scheduler, "all_snr"):
|
||||||
|
return
|
||||||
|
# compute it
|
||||||
|
with torch.no_grad():
|
||||||
|
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||||
|
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
||||||
|
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
||||||
|
alpha = sqrt_alphas_cumprod
|
||||||
|
sigma = sqrt_one_minus_alphas_cumprod
|
||||||
|
all_snr = (alpha / sigma) ** 2
|
||||||
|
all_snr.requires_grad = False
|
||||||
|
noise_scheduler.all_snr = all_snr.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_snr(noise_scheduler, device):
|
||||||
|
if hasattr(noise_scheduler, "all_snr"):
|
||||||
|
return noise_scheduler.all_snr.to(device)
|
||||||
|
# compute it
|
||||||
|
with torch.no_grad():
|
||||||
|
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||||
|
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
||||||
|
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
||||||
|
alpha = sqrt_alphas_cumprod
|
||||||
|
sigma = sqrt_one_minus_alphas_cumprod
|
||||||
|
all_snr = (alpha / sigma) ** 2
|
||||||
|
all_snr.requires_grad = False
|
||||||
|
return all_snr.to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_snr_weight(
|
||||||
|
loss,
|
||||||
|
timesteps,
|
||||||
|
noise_scheduler: Union['DDPMScheduler'],
|
||||||
|
gamma
|
||||||
|
):
|
||||||
|
# will get it form noise scheduler if exist or will calculate it if not
|
||||||
|
all_snr = get_all_snr(noise_scheduler, loss.device)
|
||||||
|
|
||||||
|
snr = torch.stack([all_snr[t] for t in timesteps])
|
||||||
|
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
||||||
|
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper
|
||||||
|
loss = loss * snr_weight
|
||||||
|
return loss
|
||||||
|
|||||||
Reference in New Issue
Block a user