mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
769 lines
29 KiB
Python
769 lines
29 KiB
Python
import argparse
|
|
import hashlib
|
|
import json
|
|
import os
|
|
import time
|
|
from typing import TYPE_CHECKING, Union
|
|
import sys
|
|
|
|
from torch.cuda.amp import GradScaler
|
|
|
|
from toolkit.paths import SD_SCRIPTS_ROOT
|
|
|
|
sys.path.append(SD_SCRIPTS_ROOT)
|
|
|
|
from diffusers import (
|
|
StableDiffusionPipeline,
|
|
DDPMScheduler,
|
|
EulerAncestralDiscreteScheduler,
|
|
DPMSolverMultistepScheduler,
|
|
DPMSolverSinglestepScheduler,
|
|
LMSDiscreteScheduler,
|
|
PNDMScheduler,
|
|
DDIMScheduler,
|
|
EulerDiscreteScheduler,
|
|
HeunDiscreteScheduler,
|
|
KDPM2DiscreteScheduler,
|
|
KDPM2AncestralDiscreteScheduler,
|
|
)
|
|
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
|
|
import torch
|
|
import re
|
|
|
|
SCHEDULER_LINEAR_START = 0.00085
|
|
SCHEDULER_LINEAR_END = 0.0120
|
|
SCHEDULER_TIMESTEPS = 1000
|
|
SCHEDLER_SCHEDULE = "scaled_linear"
|
|
|
|
UNET_ATTENTION_TIME_EMBED_DIM = 256 # XL
|
|
TEXT_ENCODER_2_PROJECTION_DIM = 1280
|
|
UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM = 2816
|
|
|
|
|
|
def get_torch_dtype(dtype_str):
|
|
# if it is a torch dtype, return it
|
|
if isinstance(dtype_str, torch.dtype):
|
|
return dtype_str
|
|
if dtype_str == "float" or dtype_str == "fp32" or dtype_str == "single" or dtype_str == "float32":
|
|
return torch.float
|
|
if dtype_str == "fp16" or dtype_str == "half" or dtype_str == "float16":
|
|
return torch.float16
|
|
if dtype_str == "bf16" or dtype_str == "bfloat16":
|
|
return torch.bfloat16
|
|
return dtype_str
|
|
|
|
|
|
def replace_filewords_prompt(prompt, args: argparse.Namespace):
|
|
# if name_replace attr in args (may not be)
|
|
if hasattr(args, "name_replace") and args.name_replace is not None:
|
|
# replace [name] to args.name_replace
|
|
prompt = prompt.replace("[name]", args.name_replace)
|
|
if hasattr(args, "prepend") and args.prepend is not None:
|
|
# prepend to every item in prompt file
|
|
prompt = args.prepend + ' ' + prompt
|
|
if hasattr(args, "append") and args.append is not None:
|
|
# append to every item in prompt file
|
|
prompt = prompt + ' ' + args.append
|
|
return prompt
|
|
|
|
|
|
def replace_filewords_in_dataset_group(dataset_group, args: argparse.Namespace):
|
|
# if name_replace attr in args (may not be)
|
|
if hasattr(args, "name_replace") and args.name_replace is not None:
|
|
if not len(dataset_group.image_data) > 0:
|
|
# throw error
|
|
raise ValueError("dataset_group.image_data is empty")
|
|
for key in dataset_group.image_data:
|
|
dataset_group.image_data[key].caption = dataset_group.image_data[key].caption.replace(
|
|
"[name]", args.name_replace)
|
|
|
|
return dataset_group
|
|
|
|
|
|
def get_seeds_from_latents(latents):
|
|
# latents shape = (batch_size, 4, height, width)
|
|
# for speed we only use 8x8 slice of the first channel
|
|
seeds = []
|
|
|
|
# split batch up
|
|
for i in range(latents.shape[0]):
|
|
# use only first channel, multiply by 255 and convert to int
|
|
tensor = latents[i, 0, :, :] * 255.0 # shape = (height, width)
|
|
# slice 8x8
|
|
tensor = tensor[:8, :8]
|
|
# clip to 0-255
|
|
tensor = torch.clamp(tensor, 0, 255)
|
|
# convert to 8bit int
|
|
tensor = tensor.to(torch.uint8)
|
|
# convert to bytes
|
|
tensor_bytes = tensor.cpu().numpy().tobytes()
|
|
# hash
|
|
hash_object = hashlib.sha256(tensor_bytes)
|
|
# get hex
|
|
hex_dig = hash_object.hexdigest()
|
|
# convert to int
|
|
seed = int(hex_dig, 16) % (2 ** 32)
|
|
# append
|
|
seeds.append(seed)
|
|
return seeds
|
|
|
|
|
|
def get_noise_from_latents(latents):
|
|
seed_list = get_seeds_from_latents(latents)
|
|
noise = []
|
|
for seed in seed_list:
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
noise.append(torch.randn_like(latents[0]))
|
|
return torch.stack(noise)
|
|
|
|
|
|
# mix 0 is completely noise mean, mix 1 is completely target mean
|
|
|
|
def match_noise_to_target_mean_offset(noise, target, mix=0.5, dim=None):
|
|
dim = dim or (1, 2, 3)
|
|
# reduce mean of noise on dim 2, 3, keeping 0 and 1 intact
|
|
noise_mean = noise.mean(dim=dim, keepdim=True)
|
|
target_mean = target.mean(dim=dim, keepdim=True)
|
|
|
|
new_noise_mean = mix * target_mean + (1 - mix) * noise_mean
|
|
|
|
noise = noise - noise_mean + new_noise_mean
|
|
return noise
|
|
|
|
|
|
def sample_images(
|
|
accelerator,
|
|
args: argparse.Namespace,
|
|
epoch,
|
|
steps,
|
|
device,
|
|
vae,
|
|
tokenizer,
|
|
text_encoder,
|
|
unet,
|
|
prompt_replacement=None,
|
|
force_sample=False
|
|
):
|
|
"""
|
|
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
|
|
"""
|
|
if not force_sample:
|
|
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
|
return
|
|
if args.sample_every_n_epochs is not None:
|
|
# sample_every_n_steps は無視する
|
|
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
|
return
|
|
else:
|
|
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
|
return
|
|
|
|
is_sample_only = args.sample_only
|
|
is_generating_only = hasattr(args, "is_generating_only") and args.is_generating_only
|
|
|
|
print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}")
|
|
if not os.path.isfile(args.sample_prompts):
|
|
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
|
return
|
|
|
|
org_vae_device = vae.device # CPUにいるはず
|
|
vae.to(device)
|
|
|
|
# read prompts
|
|
|
|
# with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
|
# prompts = f.readlines()
|
|
|
|
if args.sample_prompts.endswith(".txt"):
|
|
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
|
lines = f.readlines()
|
|
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
|
elif args.sample_prompts.endswith(".json"):
|
|
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
|
prompts = json.load(f)
|
|
|
|
# schedulerを用意する
|
|
sched_init_args = {}
|
|
if args.sample_sampler == "ddim":
|
|
scheduler_cls = DDIMScheduler
|
|
elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
|
|
scheduler_cls = DDPMScheduler
|
|
elif args.sample_sampler == "pndm":
|
|
scheduler_cls = PNDMScheduler
|
|
elif args.sample_sampler == "lms" or args.sample_sampler == "k_lms":
|
|
scheduler_cls = LMSDiscreteScheduler
|
|
elif args.sample_sampler == "euler" or args.sample_sampler == "k_euler":
|
|
scheduler_cls = EulerDiscreteScheduler
|
|
elif args.sample_sampler == "euler_a" or args.sample_sampler == "k_euler_a":
|
|
scheduler_cls = EulerAncestralDiscreteScheduler
|
|
elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
|
|
scheduler_cls = DPMSolverMultistepScheduler
|
|
sched_init_args["algorithm_type"] = args.sample_sampler
|
|
elif args.sample_sampler == "dpmsingle":
|
|
scheduler_cls = DPMSolverSinglestepScheduler
|
|
elif args.sample_sampler == "heun":
|
|
scheduler_cls = HeunDiscreteScheduler
|
|
elif args.sample_sampler == "dpm_2" or args.sample_sampler == "k_dpm_2":
|
|
scheduler_cls = KDPM2DiscreteScheduler
|
|
elif args.sample_sampler == "dpm_2_a" or args.sample_sampler == "k_dpm_2_a":
|
|
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
|
else:
|
|
scheduler_cls = DDIMScheduler
|
|
|
|
if args.v_parameterization:
|
|
sched_init_args["prediction_type"] = "v_prediction"
|
|
|
|
scheduler = scheduler_cls(
|
|
num_train_timesteps=SCHEDULER_TIMESTEPS,
|
|
beta_start=SCHEDULER_LINEAR_START,
|
|
beta_end=SCHEDULER_LINEAR_END,
|
|
beta_schedule=SCHEDLER_SCHEDULE,
|
|
**sched_init_args,
|
|
)
|
|
|
|
# clip_sample=Trueにする
|
|
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
|
# print("set clip_sample to True")
|
|
scheduler.config.clip_sample = True
|
|
|
|
pipeline = StableDiffusionLongPromptWeightingPipeline(
|
|
text_encoder=text_encoder,
|
|
vae=vae,
|
|
unet=unet,
|
|
tokenizer=tokenizer,
|
|
scheduler=scheduler,
|
|
clip_skip=args.clip_skip,
|
|
safety_checker=None,
|
|
feature_extractor=None,
|
|
requires_safety_checker=False,
|
|
)
|
|
pipeline.to(device)
|
|
|
|
if is_generating_only:
|
|
save_dir = args.output_dir
|
|
else:
|
|
save_dir = args.output_dir + "/sample"
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
rng_state = torch.get_rng_state()
|
|
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
|
|
|
with torch.no_grad():
|
|
with accelerator.autocast():
|
|
for i, prompt in enumerate(prompts):
|
|
if not accelerator.is_main_process:
|
|
continue
|
|
|
|
if isinstance(prompt, dict):
|
|
negative_prompt = prompt.get("negative_prompt")
|
|
sample_steps = prompt.get("sample_steps", 30)
|
|
width = prompt.get("width", 512)
|
|
height = prompt.get("height", 512)
|
|
scale = prompt.get("scale", 7.5)
|
|
seed = prompt.get("seed")
|
|
prompt = prompt.get("prompt")
|
|
|
|
prompt = replace_filewords_prompt(prompt, args)
|
|
negative_prompt = replace_filewords_prompt(negative_prompt, args)
|
|
else:
|
|
prompt = replace_filewords_prompt(prompt, args)
|
|
# prompt = prompt.strip()
|
|
# if len(prompt) == 0 or prompt[0] == "#":
|
|
# continue
|
|
|
|
# subset of gen_img_diffusers
|
|
prompt_args = prompt.split(" --")
|
|
prompt = prompt_args[0]
|
|
negative_prompt = None
|
|
sample_steps = 30
|
|
width = height = 512
|
|
scale = 7.5
|
|
seed = None
|
|
for parg in prompt_args:
|
|
try:
|
|
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
|
if m:
|
|
width = int(m.group(1))
|
|
continue
|
|
|
|
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
|
if m:
|
|
height = int(m.group(1))
|
|
continue
|
|
|
|
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
|
if m:
|
|
seed = int(m.group(1))
|
|
continue
|
|
|
|
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
|
if m: # steps
|
|
sample_steps = max(1, min(1000, int(m.group(1))))
|
|
continue
|
|
|
|
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
|
if m: # scale
|
|
scale = float(m.group(1))
|
|
continue
|
|
|
|
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
|
if m: # negative prompt
|
|
negative_prompt = m.group(1)
|
|
continue
|
|
|
|
except ValueError as ex:
|
|
print(f"Exception in parsing / 解析エラー: {parg}")
|
|
print(ex)
|
|
|
|
if seed is not None:
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
|
|
if prompt_replacement is not None:
|
|
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
|
if negative_prompt is not None:
|
|
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
|
|
|
height = max(64, height - height % 8) # round to divisible by 8
|
|
width = max(64, width - width % 8) # round to divisible by 8
|
|
print(f"prompt: {prompt}")
|
|
print(f"negative_prompt: {negative_prompt}")
|
|
print(f"height: {height}")
|
|
print(f"width: {width}")
|
|
print(f"sample_steps: {sample_steps}")
|
|
print(f"scale: {scale}")
|
|
image = pipeline(
|
|
prompt=prompt,
|
|
height=height,
|
|
width=width,
|
|
num_inference_steps=sample_steps,
|
|
guidance_scale=scale,
|
|
negative_prompt=negative_prompt,
|
|
).images[0]
|
|
|
|
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
|
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
|
seed_suffix = "" if seed is None else f"_{seed}"
|
|
|
|
if is_generating_only:
|
|
img_filename = (
|
|
f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
|
|
)
|
|
else:
|
|
img_filename = (
|
|
f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{i:04d}{seed_suffix}.png"
|
|
)
|
|
if is_sample_only:
|
|
# make prompt txt file
|
|
img_path_no_ext = os.path.join(save_dir, img_filename[:-4])
|
|
with open(img_path_no_ext + ".txt", "w") as f:
|
|
# put prompt in txt file
|
|
f.write(prompt)
|
|
# close file
|
|
f.close()
|
|
|
|
image.save(os.path.join(save_dir, img_filename))
|
|
|
|
# wandb有効時のみログを送信
|
|
try:
|
|
wandb_tracker = accelerator.get_tracker("wandb")
|
|
try:
|
|
import wandb
|
|
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
|
|
raise ImportError("No wandb / wandb がインストールされていないようです")
|
|
|
|
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
|
|
except: # wandb 無効時
|
|
pass
|
|
|
|
# clear pipeline and cache to reduce vram usage
|
|
del pipeline
|
|
torch.cuda.empty_cache()
|
|
|
|
torch.set_rng_state(rng_state)
|
|
if cuda_rng_state is not None:
|
|
torch.cuda.set_rng_state(cuda_rng_state)
|
|
vae.to(org_vae_device)
|
|
|
|
|
|
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
|
def apply_noise_offset(noise, noise_offset):
|
|
if noise_offset is None or noise_offset < 0.0000001:
|
|
return noise
|
|
noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device)
|
|
return noise
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from toolkit.stable_diffusion_model import PromptEmbeds
|
|
|
|
|
|
def concat_prompt_embeddings(
|
|
unconditional: 'PromptEmbeds',
|
|
conditional: 'PromptEmbeds',
|
|
n_imgs: int,
|
|
):
|
|
from toolkit.stable_diffusion_model import PromptEmbeds
|
|
text_embeds = torch.cat(
|
|
[unconditional.text_embeds, conditional.text_embeds]
|
|
).repeat_interleave(n_imgs, dim=0)
|
|
pooled_embeds = None
|
|
if unconditional.pooled_embeds is not None and conditional.pooled_embeds is not None:
|
|
pooled_embeds = torch.cat(
|
|
[unconditional.pooled_embeds, conditional.pooled_embeds]
|
|
).repeat_interleave(n_imgs, dim=0)
|
|
return PromptEmbeds([text_embeds, pooled_embeds])
|
|
|
|
|
|
def addnet_hash_safetensors(b):
|
|
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
|
|
hash_sha256 = hashlib.sha256()
|
|
blksize = 1024 * 1024
|
|
|
|
b.seek(0)
|
|
header = b.read(8)
|
|
n = int.from_bytes(header, "little")
|
|
|
|
offset = n + 8
|
|
b.seek(offset)
|
|
for chunk in iter(lambda: b.read(blksize), b""):
|
|
hash_sha256.update(chunk)
|
|
|
|
return hash_sha256.hexdigest()
|
|
|
|
|
|
def addnet_hash_legacy(b):
|
|
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
|
|
m = hashlib.sha256()
|
|
|
|
b.seek(0x100000)
|
|
m.update(b.read(0x10000))
|
|
return m.hexdigest()[0:8]
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
|
|
|
|
|
def text_tokenize(
|
|
tokenizer: 'CLIPTokenizer',
|
|
prompts: list[str],
|
|
truncate: bool = True,
|
|
max_length: int = None,
|
|
max_length_multiplier: int = 4,
|
|
):
|
|
# allow fo up to 4x the max length for long prompts
|
|
if max_length is None:
|
|
if truncate:
|
|
max_length = tokenizer.model_max_length
|
|
else:
|
|
# allow up to 4x the max length for long prompts
|
|
max_length = tokenizer.model_max_length * max_length_multiplier
|
|
|
|
input_ids = tokenizer(
|
|
prompts,
|
|
padding='max_length',
|
|
max_length=max_length,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
).input_ids
|
|
|
|
if truncate or max_length == tokenizer.model_max_length:
|
|
return input_ids
|
|
else:
|
|
# remove additional padding
|
|
num_chunks = input_ids.shape[1] // tokenizer.model_max_length
|
|
chunks = torch.chunk(input_ids, chunks=num_chunks, dim=1)
|
|
|
|
# New list to store non-redundant chunks
|
|
non_redundant_chunks = []
|
|
|
|
for chunk in chunks:
|
|
if not chunk.eq(chunk[0, 0]).all(): # Check if all elements in the chunk are the same as the first element
|
|
non_redundant_chunks.append(chunk)
|
|
|
|
input_ids = torch.cat(non_redundant_chunks, dim=1)
|
|
return input_ids
|
|
|
|
|
|
# https://github.com/huggingface/diffusers/blob/78922ed7c7e66c20aa95159c7b7a6057ba7d590d/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L334-L348
|
|
def text_encode_xl(
|
|
text_encoder: Union['CLIPTextModel', 'CLIPTextModelWithProjection'],
|
|
tokens: torch.FloatTensor,
|
|
num_images_per_prompt: int = 1,
|
|
max_length: int = 77, # not sure what default to put here, always pass one?
|
|
truncate: bool = True,
|
|
):
|
|
if truncate:
|
|
# normal short prompt 77 tokens max
|
|
prompt_embeds = text_encoder(
|
|
tokens.to(text_encoder.device), output_hidden_states=True
|
|
)
|
|
pooled_prompt_embeds = prompt_embeds[0]
|
|
prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer
|
|
else:
|
|
# handle long prompts
|
|
prompt_embeds_list = []
|
|
tokens = tokens.to(text_encoder.device)
|
|
pooled_prompt_embeds = None
|
|
for i in range(0, tokens.shape[-1], max_length):
|
|
# todo run it through the in a single batch
|
|
section_tokens = tokens[:, i: i + max_length]
|
|
embeds = text_encoder(section_tokens, output_hidden_states=True)
|
|
pooled_prompt_embed = embeds[0]
|
|
if pooled_prompt_embeds is None:
|
|
# we only want the first ( I think??)
|
|
pooled_prompt_embeds = pooled_prompt_embed
|
|
prompt_embed = embeds.hidden_states[-2] # always penultimate layer
|
|
prompt_embeds_list.append(prompt_embed)
|
|
|
|
prompt_embeds = torch.cat(prompt_embeds_list, dim=1)
|
|
|
|
bs_embed, seq_len, _ = prompt_embeds.shape
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
|
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
|
|
|
return prompt_embeds, pooled_prompt_embeds
|
|
|
|
|
|
def encode_prompts_xl(
|
|
tokenizers: list['CLIPTokenizer'],
|
|
text_encoders: list[Union['CLIPTextModel', 'CLIPTextModelWithProjection']],
|
|
prompts: list[str],
|
|
prompts2: Union[list[str], None],
|
|
num_images_per_prompt: int = 1,
|
|
use_text_encoder_1: bool = True, # sdxl
|
|
use_text_encoder_2: bool = True, # sdxl
|
|
truncate: bool = True,
|
|
max_length=None,
|
|
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
|
# text_encoder and text_encoder_2's penuultimate layer's output
|
|
text_embeds_list = []
|
|
pooled_text_embeds = None # always text_encoder_2's pool
|
|
if prompts2 is None:
|
|
prompts2 = prompts
|
|
|
|
for idx, (tokenizer, text_encoder) in enumerate(zip(tokenizers, text_encoders)):
|
|
# todo, we are using a blank string to ignore that encoder for now.
|
|
# find a better way to do this (zeroing?, removing it from the unet?)
|
|
prompt_list_to_use = prompts if idx == 0 else prompts2
|
|
if idx == 0 and not use_text_encoder_1:
|
|
prompt_list_to_use = ["" for _ in prompts]
|
|
if idx == 1 and not use_text_encoder_2:
|
|
prompt_list_to_use = ["" for _ in prompts]
|
|
|
|
text_tokens_input_ids = text_tokenize(tokenizer, prompt_list_to_use, truncate=truncate, max_length=max_length)
|
|
# set the max length for the next one
|
|
if idx == 0:
|
|
max_length = text_tokens_input_ids.shape[-1]
|
|
|
|
text_embeds, pooled_text_embeds = text_encode_xl(
|
|
text_encoder, text_tokens_input_ids, num_images_per_prompt, max_length=tokenizer.model_max_length,
|
|
truncate=truncate
|
|
)
|
|
|
|
text_embeds_list.append(text_embeds)
|
|
|
|
bs_embed = pooled_text_embeds.shape[0]
|
|
pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view(
|
|
bs_embed * num_images_per_prompt, -1
|
|
)
|
|
|
|
return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds
|
|
|
|
|
|
# ref for long prompts https://github.com/huggingface/diffusers/issues/2136
|
|
def text_encode(text_encoder: 'CLIPTextModel', tokens, truncate: bool = True, max_length=None):
|
|
if max_length is None and not truncate:
|
|
raise ValueError("max_length must be set if truncate is True")
|
|
|
|
tokens = tokens.to(text_encoder.device)
|
|
|
|
if truncate:
|
|
return text_encoder(tokens)[0]
|
|
else:
|
|
# handle long prompts
|
|
prompt_embeds_list = []
|
|
for i in range(0, tokens.shape[-1], max_length):
|
|
prompt_embeds = text_encoder(tokens[:, i: i + max_length])[0]
|
|
prompt_embeds_list.append(prompt_embeds)
|
|
|
|
return torch.cat(prompt_embeds_list, dim=1)
|
|
|
|
|
|
def encode_prompts(
|
|
tokenizer: 'CLIPTokenizer',
|
|
text_encoder: 'CLIPTextModel',
|
|
prompts: list[str],
|
|
truncate: bool = True,
|
|
max_length=None,
|
|
):
|
|
if max_length is None:
|
|
max_length = tokenizer.model_max_length
|
|
text_tokens = text_tokenize(tokenizer, prompts, truncate=truncate, max_length=max_length)
|
|
text_embeddings = text_encode(text_encoder, text_tokens, truncate=truncate, max_length=max_length)
|
|
|
|
return text_embeddings
|
|
|
|
|
|
# for XL
|
|
def get_add_time_ids(
|
|
height: int,
|
|
width: int,
|
|
dynamic_crops: bool = False,
|
|
dtype: torch.dtype = torch.float32,
|
|
):
|
|
if dynamic_crops:
|
|
# random float scale between 1 and 3
|
|
random_scale = torch.rand(1).item() * 2 + 1
|
|
original_size = (int(height * random_scale), int(width * random_scale))
|
|
# random position
|
|
crops_coords_top_left = (
|
|
torch.randint(0, original_size[0] - height, (1,)).item(),
|
|
torch.randint(0, original_size[1] - width, (1,)).item(),
|
|
)
|
|
target_size = (height, width)
|
|
else:
|
|
original_size = (height, width)
|
|
crops_coords_top_left = (0, 0)
|
|
target_size = (height, width)
|
|
|
|
# this is expected as 6
|
|
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
|
|
|
# this is expected as 2816
|
|
passed_add_embed_dim = (
|
|
UNET_ATTENTION_TIME_EMBED_DIM * len(add_time_ids) # 256 * 6
|
|
+ TEXT_ENCODER_2_PROJECTION_DIM # + 1280
|
|
)
|
|
if passed_add_embed_dim != UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM:
|
|
raise ValueError(
|
|
f"Model expects an added time embedding vector of length {UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
|
)
|
|
|
|
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
|
return add_time_ids
|
|
|
|
|
|
def concat_embeddings(
|
|
unconditional: torch.FloatTensor,
|
|
conditional: torch.FloatTensor,
|
|
n_imgs: int,
|
|
):
|
|
return torch.cat([unconditional, conditional]).repeat_interleave(n_imgs, dim=0)
|
|
|
|
|
|
def add_all_snr_to_noise_scheduler(noise_scheduler, device):
|
|
if hasattr(noise_scheduler, "all_snr"):
|
|
return
|
|
# compute it
|
|
with torch.no_grad():
|
|
alphas_cumprod = noise_scheduler.alphas_cumprod
|
|
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
|
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
|
alpha = sqrt_alphas_cumprod
|
|
sigma = sqrt_one_minus_alphas_cumprod
|
|
all_snr = (alpha / sigma) ** 2
|
|
all_snr.requires_grad = False
|
|
noise_scheduler.all_snr = all_snr.to(device)
|
|
|
|
|
|
def get_all_snr(noise_scheduler, device):
|
|
if hasattr(noise_scheduler, "all_snr"):
|
|
return noise_scheduler.all_snr.to(device)
|
|
# compute it
|
|
with torch.no_grad():
|
|
alphas_cumprod = noise_scheduler.alphas_cumprod
|
|
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
|
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
|
alpha = sqrt_alphas_cumprod
|
|
sigma = sqrt_one_minus_alphas_cumprod
|
|
all_snr = (alpha / sigma) ** 2
|
|
all_snr.requires_grad = False
|
|
return all_snr.to(device)
|
|
|
|
class LearnableSNRGamma:
|
|
"""
|
|
This is a trainer for learnable snr gamma
|
|
It will adapt to the dataset and attempt to adjust the snr multiplier to balance the loss over the timesteps
|
|
"""
|
|
def __init__(self, noise_scheduler: Union['DDPMScheduler'], device='cuda'):
|
|
self.device = device
|
|
self.noise_scheduler: Union['DDPMScheduler'] = noise_scheduler
|
|
self.offset_1 = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32, device=device))
|
|
self.offset_2 = torch.nn.Parameter(torch.tensor(0.777, dtype=torch.float32, device=device))
|
|
self.scale = torch.nn.Parameter(torch.tensor(4.14, dtype=torch.float32, device=device))
|
|
self.gamma = torch.nn.Parameter(torch.tensor(2.03, dtype=torch.float32, device=device))
|
|
self.optimizer = torch.optim.AdamW([self.offset_1, self.offset_2, self.gamma, self.scale], lr=0.01)
|
|
self.buffer = []
|
|
self.max_buffer_size = 20
|
|
|
|
def forward(self, loss, timesteps):
|
|
# do a our train loop for lsnr here and return our values detached
|
|
loss = loss.detach()
|
|
with torch.no_grad():
|
|
loss_chunks = torch.chunk(loss, loss.shape[0], dim=0)
|
|
for loss_chunk in loss_chunks:
|
|
self.buffer.append(loss_chunk.mean().detach())
|
|
if len(self.buffer) > self.max_buffer_size:
|
|
self.buffer.pop(0)
|
|
all_snr = get_all_snr(self.noise_scheduler, loss.device)
|
|
snr: torch.Tensor = torch.stack([all_snr[t] for t in timesteps]).detach().float().to(loss.device)
|
|
base_snrs = snr.clone().detach()
|
|
snr.requires_grad = True
|
|
snr = (snr + self.offset_1) * self.scale + self.offset_2
|
|
|
|
gamma_over_snr = torch.div(torch.ones_like(snr) * self.gamma, snr)
|
|
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
|
|
snr_adjusted_loss = loss * snr_weight
|
|
with torch.no_grad():
|
|
target = torch.mean(torch.stack(self.buffer)).detach()
|
|
|
|
# local_loss = torch.mean(torch.abs(snr_adjusted_loss - target))
|
|
squared_differences = (snr_adjusted_loss - target) ** 2
|
|
local_loss = torch.mean(squared_differences)
|
|
local_loss.backward()
|
|
self.optimizer.step()
|
|
self.optimizer.zero_grad()
|
|
|
|
return base_snrs, self.gamma.detach(), self.offset_1.detach(), self.offset_2.detach(), self.scale.detach()
|
|
|
|
|
|
def apply_learnable_snr_gos(
|
|
loss,
|
|
timesteps,
|
|
learnable_snr_trainer: LearnableSNRGamma
|
|
):
|
|
|
|
snr, gamma, offset_1, offset_2, scale = learnable_snr_trainer.forward(loss, timesteps)
|
|
|
|
snr = (snr + offset_1) * scale + offset_2
|
|
|
|
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
|
snr_weight = torch.abs(gamma_over_snr).float().to(loss.device) # directly using gamma over snr
|
|
snr_adjusted_loss = loss * snr_weight
|
|
|
|
return snr_adjusted_loss
|
|
|
|
|
|
def apply_snr_weight(
|
|
loss,
|
|
timesteps,
|
|
noise_scheduler: Union['DDPMScheduler'],
|
|
gamma,
|
|
fixed=False,
|
|
):
|
|
# will get it from noise scheduler if exist or will calculate it if not
|
|
all_snr = get_all_snr(noise_scheduler, loss.device)
|
|
|
|
snr = torch.stack([all_snr[t] for t in timesteps])
|
|
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
|
if fixed:
|
|
snr_weight = gamma_over_snr.float().to(loss.device) # directly using gamma over snr
|
|
else:
|
|
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device)
|
|
snr_adjusted_loss = loss * snr_weight
|
|
|
|
return snr_adjusted_loss
|