mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
2589 lines
116 KiB
Python
2589 lines
116 KiB
Python
import copy
|
|
import gc
|
|
import json
|
|
import random
|
|
import shutil
|
|
import typing
|
|
from typing import Union, List, Literal, Iterator
|
|
import sys
|
|
import os
|
|
from collections import OrderedDict
|
|
import copy
|
|
import yaml
|
|
from PIL import Image
|
|
from diffusers.pipelines.pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, \
|
|
ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN
|
|
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg
|
|
from safetensors.torch import save_file, load_file
|
|
from torch import autocast
|
|
from torch.nn import Parameter
|
|
from torch.utils.checkpoint import checkpoint
|
|
from tqdm import tqdm
|
|
from torchvision.transforms import Resize, transforms
|
|
|
|
from toolkit.assistant_lora import load_assistant_lora_from_path
|
|
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
|
from toolkit.custom_adapter import CustomAdapter
|
|
from toolkit.ip_adapter import IPAdapter
|
|
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
|
convert_vae_state_dict, load_vae
|
|
from toolkit import train_tools
|
|
from toolkit.config_modules import ModelConfig, GenerateImageConfig
|
|
from toolkit.metadata import get_meta_for_safetensors
|
|
from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT
|
|
from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds
|
|
from toolkit.reference_adapter import ReferenceAdapter
|
|
from toolkit.sampler import get_sampler
|
|
from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers
|
|
from toolkit.sd_device_states_presets import empty_preset
|
|
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
|
from einops import rearrange, repeat
|
|
import torch
|
|
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
|
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline
|
|
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
|
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, PixArtTransformer2DModel, \
|
|
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \
|
|
StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, \
|
|
StableDiffusion3Img2ImgPipeline, PixArtSigmaPipeline, AuraFlowPipeline, AuraFlowTransformer2DModel, FluxPipeline, \
|
|
FluxTransformer2DModel, FlowMatchEulerDiscreteScheduler
|
|
import diffusers
|
|
from diffusers import \
|
|
AutoencoderKL, \
|
|
UNet2DConditionModel
|
|
from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler, PixArtSigmaPipeline
|
|
from transformers import T5EncoderModel, BitsAndBytesConfig, UMT5EncoderModel, T5TokenizerFast
|
|
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
|
|
|
from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4
|
|
from typing import TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from toolkit.lora_special import LoRASpecialNetwork
|
|
|
|
# tell it to shut up
|
|
diffusers.logging.set_verbosity(diffusers.logging.ERROR)
|
|
|
|
SD_PREFIX_VAE = "vae"
|
|
SD_PREFIX_UNET = "unet"
|
|
SD_PREFIX_REFINER_UNET = "refiner_unet"
|
|
SD_PREFIX_TEXT_ENCODER = "te"
|
|
|
|
SD_PREFIX_TEXT_ENCODER1 = "te0"
|
|
SD_PREFIX_TEXT_ENCODER2 = "te1"
|
|
|
|
# prefixed diffusers keys
|
|
DO_NOT_TRAIN_WEIGHTS = [
|
|
"unet_time_embedding.linear_1.bias",
|
|
"unet_time_embedding.linear_1.weight",
|
|
"unet_time_embedding.linear_2.bias",
|
|
"unet_time_embedding.linear_2.weight",
|
|
"refiner_unet_time_embedding.linear_1.bias",
|
|
"refiner_unet_time_embedding.linear_1.weight",
|
|
"refiner_unet_time_embedding.linear_2.bias",
|
|
"refiner_unet_time_embedding.linear_2.weight",
|
|
]
|
|
|
|
DeviceStatePreset = Literal['cache_latents', 'generate']
|
|
|
|
|
|
class BlankNetwork:
|
|
|
|
def __init__(self):
|
|
self.multiplier = 1.0
|
|
self.is_active = True
|
|
self.is_merged_in = False
|
|
self.can_merge_in = False
|
|
|
|
def __enter__(self):
|
|
self.is_active = True
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self.is_active = False
|
|
|
|
|
|
def flush():
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|
|
|
|
|
|
UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。
|
|
# VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
|
|
|
|
|
|
|
class StableDiffusion:
|
|
|
|
def __init__(
|
|
self,
|
|
device,
|
|
model_config: ModelConfig,
|
|
dtype='fp16',
|
|
custom_pipeline=None,
|
|
noise_scheduler=None,
|
|
quantize_device=None,
|
|
):
|
|
self.custom_pipeline = custom_pipeline
|
|
self.device = device
|
|
self.dtype = dtype
|
|
self.torch_dtype = get_torch_dtype(dtype)
|
|
self.device_torch = torch.device(self.device)
|
|
|
|
self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device(
|
|
model_config.vae_device)
|
|
self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype)
|
|
|
|
self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device(
|
|
model_config.te_device)
|
|
self.te_torch_dtype = get_torch_dtype(model_config.te_dtype)
|
|
|
|
self.model_config = model_config
|
|
self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon"
|
|
|
|
self.device_state = None
|
|
|
|
self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline']
|
|
self.vae: Union[None, 'AutoencoderKL']
|
|
self.unet: Union[None, 'UNet2DConditionModel']
|
|
self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]]
|
|
self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']]
|
|
self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler
|
|
|
|
self.refiner_unet: Union[None, 'UNet2DConditionModel'] = None
|
|
self.assistant_lora: Union[None, 'LoRASpecialNetwork'] = None
|
|
|
|
# sdxl stuff
|
|
self.logit_scale = None
|
|
self.ckppt_info = None
|
|
self.is_loaded = False
|
|
|
|
# to hold network if there is one
|
|
self.network = None
|
|
self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None
|
|
self.is_xl = model_config.is_xl
|
|
self.is_v2 = model_config.is_v2
|
|
self.is_ssd = model_config.is_ssd
|
|
self.is_v3 = model_config.is_v3
|
|
self.is_vega = model_config.is_vega
|
|
self.is_pixart = model_config.is_pixart
|
|
self.is_auraflow = model_config.is_auraflow
|
|
self.is_flux = model_config.is_flux
|
|
|
|
self.use_text_encoder_1 = model_config.use_text_encoder_1
|
|
self.use_text_encoder_2 = model_config.use_text_encoder_2
|
|
|
|
self.config_file = None
|
|
|
|
self.is_flow_matching = False
|
|
if self.is_flux or self.is_v3 or self.is_auraflow:
|
|
self.is_flow_matching = True
|
|
|
|
self.quantize_device = quantize_device if quantize_device is not None else self.device
|
|
self.low_vram = self.model_config.low_vram
|
|
|
|
# merge in and preview active with -1 weight
|
|
self.invert_assistant_lora = False
|
|
|
|
def load_model(self):
|
|
if self.is_loaded:
|
|
return
|
|
dtype = get_torch_dtype(self.dtype)
|
|
|
|
# move the betas alphas and alphas_cumprod to device. Sometimed they get stuck on cpu, not sure why
|
|
# self.noise_scheduler.betas = self.noise_scheduler.betas.to(self.device_torch)
|
|
# self.noise_scheduler.alphas = self.noise_scheduler.alphas.to(self.device_torch)
|
|
# self.noise_scheduler.alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(self.device_torch)
|
|
|
|
model_path = self.model_config.name_or_path
|
|
if 'civitai.com' in self.model_config.name_or_path:
|
|
# load is a civit ai model, use the loader.
|
|
from toolkit.civitai import get_model_path_from_url
|
|
model_path = get_model_path_from_url(self.model_config.name_or_path)
|
|
|
|
load_args = {}
|
|
if self.noise_scheduler:
|
|
load_args['scheduler'] = self.noise_scheduler
|
|
|
|
if self.model_config.vae_path is not None:
|
|
load_args['vae'] = load_vae(self.model_config.vae_path, dtype)
|
|
if self.model_config.is_xl or self.model_config.is_ssd or self.model_config.is_vega:
|
|
if self.custom_pipeline is not None:
|
|
pipln = self.custom_pipeline
|
|
else:
|
|
pipln = StableDiffusionXLPipeline
|
|
# pipln = StableDiffusionKDiffusionXLPipeline
|
|
|
|
# see if path exists
|
|
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
|
# try to load with default diffusers
|
|
pipe = pipln.from_pretrained(
|
|
model_path,
|
|
dtype=dtype,
|
|
device=self.device_torch,
|
|
# variant="fp16",
|
|
use_safetensors=True,
|
|
**load_args
|
|
)
|
|
else:
|
|
pipe = pipln.from_single_file(
|
|
model_path,
|
|
device=self.device_torch,
|
|
torch_dtype=self.torch_dtype,
|
|
)
|
|
|
|
if 'vae' in load_args and load_args['vae'] is not None:
|
|
pipe.vae = load_args['vae']
|
|
flush()
|
|
|
|
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
|
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
|
|
for text_encoder in text_encoders:
|
|
text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype)
|
|
text_encoder.requires_grad_(False)
|
|
text_encoder.eval()
|
|
text_encoder = text_encoders
|
|
|
|
pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
|
|
|
|
if self.model_config.experimental_xl:
|
|
print("Experimental XL mode enabled")
|
|
print("Loading and injecting alt weights")
|
|
# load the mismatched weight and force it in
|
|
raw_state_dict = load_file(model_path)
|
|
replacement_weight = raw_state_dict['conditioner.embedders.1.model.text_projection'].clone()
|
|
del raw_state_dict
|
|
# get state dict for for 2nd text encoder
|
|
te1_state_dict = text_encoders[1].state_dict()
|
|
# replace weight with mismatched weight
|
|
te1_state_dict['text_projection.weight'] = replacement_weight.to(self.device_torch, dtype=dtype)
|
|
flush()
|
|
print("Injecting alt weights")
|
|
elif self.model_config.is_v3:
|
|
if self.custom_pipeline is not None:
|
|
pipln = self.custom_pipeline
|
|
else:
|
|
pipln = StableDiffusion3Pipeline
|
|
|
|
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
|
|
|
model_id = "stabilityai/stable-diffusion-3-medium"
|
|
text_encoder3 = T5EncoderModel.from_pretrained(
|
|
model_id,
|
|
subfolder="text_encoder_3",
|
|
# quantization_config=quantization_config,
|
|
revision="refs/pr/26",
|
|
device_map="cuda"
|
|
)
|
|
|
|
# see if path exists
|
|
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
|
try:
|
|
# try to load with default diffusers
|
|
pipe = pipln.from_pretrained(
|
|
model_path,
|
|
dtype=dtype,
|
|
device=self.device_torch,
|
|
text_encoder_3=text_encoder3,
|
|
# variant="fp16",
|
|
use_safetensors=True,
|
|
revision="refs/pr/26",
|
|
repo_type="model",
|
|
ignore_patterns=["*.md", "*..gitattributes"],
|
|
**load_args
|
|
)
|
|
except Exception as e:
|
|
print(f"Error loading from pretrained: {e}")
|
|
raise e
|
|
|
|
else:
|
|
pipe = pipln.from_single_file(
|
|
model_path,
|
|
device=self.device_torch,
|
|
torch_dtype=self.torch_dtype,
|
|
text_encoder_3=text_encoder3,
|
|
**load_args
|
|
)
|
|
|
|
flush()
|
|
|
|
text_encoders = [pipe.text_encoder, pipe.text_encoder_2, pipe.text_encoder_3]
|
|
tokenizer = [pipe.tokenizer, pipe.tokenizer_2, pipe.tokenizer_3]
|
|
# replace the to function with a no-op since it throws an error instead of a warning
|
|
# text_encoders[2].to = lambda *args, **kwargs: None
|
|
for text_encoder in text_encoders:
|
|
text_encoder.to(self.device_torch, dtype=dtype)
|
|
text_encoder.requires_grad_(False)
|
|
text_encoder.eval()
|
|
text_encoder = text_encoders
|
|
|
|
|
|
elif self.model_config.is_pixart:
|
|
te_kwargs = {}
|
|
# handle quantization of TE
|
|
te_is_quantized = False
|
|
if self.model_config.text_encoder_bits == 8:
|
|
te_kwargs['load_in_8bit'] = True
|
|
te_kwargs['device_map'] = "auto"
|
|
te_is_quantized = True
|
|
elif self.model_config.text_encoder_bits == 4:
|
|
te_kwargs['load_in_4bit'] = True
|
|
te_kwargs['device_map'] = "auto"
|
|
te_is_quantized = True
|
|
|
|
main_model_path = "PixArt-alpha/PixArt-XL-2-1024-MS"
|
|
if self.model_config.is_pixart_sigma:
|
|
main_model_path = "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers"
|
|
|
|
main_model_path = model_path
|
|
|
|
# load the TE in 8bit mode
|
|
text_encoder = T5EncoderModel.from_pretrained(
|
|
main_model_path,
|
|
subfolder="text_encoder",
|
|
torch_dtype=self.torch_dtype,
|
|
**te_kwargs
|
|
)
|
|
|
|
# load the transformer
|
|
subfolder = "transformer"
|
|
# check if it is just the unet
|
|
if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)):
|
|
subfolder = None
|
|
|
|
if te_is_quantized:
|
|
# replace the to function with a no-op since it throws an error instead of a warning
|
|
text_encoder.to = lambda *args, **kwargs: None
|
|
|
|
text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype)
|
|
|
|
if self.model_config.is_pixart_sigma:
|
|
# load the transformer only from the save
|
|
transformer = Transformer2DModel.from_pretrained(
|
|
model_path if self.model_config.unet_path is None else self.model_config.unet_path,
|
|
torch_dtype=self.torch_dtype,
|
|
subfolder='transformer'
|
|
)
|
|
pipe: PixArtSigmaPipeline = PixArtSigmaPipeline.from_pretrained(
|
|
main_model_path,
|
|
transformer=transformer,
|
|
text_encoder=text_encoder,
|
|
dtype=dtype,
|
|
device=self.device_torch,
|
|
**load_args
|
|
)
|
|
|
|
else:
|
|
|
|
# load the transformer only from the save
|
|
transformer = Transformer2DModel.from_pretrained(model_path, torch_dtype=self.torch_dtype,
|
|
subfolder=subfolder)
|
|
pipe: PixArtAlphaPipeline = PixArtAlphaPipeline.from_pretrained(
|
|
main_model_path,
|
|
transformer=transformer,
|
|
text_encoder=text_encoder,
|
|
dtype=dtype,
|
|
device=self.device_torch,
|
|
**load_args
|
|
).to(self.device_torch)
|
|
|
|
if self.model_config.unet_sample_size is not None:
|
|
pipe.transformer.config.sample_size = self.model_config.unet_sample_size
|
|
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
|
|
|
|
flush()
|
|
# text_encoder = pipe.text_encoder
|
|
# text_encoder.to(self.device_torch, dtype=dtype)
|
|
text_encoder.requires_grad_(False)
|
|
text_encoder.eval()
|
|
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
|
|
tokenizer = pipe.tokenizer
|
|
|
|
pipe.vae = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
|
|
if self.noise_scheduler is None:
|
|
self.noise_scheduler = pipe.scheduler
|
|
|
|
|
|
elif self.model_config.is_auraflow:
|
|
te_kwargs = {}
|
|
# handle quantization of TE
|
|
te_is_quantized = False
|
|
if self.model_config.text_encoder_bits == 8:
|
|
te_kwargs['load_in_8bit'] = True
|
|
te_kwargs['device_map'] = "auto"
|
|
te_is_quantized = True
|
|
elif self.model_config.text_encoder_bits == 4:
|
|
te_kwargs['load_in_4bit'] = True
|
|
te_kwargs['device_map'] = "auto"
|
|
te_is_quantized = True
|
|
|
|
main_model_path = model_path
|
|
|
|
# load the TE in 8bit mode
|
|
text_encoder = UMT5EncoderModel.from_pretrained(
|
|
main_model_path,
|
|
subfolder="text_encoder",
|
|
torch_dtype=self.torch_dtype,
|
|
**te_kwargs
|
|
)
|
|
|
|
# load the transformer
|
|
subfolder = "transformer"
|
|
# check if it is just the unet
|
|
if os.path.exists(model_path) and not os.path.exists(os.path.join(model_path, subfolder)):
|
|
subfolder = None
|
|
|
|
if te_is_quantized:
|
|
# replace the to function with a no-op since it throws an error instead of a warning
|
|
text_encoder.to = lambda *args, **kwargs: None
|
|
|
|
# load the transformer only from the save
|
|
transformer = AuraFlowTransformer2DModel.from_pretrained(
|
|
model_path if self.model_config.unet_path is None else self.model_config.unet_path,
|
|
torch_dtype=self.torch_dtype,
|
|
subfolder='transformer'
|
|
)
|
|
pipe: AuraFlowPipeline = AuraFlowPipeline.from_pretrained(
|
|
main_model_path,
|
|
transformer=transformer,
|
|
text_encoder=text_encoder,
|
|
dtype=dtype,
|
|
device=self.device_torch,
|
|
**load_args
|
|
)
|
|
|
|
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
|
|
|
|
# patch auraflow so it can handle other aspect ratios
|
|
# patch_auraflow_pos_embed(pipe.transformer.pos_embed)
|
|
|
|
flush()
|
|
# text_encoder = pipe.text_encoder
|
|
# text_encoder.to(self.device_torch, dtype=dtype)
|
|
text_encoder.requires_grad_(False)
|
|
text_encoder.eval()
|
|
pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype)
|
|
tokenizer = pipe.tokenizer
|
|
|
|
elif self.model_config.is_flux:
|
|
print("Loading Flux model")
|
|
base_model_path = "black-forest-labs/FLUX.1-schnell"
|
|
print("Loading transformer")
|
|
subfolder = 'transformer'
|
|
transformer_path = model_path
|
|
local_files_only = False
|
|
# check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set
|
|
if os.path.exists(transformer_path):
|
|
subfolder = None
|
|
transformer_path = os.path.join(transformer_path, 'transformer')
|
|
# check if the path is a full checkpoint.
|
|
te_folder_path = os.path.join(model_path, 'text_encoder')
|
|
# if we have the te, this folder is a full checkpoint, use it as the base
|
|
if os.path.exists(te_folder_path):
|
|
base_model_path = model_path
|
|
|
|
transformer = FluxTransformer2DModel.from_pretrained(
|
|
transformer_path,
|
|
subfolder=subfolder,
|
|
torch_dtype=dtype,
|
|
# low_cpu_mem_usage=False,
|
|
# device_map=None
|
|
)
|
|
if not self.low_vram:
|
|
# for low v ram, we leave it on the cpu. Quantizes slower, but allows training on primary gpu
|
|
transformer.to(torch.device(self.quantize_device), dtype=dtype)
|
|
flush()
|
|
|
|
if self.model_config.assistant_lora_path is not None:
|
|
if self.model_config.lora_path:
|
|
raise ValueError("Cannot load both assistant lora and lora at the same time")
|
|
|
|
if not self.is_flux:
|
|
raise ValueError("Assistant lora is only supported for flux models currently")
|
|
|
|
# handle downloading from the hub if needed
|
|
if not os.path.exists(self.model_config.assistant_lora_path):
|
|
print(f"Grabbing assistant lora from the hub: {self.model_config.assistant_lora_path}")
|
|
new_lora_path = hf_hub_download(
|
|
self.model_config.assistant_lora_path,
|
|
filename="pytorch_lora_weights.safetensors"
|
|
)
|
|
# replace the path
|
|
self.model_config.assistant_lora_path = new_lora_path
|
|
|
|
# for flux, we assume it is flux schnell. We cannot merge in the assistant lora and unmerge it on
|
|
# quantized weights so it had to process unmerged (slow). Since schnell samples in just 4 steps
|
|
# it is better to merge it in now, and sample slowly later, otherwise training is slowed in half
|
|
# so we will merge in now and sample with -1 weight later
|
|
self.invert_assistant_lora = True
|
|
# trigger it to get merged in
|
|
self.model_config.lora_path = self.model_config.assistant_lora_path
|
|
|
|
if self.model_config.lora_path is not None:
|
|
print("Fusing in LoRA")
|
|
# if doing low vram, do this on the gpu, painfully slow otherwise
|
|
if self.low_vram:
|
|
print(" - this process is painfully slow with 'low_vram' enabled. Disable it if possible.")
|
|
# need the pipe to do this unfortunately for now
|
|
# we have to fuse in the weights before quantizing
|
|
pipe: FluxPipeline = FluxPipeline(
|
|
scheduler=None,
|
|
text_encoder=None,
|
|
tokenizer=None,
|
|
text_encoder_2=None,
|
|
tokenizer_2=None,
|
|
vae=None,
|
|
transformer=transformer,
|
|
)
|
|
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
|
|
pipe.fuse_lora()
|
|
# unfortunately, not an easier way with peft
|
|
pipe.unload_lora_weights()
|
|
|
|
if self.model_config.quantize:
|
|
quantization_type = qfloat8
|
|
print("Quantizing transformer")
|
|
quantize(transformer, weights=quantization_type)
|
|
freeze(transformer)
|
|
transformer.to(self.device_torch)
|
|
else:
|
|
transformer.to(self.device_torch, dtype=dtype)
|
|
|
|
flush()
|
|
|
|
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
|
|
print("Loading vae")
|
|
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
|
|
flush()
|
|
|
|
print("Loading t5")
|
|
tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype)
|
|
text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2",
|
|
torch_dtype=dtype)
|
|
|
|
text_encoder_2.to(self.device_torch, dtype=dtype)
|
|
flush()
|
|
|
|
print("Quantizing T5")
|
|
quantize(text_encoder_2, weights=qfloat8)
|
|
freeze(text_encoder_2)
|
|
flush()
|
|
|
|
print("Loading clip")
|
|
text_encoder = CLIPTextModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=dtype)
|
|
tokenizer = CLIPTokenizer.from_pretrained(base_model_path, subfolder="tokenizer", torch_dtype=dtype)
|
|
text_encoder.to(self.device_torch, dtype=dtype)
|
|
|
|
print("making pipe")
|
|
pipe: FluxPipeline = FluxPipeline(
|
|
scheduler=scheduler,
|
|
text_encoder=text_encoder,
|
|
tokenizer=tokenizer,
|
|
text_encoder_2=None,
|
|
tokenizer_2=tokenizer_2,
|
|
vae=vae,
|
|
transformer=None,
|
|
)
|
|
pipe.text_encoder_2 = text_encoder_2
|
|
pipe.transformer = transformer
|
|
|
|
print("preparing")
|
|
|
|
text_encoder = [pipe.text_encoder, pipe.text_encoder_2]
|
|
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
|
|
|
|
pipe.transformer = pipe.transformer.to(self.device_torch)
|
|
|
|
flush()
|
|
text_encoder[0].to(self.device_torch)
|
|
text_encoder[0].requires_grad_(False)
|
|
text_encoder[0].eval()
|
|
text_encoder[1].to(self.device_torch)
|
|
text_encoder[1].requires_grad_(False)
|
|
text_encoder[1].eval()
|
|
pipe.transformer = pipe.transformer.to(self.device_torch)
|
|
flush()
|
|
else:
|
|
if self.custom_pipeline is not None:
|
|
pipln = self.custom_pipeline
|
|
else:
|
|
pipln = StableDiffusionPipeline
|
|
|
|
if self.model_config.text_encoder_bits < 16:
|
|
# this is only supported for T5 models for now
|
|
te_kwargs = {}
|
|
# handle quantization of TE
|
|
te_is_quantized = False
|
|
if self.model_config.text_encoder_bits == 8:
|
|
te_kwargs['load_in_8bit'] = True
|
|
te_kwargs['device_map'] = "auto"
|
|
te_is_quantized = True
|
|
elif self.model_config.text_encoder_bits == 4:
|
|
te_kwargs['load_in_4bit'] = True
|
|
te_kwargs['device_map'] = "auto"
|
|
te_is_quantized = True
|
|
|
|
text_encoder = T5EncoderModel.from_pretrained(
|
|
model_path,
|
|
subfolder="text_encoder",
|
|
torch_dtype=self.te_torch_dtype,
|
|
**te_kwargs
|
|
)
|
|
# replace the to function with a no-op since it throws an error instead of a warning
|
|
text_encoder.to = lambda *args, **kwargs: None
|
|
|
|
load_args['text_encoder'] = text_encoder
|
|
|
|
# see if path exists
|
|
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
|
# try to load with default diffusers
|
|
pipe = pipln.from_pretrained(
|
|
model_path,
|
|
dtype=dtype,
|
|
device=self.device_torch,
|
|
load_safety_checker=False,
|
|
requires_safety_checker=False,
|
|
safety_checker=None,
|
|
# variant="fp16",
|
|
trust_remote_code=True,
|
|
**load_args
|
|
)
|
|
else:
|
|
pipe = pipln.from_single_file(
|
|
model_path,
|
|
dtype=dtype,
|
|
device=self.device_torch,
|
|
load_safety_checker=False,
|
|
requires_safety_checker=False,
|
|
torch_dtype=self.torch_dtype,
|
|
safety_checker=None,
|
|
trust_remote_code=True,
|
|
**load_args
|
|
)
|
|
flush()
|
|
|
|
pipe.register_to_config(requires_safety_checker=False)
|
|
text_encoder = pipe.text_encoder
|
|
text_encoder.to(self.te_device_torch, dtype=self.te_torch_dtype)
|
|
text_encoder.requires_grad_(False)
|
|
text_encoder.eval()
|
|
tokenizer = pipe.tokenizer
|
|
|
|
# scheduler doesn't get set sometimes, so we set it here
|
|
pipe.scheduler = self.noise_scheduler
|
|
|
|
# add hacks to unet to help training
|
|
# pipe.unet = prepare_unet_for_training(pipe.unet)
|
|
|
|
if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux:
|
|
# pixart and sd3 dont use a unet
|
|
self.unet = pipe.transformer
|
|
else:
|
|
self.unet: 'UNet2DConditionModel' = pipe.unet
|
|
self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype)
|
|
self.vae.eval()
|
|
self.vae.requires_grad_(False)
|
|
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
|
self.vae_scale_factor = VAE_SCALE_FACTOR
|
|
self.unet.to(self.device_torch, dtype=dtype)
|
|
self.unet.requires_grad_(False)
|
|
self.unet.eval()
|
|
|
|
# load any loras we have
|
|
if self.model_config.lora_path is not None and not self.is_flux:
|
|
pipe.load_lora_weights(self.model_config.lora_path, adapter_name="lora1")
|
|
pipe.fuse_lora()
|
|
# unfortunately, not an easier way with peft
|
|
pipe.unload_lora_weights()
|
|
|
|
self.tokenizer = tokenizer
|
|
self.text_encoder = text_encoder
|
|
self.pipeline = pipe
|
|
self.load_refiner()
|
|
self.is_loaded = True
|
|
|
|
if self.model_config.assistant_lora_path is not None:
|
|
print("Loading assistant lora")
|
|
self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path(
|
|
self.model_config.assistant_lora_path, self)
|
|
|
|
if self.invert_assistant_lora:
|
|
# invert and disable during training
|
|
self.assistant_lora.multiplier = -1.0
|
|
self.assistant_lora.is_active = False
|
|
|
|
if self.is_pixart and self.vae_scale_factor == 16:
|
|
# TODO make our own pipeline?
|
|
# we generate an image 2x larger, so we need to copy the sizes from larger ones down
|
|
# ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN
|
|
for key in ASPECT_RATIO_256_BIN.keys():
|
|
ASPECT_RATIO_256_BIN[key] = [ASPECT_RATIO_256_BIN[key][0] * 2, ASPECT_RATIO_256_BIN[key][1] * 2]
|
|
for key in ASPECT_RATIO_512_BIN.keys():
|
|
ASPECT_RATIO_512_BIN[key] = [ASPECT_RATIO_512_BIN[key][0] * 2, ASPECT_RATIO_512_BIN[key][1] * 2]
|
|
for key in ASPECT_RATIO_1024_BIN.keys():
|
|
ASPECT_RATIO_1024_BIN[key] = [ASPECT_RATIO_1024_BIN[key][0] * 2, ASPECT_RATIO_1024_BIN[key][1] * 2]
|
|
for key in ASPECT_RATIO_2048_BIN.keys():
|
|
ASPECT_RATIO_2048_BIN[key] = [ASPECT_RATIO_2048_BIN[key][0] * 2, ASPECT_RATIO_2048_BIN[key][1] * 2]
|
|
|
|
def te_train(self):
|
|
if isinstance(self.text_encoder, list):
|
|
for te in self.text_encoder:
|
|
te.train()
|
|
else:
|
|
self.text_encoder.train()
|
|
|
|
def te_eval(self):
|
|
if isinstance(self.text_encoder, list):
|
|
for te in self.text_encoder:
|
|
te.eval()
|
|
else:
|
|
self.text_encoder.eval()
|
|
|
|
def load_refiner(self):
|
|
# for now, we are just going to rely on the TE from the base model
|
|
# which is TE2 for SDXL and TE for SD (no refiner currently)
|
|
# and completely ignore a TE that may or may not be packaged with the refiner
|
|
if self.model_config.refiner_name_or_path is not None:
|
|
refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml')
|
|
# load the refiner model
|
|
dtype = get_torch_dtype(self.dtype)
|
|
model_path = self.model_config.refiner_name_or_path
|
|
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
|
# TODO only load unet??
|
|
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
|
model_path,
|
|
dtype=dtype,
|
|
device=self.device_torch,
|
|
# variant="fp16",
|
|
use_safetensors=True,
|
|
).to(self.device_torch)
|
|
else:
|
|
refiner = StableDiffusionXLImg2ImgPipeline.from_single_file(
|
|
model_path,
|
|
dtype=dtype,
|
|
device=self.device_torch,
|
|
torch_dtype=self.torch_dtype,
|
|
original_config_file=refiner_config_path,
|
|
).to(self.device_torch)
|
|
|
|
self.refiner_unet = refiner.unet
|
|
del refiner
|
|
flush()
|
|
|
|
@torch.no_grad()
|
|
def generate_images(
|
|
self,
|
|
image_configs: List[GenerateImageConfig],
|
|
sampler=None,
|
|
pipeline: Union[None, StableDiffusionPipeline, StableDiffusionXLPipeline] = None,
|
|
):
|
|
merge_multiplier = 1.0
|
|
flush()
|
|
# if using assistant, unfuse it
|
|
if self.model_config.assistant_lora_path is not None:
|
|
print("Unloading assistant lora")
|
|
if self.invert_assistant_lora:
|
|
self.assistant_lora.is_active = True
|
|
# move weights on to the device
|
|
self.assistant_lora.force_to(self.device_torch, self.torch_dtype)
|
|
else:
|
|
self.assistant_lora.is_active = False
|
|
|
|
if self.network is not None:
|
|
self.network.eval()
|
|
network = self.network
|
|
# check if we have the same network weight for all samples. If we do, we can merge in th
|
|
# the network to drastically speed up inference
|
|
unique_network_weights = set([x.network_multiplier for x in image_configs])
|
|
if len(unique_network_weights) == 1 and self.network.can_merge_in:
|
|
can_merge_in = True
|
|
merge_multiplier = unique_network_weights.pop()
|
|
network.merge_in(merge_weight=merge_multiplier)
|
|
else:
|
|
network = BlankNetwork()
|
|
|
|
self.save_device_state()
|
|
self.set_device_state_preset('generate')
|
|
|
|
# save current seed state for training
|
|
rng_state = torch.get_rng_state()
|
|
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
|
|
|
if pipeline is None:
|
|
noise_scheduler = self.noise_scheduler
|
|
if sampler is not None:
|
|
if sampler.startswith("sample_"): # sample_dpmpp_2m
|
|
# using ksampler
|
|
noise_scheduler = get_sampler(
|
|
'lms', {
|
|
"prediction_type": self.prediction_type,
|
|
})
|
|
else:
|
|
noise_scheduler = get_sampler(
|
|
sampler,
|
|
{
|
|
"prediction_type": self.prediction_type,
|
|
},
|
|
'sd' if not self.is_pixart else 'pixart'
|
|
)
|
|
|
|
try:
|
|
noise_scheduler = noise_scheduler.to(self.device_torch, self.torch_dtype)
|
|
except:
|
|
pass
|
|
|
|
if sampler.startswith("sample_") and self.is_xl:
|
|
# using kdiffusion
|
|
Pipe = StableDiffusionKDiffusionXLPipeline
|
|
elif self.is_xl:
|
|
Pipe = StableDiffusionXLPipeline
|
|
elif self.is_v3:
|
|
Pipe = StableDiffusion3Pipeline
|
|
else:
|
|
Pipe = StableDiffusionPipeline
|
|
|
|
extra_args = {}
|
|
if self.adapter is not None:
|
|
if isinstance(self.adapter, T2IAdapter):
|
|
if self.is_xl:
|
|
Pipe = StableDiffusionXLAdapterPipeline
|
|
else:
|
|
Pipe = StableDiffusionAdapterPipeline
|
|
extra_args['adapter'] = self.adapter
|
|
elif isinstance(self.adapter, ControlNetModel):
|
|
if self.is_xl:
|
|
Pipe = StableDiffusionXLControlNetPipeline
|
|
else:
|
|
Pipe = StableDiffusionControlNetPipeline
|
|
extra_args['controlnet'] = self.adapter
|
|
elif isinstance(self.adapter, ReferenceAdapter):
|
|
# pass the noise scheduler to the adapter
|
|
self.adapter.noise_scheduler = noise_scheduler
|
|
else:
|
|
if self.is_xl:
|
|
extra_args['add_watermarker'] = False
|
|
|
|
# TODO add clip skip
|
|
if self.is_xl:
|
|
pipeline = Pipe(
|
|
vae=self.vae,
|
|
unet=self.unet,
|
|
text_encoder=self.text_encoder[0],
|
|
text_encoder_2=self.text_encoder[1],
|
|
tokenizer=self.tokenizer[0],
|
|
tokenizer_2=self.tokenizer[1],
|
|
scheduler=noise_scheduler,
|
|
**extra_args
|
|
).to(self.device_torch)
|
|
pipeline.watermark = None
|
|
elif self.is_flux:
|
|
if self.model_config.use_flux_cfg:
|
|
pipeline = FluxWithCFGPipeline(
|
|
vae=self.vae,
|
|
transformer=self.unet,
|
|
text_encoder=self.text_encoder[0],
|
|
text_encoder_2=self.text_encoder[1],
|
|
tokenizer=self.tokenizer[0],
|
|
tokenizer_2=self.tokenizer[1],
|
|
scheduler=noise_scheduler,
|
|
**extra_args
|
|
)
|
|
|
|
else:
|
|
pipeline = FluxPipeline(
|
|
vae=self.vae,
|
|
transformer=self.unet,
|
|
text_encoder=self.text_encoder[0],
|
|
text_encoder_2=self.text_encoder[1],
|
|
tokenizer=self.tokenizer[0],
|
|
tokenizer_2=self.tokenizer[1],
|
|
scheduler=noise_scheduler,
|
|
**extra_args
|
|
)
|
|
pipeline.watermark = None
|
|
elif self.is_v3:
|
|
pipeline = Pipe(
|
|
vae=self.vae,
|
|
transformer=self.unet,
|
|
text_encoder=self.text_encoder[0],
|
|
text_encoder_2=self.text_encoder[1],
|
|
text_encoder_3=self.text_encoder[2],
|
|
tokenizer=self.tokenizer[0],
|
|
tokenizer_2=self.tokenizer[1],
|
|
tokenizer_3=self.tokenizer[2],
|
|
scheduler=noise_scheduler,
|
|
**extra_args
|
|
)
|
|
elif self.is_pixart:
|
|
pipeline = PixArtSigmaPipeline(
|
|
vae=self.vae,
|
|
transformer=self.unet,
|
|
text_encoder=self.text_encoder,
|
|
tokenizer=self.tokenizer,
|
|
scheduler=noise_scheduler,
|
|
**extra_args
|
|
)
|
|
|
|
elif self.is_auraflow:
|
|
pipeline = AuraFlowPipeline(
|
|
vae=self.vae,
|
|
transformer=self.unet,
|
|
text_encoder=self.text_encoder,
|
|
tokenizer=self.tokenizer,
|
|
scheduler=noise_scheduler,
|
|
**extra_args
|
|
)
|
|
|
|
else:
|
|
pipeline = Pipe(
|
|
vae=self.vae,
|
|
unet=self.unet,
|
|
text_encoder=self.text_encoder,
|
|
tokenizer=self.tokenizer,
|
|
scheduler=noise_scheduler,
|
|
safety_checker=None,
|
|
feature_extractor=None,
|
|
requires_safety_checker=False,
|
|
**extra_args
|
|
)
|
|
flush()
|
|
# disable progress bar
|
|
pipeline.set_progress_bar_config(disable=True)
|
|
|
|
if sampler.startswith("sample_"):
|
|
pipeline.set_scheduler(sampler)
|
|
|
|
refiner_pipeline = None
|
|
if self.refiner_unet:
|
|
# build refiner pipeline
|
|
refiner_pipeline = StableDiffusionXLImg2ImgPipeline(
|
|
vae=pipeline.vae,
|
|
unet=self.refiner_unet,
|
|
text_encoder=None,
|
|
text_encoder_2=pipeline.text_encoder_2,
|
|
tokenizer=None,
|
|
tokenizer_2=pipeline.tokenizer_2,
|
|
scheduler=pipeline.scheduler,
|
|
add_watermarker=False,
|
|
requires_aesthetics_score=True,
|
|
).to(self.device_torch)
|
|
# refiner_pipeline.register_to_config(requires_aesthetics_score=False)
|
|
refiner_pipeline.watermark = None
|
|
refiner_pipeline.set_progress_bar_config(disable=True)
|
|
flush()
|
|
|
|
start_multiplier = 1.0
|
|
if self.network is not None:
|
|
start_multiplier = self.network.multiplier
|
|
|
|
# pipeline.to(self.device_torch)
|
|
|
|
with network:
|
|
with torch.no_grad():
|
|
if self.network is not None:
|
|
assert self.network.is_active
|
|
|
|
for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False):
|
|
gen_config = image_configs[i]
|
|
|
|
extra = {}
|
|
validation_image = None
|
|
if self.adapter is not None and gen_config.adapter_image_path is not None:
|
|
validation_image = Image.open(gen_config.adapter_image_path).convert("RGB")
|
|
if isinstance(self.adapter, T2IAdapter):
|
|
# not sure why this is double??
|
|
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
|
|
extra['image'] = validation_image
|
|
extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale
|
|
if isinstance(self.adapter, ControlNetModel):
|
|
validation_image = validation_image.resize((gen_config.width, gen_config.height))
|
|
extra['image'] = validation_image
|
|
extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale
|
|
if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter):
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
])
|
|
validation_image = transform(validation_image)
|
|
if isinstance(self.adapter, CustomAdapter):
|
|
# todo allow loading multiple
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
])
|
|
validation_image = transform(validation_image)
|
|
self.adapter.num_images = 1
|
|
if isinstance(self.adapter, ReferenceAdapter):
|
|
# need -1 to 1
|
|
validation_image = transforms.ToTensor()(validation_image)
|
|
validation_image = validation_image * 2.0 - 1.0
|
|
validation_image = validation_image.unsqueeze(0)
|
|
self.adapter.set_reference_images(validation_image)
|
|
|
|
if self.network is not None:
|
|
self.network.multiplier = gen_config.network_multiplier
|
|
torch.manual_seed(gen_config.seed)
|
|
torch.cuda.manual_seed(gen_config.seed)
|
|
|
|
if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter) \
|
|
and gen_config.adapter_image_path is not None:
|
|
# run through the adapter to saturate the embeds
|
|
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
|
|
self.adapter(conditional_clip_embeds)
|
|
|
|
if self.adapter is not None and isinstance(self.adapter, CustomAdapter):
|
|
# handle condition the prompts
|
|
gen_config.prompt = self.adapter.condition_prompt(
|
|
gen_config.prompt,
|
|
is_unconditional=False,
|
|
)
|
|
gen_config.prompt_2 = gen_config.prompt
|
|
gen_config.negative_prompt = self.adapter.condition_prompt(
|
|
gen_config.negative_prompt,
|
|
is_unconditional=True,
|
|
)
|
|
gen_config.negative_prompt_2 = gen_config.negative_prompt
|
|
|
|
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None:
|
|
self.adapter.trigger_pre_te(
|
|
tensors_0_1=validation_image,
|
|
is_training=False,
|
|
has_been_preprocessed=False,
|
|
quad_count=4
|
|
)
|
|
|
|
# encode the prompt ourselves so we can do fun stuff with embeddings
|
|
if isinstance(self.adapter, CustomAdapter):
|
|
self.adapter.is_unconditional_run = False
|
|
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
|
|
|
|
if isinstance(self.adapter, CustomAdapter):
|
|
self.adapter.is_unconditional_run = True
|
|
unconditional_embeds = self.encode_prompt(
|
|
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
|
|
)
|
|
if isinstance(self.adapter, CustomAdapter):
|
|
self.adapter.is_unconditional_run = False
|
|
|
|
# allow any manipulations to take place to embeddings
|
|
gen_config.post_process_embeddings(
|
|
conditional_embeds,
|
|
unconditional_embeds,
|
|
)
|
|
|
|
if self.adapter is not None and isinstance(self.adapter, IPAdapter) \
|
|
and gen_config.adapter_image_path is not None:
|
|
# apply the image projection
|
|
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
|
|
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image,
|
|
True)
|
|
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
|
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
|
|
|
|
if self.adapter is not None and isinstance(self.adapter,
|
|
CustomAdapter) and validation_image is not None:
|
|
conditional_embeds = self.adapter.condition_encoded_embeds(
|
|
tensors_0_1=validation_image,
|
|
prompt_embeds=conditional_embeds,
|
|
is_training=False,
|
|
has_been_preprocessed=False,
|
|
is_generating_samples=True,
|
|
)
|
|
unconditional_embeds = self.adapter.condition_encoded_embeds(
|
|
tensors_0_1=validation_image,
|
|
prompt_embeds=unconditional_embeds,
|
|
is_training=False,
|
|
has_been_preprocessed=False,
|
|
is_unconditional=True,
|
|
is_generating_samples=True,
|
|
)
|
|
|
|
if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len(
|
|
gen_config.extra_values) > 0:
|
|
extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch,
|
|
dtype=self.torch_dtype)
|
|
# apply extra values to the embeddings
|
|
self.adapter.add_extra_values(extra_values, is_unconditional=False)
|
|
self.adapter.add_extra_values(torch.zeros_like(extra_values), is_unconditional=True)
|
|
pass # todo remove, for debugging
|
|
|
|
if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0:
|
|
# if we have a refiner loaded, set the denoising end at the refiner start
|
|
extra['denoising_end'] = gen_config.refiner_start_at
|
|
extra['output_type'] = 'latent'
|
|
if not self.is_xl:
|
|
raise ValueError("Refiner is only supported for XL models")
|
|
|
|
conditional_embeds = conditional_embeds.to(self.device_torch, dtype=self.unet.dtype)
|
|
unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=self.unet.dtype)
|
|
|
|
if self.is_xl:
|
|
# fix guidance rescale for sdxl
|
|
# was trained on 0.7 (I believe)
|
|
|
|
grs = gen_config.guidance_rescale
|
|
# if grs is None or grs < 0.00001:
|
|
# grs = 0.7
|
|
# grs = 0.0
|
|
|
|
if sampler.startswith("sample_"):
|
|
extra['use_karras_sigmas'] = True
|
|
extra = {
|
|
**extra,
|
|
**gen_config.extra_kwargs,
|
|
}
|
|
|
|
img = pipeline(
|
|
# prompt=gen_config.prompt,
|
|
# prompt_2=gen_config.prompt_2,
|
|
prompt_embeds=conditional_embeds.text_embeds,
|
|
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
|
|
negative_prompt_embeds=unconditional_embeds.text_embeds,
|
|
negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
|
|
# negative_prompt=gen_config.negative_prompt,
|
|
# negative_prompt_2=gen_config.negative_prompt_2,
|
|
height=gen_config.height,
|
|
width=gen_config.width,
|
|
num_inference_steps=gen_config.num_inference_steps,
|
|
guidance_scale=gen_config.guidance_scale,
|
|
guidance_rescale=grs,
|
|
latents=gen_config.latents,
|
|
**extra
|
|
).images[0]
|
|
elif self.is_v3:
|
|
img = pipeline(
|
|
prompt_embeds=conditional_embeds.text_embeds,
|
|
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
|
|
negative_prompt_embeds=unconditional_embeds.text_embeds,
|
|
negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
|
|
height=gen_config.height,
|
|
width=gen_config.width,
|
|
num_inference_steps=gen_config.num_inference_steps,
|
|
guidance_scale=gen_config.guidance_scale,
|
|
latents=gen_config.latents,
|
|
**extra
|
|
).images[0]
|
|
elif self.is_flux:
|
|
if self.model_config.use_flux_cfg:
|
|
img = pipeline(
|
|
prompt_embeds=conditional_embeds.text_embeds,
|
|
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
|
|
negative_prompt_embeds=unconditional_embeds.text_embeds,
|
|
negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
|
|
height=gen_config.height,
|
|
width=gen_config.width,
|
|
num_inference_steps=gen_config.num_inference_steps,
|
|
guidance_scale=gen_config.guidance_scale,
|
|
latents=gen_config.latents,
|
|
**extra
|
|
).images[0]
|
|
else:
|
|
img = pipeline(
|
|
prompt_embeds=conditional_embeds.text_embeds,
|
|
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
|
|
# negative_prompt_embeds=unconditional_embeds.text_embeds,
|
|
# negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
|
|
height=gen_config.height,
|
|
width=gen_config.width,
|
|
num_inference_steps=gen_config.num_inference_steps,
|
|
guidance_scale=gen_config.guidance_scale,
|
|
latents=gen_config.latents,
|
|
**extra
|
|
).images[0]
|
|
elif self.is_pixart:
|
|
# needs attention masks for some reason
|
|
img = pipeline(
|
|
prompt=None,
|
|
prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
|
|
prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch,
|
|
dtype=self.unet.dtype),
|
|
negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch,
|
|
dtype=self.unet.dtype),
|
|
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch,
|
|
dtype=self.unet.dtype),
|
|
negative_prompt=None,
|
|
# negative_prompt=gen_config.negative_prompt,
|
|
height=gen_config.height,
|
|
width=gen_config.width,
|
|
num_inference_steps=gen_config.num_inference_steps,
|
|
guidance_scale=gen_config.guidance_scale,
|
|
latents=gen_config.latents,
|
|
**extra
|
|
).images[0]
|
|
elif self.is_auraflow:
|
|
pipeline: AuraFlowPipeline = pipeline
|
|
|
|
img = pipeline(
|
|
prompt=None,
|
|
prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype),
|
|
prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch,
|
|
dtype=self.unet.dtype),
|
|
negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch,
|
|
dtype=self.unet.dtype),
|
|
negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch,
|
|
dtype=self.unet.dtype),
|
|
negative_prompt=None,
|
|
# negative_prompt=gen_config.negative_prompt,
|
|
height=gen_config.height,
|
|
width=gen_config.width,
|
|
num_inference_steps=gen_config.num_inference_steps,
|
|
guidance_scale=gen_config.guidance_scale,
|
|
latents=gen_config.latents,
|
|
**extra
|
|
).images[0]
|
|
else:
|
|
img = pipeline(
|
|
# prompt=gen_config.prompt,
|
|
prompt_embeds=conditional_embeds.text_embeds,
|
|
negative_prompt_embeds=unconditional_embeds.text_embeds,
|
|
# negative_prompt=gen_config.negative_prompt,
|
|
height=gen_config.height,
|
|
width=gen_config.width,
|
|
num_inference_steps=gen_config.num_inference_steps,
|
|
guidance_scale=gen_config.guidance_scale,
|
|
latents=gen_config.latents,
|
|
**extra
|
|
).images[0]
|
|
|
|
if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0:
|
|
# slide off just the last 1280 on the last dim as refiner does not use first text encoder
|
|
# todo, should we just use the Text encoder for the refiner? Fine tuned versions will differ
|
|
refiner_text_embeds = conditional_embeds.text_embeds[:, :, -1280:]
|
|
refiner_unconditional_text_embeds = unconditional_embeds.text_embeds[:, :, -1280:]
|
|
# run through refiner
|
|
img = refiner_pipeline(
|
|
# prompt=gen_config.prompt,
|
|
# prompt_2=gen_config.prompt_2,
|
|
|
|
# slice these as it does not use both text encoders
|
|
# height=gen_config.height,
|
|
# width=gen_config.width,
|
|
prompt_embeds=refiner_text_embeds,
|
|
pooled_prompt_embeds=conditional_embeds.pooled_embeds,
|
|
negative_prompt_embeds=refiner_unconditional_text_embeds,
|
|
negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds,
|
|
num_inference_steps=gen_config.num_inference_steps,
|
|
guidance_scale=gen_config.guidance_scale,
|
|
guidance_rescale=grs,
|
|
denoising_start=gen_config.refiner_start_at,
|
|
denoising_end=gen_config.num_inference_steps,
|
|
image=img.unsqueeze(0)
|
|
).images[0]
|
|
|
|
gen_config.save_image(img, i)
|
|
|
|
if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter):
|
|
self.adapter.clear_memory()
|
|
|
|
# clear pipeline and cache to reduce vram usage
|
|
del pipeline
|
|
if refiner_pipeline is not None:
|
|
del refiner_pipeline
|
|
torch.cuda.empty_cache()
|
|
|
|
# restore training state
|
|
torch.set_rng_state(rng_state)
|
|
if cuda_rng_state is not None:
|
|
torch.cuda.set_rng_state(cuda_rng_state)
|
|
|
|
self.restore_device_state()
|
|
if self.network is not None:
|
|
self.network.train()
|
|
self.network.multiplier = start_multiplier
|
|
|
|
self.unet.to(self.device_torch, dtype=self.torch_dtype)
|
|
if network.is_merged_in:
|
|
network.merge_out(merge_multiplier)
|
|
# self.tokenizer.to(original_device_dict['tokenizer'])
|
|
|
|
# refuse loras
|
|
if self.model_config.assistant_lora_path is not None:
|
|
print("Loading assistant lora")
|
|
if self.invert_assistant_lora:
|
|
self.assistant_lora.is_active = False
|
|
# move weights off the device
|
|
self.assistant_lora.force_to('cpu', self.torch_dtype)
|
|
else:
|
|
self.assistant_lora.is_active = True
|
|
|
|
flush()
|
|
|
|
def get_latent_noise(
|
|
self,
|
|
height=None,
|
|
width=None,
|
|
pixel_height=None,
|
|
pixel_width=None,
|
|
batch_size=1,
|
|
noise_offset=0.0,
|
|
):
|
|
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
|
if height is None and pixel_height is None:
|
|
raise ValueError("height or pixel_height must be specified")
|
|
if width is None and pixel_width is None:
|
|
raise ValueError("width or pixel_width must be specified")
|
|
if height is None:
|
|
height = pixel_height // VAE_SCALE_FACTOR
|
|
if width is None:
|
|
width = pixel_width // VAE_SCALE_FACTOR
|
|
|
|
num_channels = self.unet.config['in_channels']
|
|
if self.is_flux:
|
|
# has 64 channels in for some reason
|
|
num_channels = 16
|
|
noise = torch.randn(
|
|
(
|
|
batch_size,
|
|
num_channels,
|
|
height,
|
|
width,
|
|
),
|
|
device=self.unet.device,
|
|
)
|
|
noise = apply_noise_offset(noise, noise_offset)
|
|
return noise
|
|
|
|
def get_time_ids_from_latents(self, latents: torch.Tensor, requires_aesthetic_score=False):
|
|
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
|
if self.is_xl:
|
|
bs, ch, h, w = list(latents.shape)
|
|
|
|
height = h * VAE_SCALE_FACTOR
|
|
width = w * VAE_SCALE_FACTOR
|
|
|
|
dtype = latents.dtype
|
|
# just do it without any cropping nonsense
|
|
target_size = (height, width)
|
|
original_size = (height, width)
|
|
crops_coords_top_left = (0, 0)
|
|
if requires_aesthetic_score:
|
|
# refiner
|
|
# https://huggingface.co/papers/2307.01952
|
|
aesthetic_score = 6.0 # simulate one
|
|
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
|
|
else:
|
|
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
|
add_time_ids = torch.tensor([add_time_ids])
|
|
add_time_ids = add_time_ids.to(latents.device, dtype=dtype)
|
|
|
|
batch_time_ids = torch.cat(
|
|
[add_time_ids for _ in range(bs)]
|
|
)
|
|
return batch_time_ids
|
|
else:
|
|
return None
|
|
|
|
def add_noise(
|
|
self,
|
|
original_samples: torch.FloatTensor,
|
|
noise: torch.FloatTensor,
|
|
timesteps: torch.IntTensor
|
|
) -> torch.FloatTensor:
|
|
original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0)
|
|
noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
|
|
timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0)
|
|
|
|
if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks):
|
|
timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks)
|
|
|
|
noisy_latents_chunks = []
|
|
|
|
for idx in range(original_samples.shape[0]):
|
|
noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx],
|
|
timesteps_chunks[idx])
|
|
noisy_latents_chunks.append(noisy_latents)
|
|
|
|
noisy_latents = torch.cat(noisy_latents_chunks, dim=0)
|
|
return noisy_latents
|
|
|
|
def predict_noise(
|
|
self,
|
|
latents: torch.Tensor,
|
|
text_embeddings: Union[PromptEmbeds, None] = None,
|
|
timestep: Union[int, torch.Tensor] = 1,
|
|
guidance_scale=7.5,
|
|
guidance_rescale=0,
|
|
add_time_ids=None,
|
|
conditional_embeddings: Union[PromptEmbeds, None] = None,
|
|
unconditional_embeddings: Union[PromptEmbeds, None] = None,
|
|
is_input_scaled=False,
|
|
detach_unconditional=False,
|
|
rescale_cfg=None,
|
|
return_conditional_pred=False,
|
|
**kwargs,
|
|
):
|
|
conditional_pred = None
|
|
# get the embeddings
|
|
if text_embeddings is None and conditional_embeddings is None:
|
|
raise ValueError("Either text_embeddings or conditional_embeddings must be specified")
|
|
if text_embeddings is None and unconditional_embeddings is not None:
|
|
text_embeddings = concat_prompt_embeds([
|
|
unconditional_embeddings, # negative embedding
|
|
conditional_embeddings, # positive embedding
|
|
])
|
|
elif text_embeddings is None and conditional_embeddings is not None:
|
|
# not doing cfg
|
|
text_embeddings = conditional_embeddings
|
|
|
|
# CFG is comparing neg and positive, if we have concatenated embeddings
|
|
# then we are doing it, otherwise we are not and takes half the time.
|
|
do_classifier_free_guidance = True
|
|
|
|
# check if batch size of embeddings matches batch size of latents
|
|
if latents.shape[0] == text_embeddings.text_embeds.shape[0]:
|
|
do_classifier_free_guidance = False
|
|
elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]:
|
|
raise ValueError("Batch size of latents must be the same or half the batch size of text embeddings")
|
|
latents = latents.to(self.device_torch)
|
|
text_embeddings = text_embeddings.to(self.device_torch)
|
|
timestep = timestep.to(self.device_torch)
|
|
|
|
# if timestep is zero dim, unsqueeze it
|
|
if len(timestep.shape) == 0:
|
|
timestep = timestep.unsqueeze(0)
|
|
|
|
# if we only have 1 timestep, we can just use the same timestep for all
|
|
if timestep.shape[0] == 1 and latents.shape[0] > 1:
|
|
# check if it is rank 1 or 2
|
|
if len(timestep.shape) == 1:
|
|
timestep = timestep.repeat(latents.shape[0])
|
|
else:
|
|
timestep = timestep.repeat(latents.shape[0], 0)
|
|
|
|
# handle t2i adapters
|
|
if 'down_intrablock_additional_residuals' in kwargs:
|
|
# go through each item and concat if doing cfg and it doesnt have the same shape
|
|
for idx, item in enumerate(kwargs['down_intrablock_additional_residuals']):
|
|
if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]:
|
|
kwargs['down_intrablock_additional_residuals'][idx] = torch.cat([item] * 2, dim=0)
|
|
|
|
# handle controlnet
|
|
if 'down_block_additional_residuals' in kwargs and 'mid_block_additional_residual' in kwargs:
|
|
# go through each item and concat if doing cfg and it doesnt have the same shape
|
|
for idx, item in enumerate(kwargs['down_block_additional_residuals']):
|
|
if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]:
|
|
kwargs['down_block_additional_residuals'][idx] = torch.cat([item] * 2, dim=0)
|
|
for idx, item in enumerate(kwargs['mid_block_additional_residual']):
|
|
if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]:
|
|
kwargs['mid_block_additional_residual'][idx] = torch.cat([item] * 2, dim=0)
|
|
|
|
def scale_model_input(model_input, timestep_tensor):
|
|
if is_input_scaled:
|
|
return model_input
|
|
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
|
|
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
|
|
out_chunks = []
|
|
# unsqueeze if timestep is zero dim
|
|
for idx in range(model_input.shape[0]):
|
|
# if scheduler has step_index
|
|
if hasattr(self.noise_scheduler, '_step_index'):
|
|
self.noise_scheduler._step_index = None
|
|
out_chunks.append(
|
|
self.noise_scheduler.scale_model_input(mi_chunks[idx], timestep_chunks[idx])
|
|
)
|
|
return torch.cat(out_chunks, dim=0)
|
|
|
|
if self.is_xl:
|
|
with torch.no_grad():
|
|
# 16, 6 for bs of 4
|
|
if add_time_ids is None:
|
|
add_time_ids = self.get_time_ids_from_latents(latents)
|
|
|
|
if do_classifier_free_guidance:
|
|
# todo check this with larget batches
|
|
add_time_ids = torch.cat([add_time_ids] * 2)
|
|
|
|
if do_classifier_free_guidance:
|
|
latent_model_input = torch.cat([latents] * 2)
|
|
timestep = torch.cat([timestep] * 2)
|
|
else:
|
|
latent_model_input = latents
|
|
|
|
latent_model_input = scale_model_input(latent_model_input, timestep)
|
|
|
|
added_cond_kwargs = {
|
|
# todo can we zero here the second text encoder? or match a blank string?
|
|
"text_embeds": text_embeddings.pooled_embeds,
|
|
"time_ids": add_time_ids,
|
|
}
|
|
|
|
if self.model_config.refiner_name_or_path is not None:
|
|
# we have the refiner on the second half of everything. Do Both
|
|
if do_classifier_free_guidance:
|
|
raise ValueError("Refiner is not supported with classifier free guidance")
|
|
|
|
if self.unet.training:
|
|
input_chunks = torch.chunk(latent_model_input, 2, dim=0)
|
|
timestep_chunks = torch.chunk(timestep, 2, dim=0)
|
|
added_cond_kwargs_chunked = {
|
|
"text_embeds": torch.chunk(text_embeddings.pooled_embeds, 2, dim=0),
|
|
"time_ids": torch.chunk(add_time_ids, 2, dim=0),
|
|
}
|
|
text_embeds_chunks = torch.chunk(text_embeddings.text_embeds, 2, dim=0)
|
|
|
|
# predict the noise residual
|
|
base_pred = self.unet(
|
|
input_chunks[0],
|
|
timestep_chunks[0],
|
|
encoder_hidden_states=text_embeds_chunks[0],
|
|
added_cond_kwargs={
|
|
"text_embeds": added_cond_kwargs_chunked['text_embeds'][0],
|
|
"time_ids": added_cond_kwargs_chunked['time_ids'][0],
|
|
},
|
|
**kwargs,
|
|
).sample
|
|
|
|
refiner_pred = self.refiner_unet(
|
|
input_chunks[1],
|
|
timestep_chunks[1],
|
|
encoder_hidden_states=text_embeds_chunks[1][:, :, -1280:],
|
|
# just use the first second text encoder
|
|
added_cond_kwargs={
|
|
"text_embeds": added_cond_kwargs_chunked['text_embeds'][1],
|
|
# "time_ids": added_cond_kwargs_chunked['time_ids'][1],
|
|
"time_ids": self.get_time_ids_from_latents(input_chunks[1], requires_aesthetic_score=True),
|
|
},
|
|
**kwargs,
|
|
).sample
|
|
|
|
noise_pred = torch.cat([base_pred, refiner_pred], dim=0)
|
|
else:
|
|
noise_pred = self.refiner_unet(
|
|
latent_model_input,
|
|
timestep,
|
|
encoder_hidden_states=text_embeddings.text_embeds[:, :, -1280:],
|
|
# just use the first second text encoder
|
|
added_cond_kwargs={
|
|
"text_embeds": text_embeddings.pooled_embeds,
|
|
"time_ids": self.get_time_ids_from_latents(latent_model_input,
|
|
requires_aesthetic_score=True),
|
|
},
|
|
**kwargs,
|
|
).sample
|
|
|
|
else:
|
|
|
|
# predict the noise residual
|
|
noise_pred = self.unet(
|
|
latent_model_input.to(self.device_torch, self.torch_dtype),
|
|
timestep,
|
|
encoder_hidden_states=text_embeddings.text_embeds,
|
|
added_cond_kwargs=added_cond_kwargs,
|
|
**kwargs,
|
|
).sample
|
|
|
|
conditional_pred = noise_pred
|
|
|
|
if do_classifier_free_guidance:
|
|
# perform guidance
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
conditional_pred = noise_pred_text
|
|
noise_pred = noise_pred_uncond + guidance_scale * (
|
|
noise_pred_text - noise_pred_uncond
|
|
)
|
|
|
|
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
|
|
if guidance_rescale > 0.0:
|
|
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
|
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
|
|
|
else:
|
|
with torch.no_grad():
|
|
if do_classifier_free_guidance:
|
|
# if we are doing classifier free guidance, need to double up
|
|
latent_model_input = torch.cat([latents] * 2, dim=0)
|
|
timestep = torch.cat([timestep] * 2)
|
|
else:
|
|
latent_model_input = latents
|
|
|
|
latent_model_input = scale_model_input(latent_model_input, timestep)
|
|
|
|
# check if we need to concat timesteps
|
|
if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1:
|
|
ts_bs = timestep.shape[0]
|
|
if ts_bs != latent_model_input.shape[0]:
|
|
if ts_bs == 1:
|
|
timestep = torch.cat([timestep] * latent_model_input.shape[0])
|
|
elif ts_bs * 2 == latent_model_input.shape[0]:
|
|
timestep = torch.cat([timestep] * 2, dim=0)
|
|
else:
|
|
raise ValueError(
|
|
f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}")
|
|
|
|
# predict the noise residual
|
|
if self.is_pixart:
|
|
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
|
batch_size, ch, h, w = list(latents.shape)
|
|
|
|
height = h * VAE_SCALE_FACTOR
|
|
width = w * VAE_SCALE_FACTOR
|
|
|
|
if self.pipeline.transformer.config.sample_size == 256:
|
|
aspect_ratio_bin = ASPECT_RATIO_2048_BIN
|
|
elif self.pipeline.transformer.config.sample_size == 128:
|
|
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
|
|
elif self.pipeline.transformer.config.sample_size == 64:
|
|
aspect_ratio_bin = ASPECT_RATIO_512_BIN
|
|
elif self.pipeline.transformer.config.sample_size == 32:
|
|
aspect_ratio_bin = ASPECT_RATIO_256_BIN
|
|
else:
|
|
raise ValueError(f"Invalid sample size: {self.pipeline.transformer.config.sample_size}")
|
|
orig_height, orig_width = height, width
|
|
height, width = self.pipeline.image_processor.classify_height_width_bin(height, width,
|
|
ratios=aspect_ratio_bin)
|
|
|
|
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
|
|
if self.unet.config.sample_size == 128 or (
|
|
self.vae_scale_factor == 16 and self.unet.config.sample_size == 64):
|
|
resolution = torch.tensor([height, width]).repeat(batch_size, 1)
|
|
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1)
|
|
resolution = resolution.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch)
|
|
aspect_ratio = aspect_ratio.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch)
|
|
|
|
if do_classifier_free_guidance:
|
|
resolution = torch.cat([resolution, resolution], dim=0)
|
|
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
|
|
|
|
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
|
|
|
|
noise_pred = self.unet(
|
|
latent_model_input.to(self.device_torch, self.torch_dtype),
|
|
encoder_hidden_states=text_embeddings.text_embeds,
|
|
encoder_attention_mask=text_embeddings.attention_mask,
|
|
timestep=timestep,
|
|
added_cond_kwargs=added_cond_kwargs,
|
|
return_dict=False,
|
|
**kwargs
|
|
)[0]
|
|
|
|
# learned sigma
|
|
if self.unet.config.out_channels // 2 == self.unet.config.in_channels:
|
|
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
|
else:
|
|
noise_pred = noise_pred
|
|
else:
|
|
if self.unet.device != self.device_torch:
|
|
self.unet.to(self.device_torch)
|
|
if self.unet.dtype != self.torch_dtype:
|
|
self.unet = self.unet.to(dtype=self.torch_dtype)
|
|
if self.is_flux:
|
|
with torch.no_grad():
|
|
|
|
bs, c, h, w = latent_model_input.shape
|
|
latent_model_input_packed = rearrange(
|
|
latent_model_input,
|
|
"b c (h ph) (w pw) -> b (h w) (c ph pw)",
|
|
ph=2,
|
|
pw=2
|
|
)
|
|
|
|
img_ids = torch.zeros(h // 2, w // 2, 3)
|
|
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
|
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs).to(self.device_torch)
|
|
|
|
txt_ids = torch.zeros(bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch)
|
|
|
|
# # handle guidance
|
|
guidance_scale = 1.0 # ?
|
|
if self.unet.config.guidance_embeds:
|
|
guidance = torch.tensor([guidance_scale], device=self.device_torch)
|
|
guidance = guidance.expand(latents.shape[0])
|
|
else:
|
|
guidance = None
|
|
|
|
cast_dtype = self.unet.dtype
|
|
# with torch.amp.autocast(device_type='cuda', dtype=cast_dtype):
|
|
noise_pred = self.unet(
|
|
hidden_states=latent_model_input_packed.to(self.device_torch, cast_dtype), # [1, 4096, 64]
|
|
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
|
# todo make sure this doesnt change
|
|
timestep=timestep / 1000, # timestep is 1000 scale
|
|
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype),
|
|
# [1, 512, 4096]
|
|
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, cast_dtype), # [1, 768]
|
|
txt_ids=txt_ids, # [1, 512, 3]
|
|
img_ids=img_ids, # [1, 4096, 3]
|
|
guidance=guidance,
|
|
return_dict=False,
|
|
**kwargs,
|
|
)[0]
|
|
|
|
if isinstance(noise_pred, QTensor):
|
|
noise_pred = noise_pred.dequantize()
|
|
|
|
noise_pred = rearrange(
|
|
noise_pred,
|
|
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
|
h=latent_model_input.shape[2] // 2,
|
|
w=latent_model_input.shape[3] // 2,
|
|
ph=2,
|
|
pw=2,
|
|
c=latent_model_input.shape[1],
|
|
)
|
|
elif self.is_v3:
|
|
noise_pred = self.unet(
|
|
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
|
timestep=timestep,
|
|
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
|
|
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype),
|
|
**kwargs,
|
|
).sample
|
|
elif self.is_auraflow:
|
|
# aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image
|
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
|
t = torch.tensor([timestep / 1000]).expand(latent_model_input.shape[0])
|
|
t = t.to(self.device_torch, self.torch_dtype)
|
|
|
|
noise_pred = self.unet(
|
|
latent_model_input,
|
|
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
|
|
timestep=t,
|
|
return_dict=False,
|
|
)[0]
|
|
else:
|
|
noise_pred = self.unet(
|
|
latent_model_input.to(self.device_torch, self.torch_dtype),
|
|
timestep=timestep,
|
|
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
|
|
**kwargs,
|
|
).sample
|
|
|
|
conditional_pred = noise_pred
|
|
|
|
if do_classifier_free_guidance:
|
|
# perform guidance
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2, dim=0)
|
|
conditional_pred = noise_pred_text
|
|
if detach_unconditional:
|
|
noise_pred_uncond = noise_pred_uncond.detach()
|
|
noise_pred = noise_pred_uncond + guidance_scale * (
|
|
noise_pred_text - noise_pred_uncond
|
|
)
|
|
if rescale_cfg is not None and rescale_cfg != guidance_scale:
|
|
with torch.no_grad():
|
|
# do cfg at the target rescale so we can match it
|
|
target_pred_mean_std = noise_pred_uncond + rescale_cfg * (
|
|
noise_pred_text - noise_pred_uncond
|
|
)
|
|
target_mean = target_pred_mean_std.mean([1, 2, 3], keepdim=True).detach()
|
|
target_std = target_pred_mean_std.std([1, 2, 3], keepdim=True).detach()
|
|
|
|
pred_mean = noise_pred.mean([1, 2, 3], keepdim=True).detach()
|
|
pred_std = noise_pred.std([1, 2, 3], keepdim=True).detach()
|
|
|
|
# match the mean and std
|
|
noise_pred = (noise_pred - pred_mean) / pred_std
|
|
noise_pred = (noise_pred * target_std) + target_mean
|
|
|
|
# https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775
|
|
if guidance_rescale > 0.0:
|
|
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
|
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
|
|
|
if return_conditional_pred:
|
|
return noise_pred, conditional_pred
|
|
return noise_pred
|
|
|
|
def step_scheduler(self, model_input, latent_input, timestep_tensor, noise_scheduler=None):
|
|
if noise_scheduler is None:
|
|
noise_scheduler = self.noise_scheduler
|
|
# // sometimes they are on the wrong device, no idea why
|
|
if isinstance(noise_scheduler, DDPMScheduler) or isinstance(noise_scheduler, LCMScheduler):
|
|
try:
|
|
noise_scheduler.betas = noise_scheduler.betas.to(self.device_torch)
|
|
noise_scheduler.alphas = noise_scheduler.alphas.to(self.device_torch)
|
|
noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.to(self.device_torch)
|
|
except Exception as e:
|
|
pass
|
|
|
|
mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0)
|
|
latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0)
|
|
timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0)
|
|
out_chunks = []
|
|
if len(timestep_chunks) == 1 and len(mi_chunks) > 1:
|
|
# expand timestep to match
|
|
timestep_chunks = timestep_chunks * len(mi_chunks)
|
|
|
|
for idx in range(model_input.shape[0]):
|
|
# Reset it so it is unique for the
|
|
if hasattr(noise_scheduler, '_step_index'):
|
|
noise_scheduler._step_index = None
|
|
if hasattr(noise_scheduler, 'is_scale_input_called'):
|
|
noise_scheduler.is_scale_input_called = True
|
|
out_chunks.append(
|
|
noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[
|
|
0]
|
|
)
|
|
return torch.cat(out_chunks, dim=0)
|
|
|
|
# ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746
|
|
def diffuse_some_steps(
|
|
self,
|
|
latents: torch.FloatTensor,
|
|
text_embeddings: PromptEmbeds,
|
|
total_timesteps: int = 1000,
|
|
start_timesteps=0,
|
|
guidance_scale=1,
|
|
add_time_ids=None,
|
|
bleed_ratio: float = 0.5,
|
|
bleed_latents: torch.FloatTensor = None,
|
|
is_input_scaled=False,
|
|
return_first_prediction=False,
|
|
**kwargs,
|
|
):
|
|
timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps]
|
|
|
|
first_prediction = None
|
|
|
|
for timestep in tqdm(timesteps_to_run, leave=False):
|
|
timestep = timestep.unsqueeze_(0)
|
|
noise_pred, conditional_pred = self.predict_noise(
|
|
latents,
|
|
text_embeddings,
|
|
timestep,
|
|
guidance_scale=guidance_scale,
|
|
add_time_ids=add_time_ids,
|
|
is_input_scaled=is_input_scaled,
|
|
return_conditional_pred=True,
|
|
**kwargs,
|
|
)
|
|
# some schedulers need to run separately, so do that. (euler for example)
|
|
|
|
if return_first_prediction and first_prediction is None:
|
|
first_prediction = conditional_pred
|
|
|
|
latents = self.step_scheduler(noise_pred, latents, timestep)
|
|
|
|
# if not last step, and bleeding, bleed in some latents
|
|
if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]:
|
|
latents = (latents * (1 - bleed_ratio)) + (bleed_latents * bleed_ratio)
|
|
|
|
# only skip first scaling
|
|
is_input_scaled = False
|
|
|
|
# return latents_steps
|
|
if return_first_prediction:
|
|
return latents, first_prediction
|
|
return latents
|
|
|
|
def encode_prompt(
|
|
self,
|
|
prompt,
|
|
prompt2=None,
|
|
num_images_per_prompt=1,
|
|
force_all=False,
|
|
long_prompts=False,
|
|
max_length=None,
|
|
dropout_prob=0.0,
|
|
) -> PromptEmbeds:
|
|
# sd1.5 embeddings are (bs, 77, 768)
|
|
prompt = prompt
|
|
# if it is not a list, make it one
|
|
if not isinstance(prompt, list):
|
|
prompt = [prompt]
|
|
|
|
if prompt2 is not None and not isinstance(prompt2, list):
|
|
prompt2 = [prompt2]
|
|
if self.is_xl:
|
|
# todo make this a config
|
|
# 50% chance to use an encoder anyway even if it is disabled
|
|
# allows the other TE to compensate for the disabled one
|
|
# use_encoder_1 = self.use_text_encoder_1 or force_all or random.random() > 0.5
|
|
# use_encoder_2 = self.use_text_encoder_2 or force_all or random.random() > 0.5
|
|
use_encoder_1 = True
|
|
use_encoder_2 = True
|
|
|
|
return PromptEmbeds(
|
|
train_tools.encode_prompts_xl(
|
|
self.tokenizer,
|
|
self.text_encoder,
|
|
prompt,
|
|
prompt2,
|
|
num_images_per_prompt=num_images_per_prompt,
|
|
use_text_encoder_1=use_encoder_1,
|
|
use_text_encoder_2=use_encoder_2,
|
|
truncate=not long_prompts,
|
|
max_length=max_length,
|
|
dropout_prob=dropout_prob,
|
|
)
|
|
)
|
|
if self.is_v3:
|
|
return PromptEmbeds(
|
|
train_tools.encode_prompts_sd3(
|
|
self.tokenizer,
|
|
self.text_encoder,
|
|
prompt,
|
|
num_images_per_prompt=num_images_per_prompt,
|
|
truncate=not long_prompts,
|
|
max_length=max_length,
|
|
dropout_prob=dropout_prob,
|
|
pipeline=self.pipeline,
|
|
)
|
|
)
|
|
elif self.is_pixart:
|
|
embeds, attention_mask = train_tools.encode_prompts_pixart(
|
|
self.tokenizer,
|
|
self.text_encoder,
|
|
prompt,
|
|
truncate=not long_prompts,
|
|
max_length=300 if self.model_config.is_pixart_sigma else 120,
|
|
dropout_prob=dropout_prob
|
|
)
|
|
return PromptEmbeds(
|
|
embeds,
|
|
attention_mask=attention_mask,
|
|
)
|
|
elif self.is_auraflow:
|
|
embeds, attention_mask = train_tools.encode_prompts_auraflow(
|
|
self.tokenizer,
|
|
self.text_encoder,
|
|
prompt,
|
|
truncate=not long_prompts,
|
|
max_length=256,
|
|
dropout_prob=dropout_prob
|
|
)
|
|
return PromptEmbeds(
|
|
embeds,
|
|
attention_mask=attention_mask, # not used
|
|
)
|
|
elif self.is_flux:
|
|
prompt_embeds, pooled_prompt_embeds = train_tools.encode_prompts_flux(
|
|
self.tokenizer, # list
|
|
self.text_encoder, # list
|
|
prompt,
|
|
truncate=not long_prompts,
|
|
max_length=512,
|
|
dropout_prob=dropout_prob
|
|
)
|
|
pe = PromptEmbeds(
|
|
prompt_embeds
|
|
)
|
|
pe.pooled_embeds = pooled_prompt_embeds
|
|
return pe
|
|
|
|
|
|
elif isinstance(self.text_encoder, T5EncoderModel):
|
|
embeds, attention_mask = train_tools.encode_prompts_pixart(
|
|
self.tokenizer,
|
|
self.text_encoder,
|
|
prompt,
|
|
truncate=not long_prompts,
|
|
max_length=77, # todo set this higher when not transfer learning
|
|
dropout_prob=dropout_prob
|
|
)
|
|
return PromptEmbeds(
|
|
embeds,
|
|
# do we want attn mask here?
|
|
# attention_mask=attention_mask,
|
|
)
|
|
else:
|
|
return PromptEmbeds(
|
|
train_tools.encode_prompts(
|
|
self.tokenizer,
|
|
self.text_encoder,
|
|
prompt,
|
|
truncate=not long_prompts,
|
|
max_length=max_length,
|
|
dropout_prob=dropout_prob
|
|
)
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def encode_images(
|
|
self,
|
|
image_list: List[torch.Tensor],
|
|
device=None,
|
|
dtype=None
|
|
):
|
|
if device is None:
|
|
device = self.vae_device_torch
|
|
if dtype is None:
|
|
dtype = self.vae_torch_dtype
|
|
|
|
latent_list = []
|
|
# Move to vae to device if on cpu
|
|
if self.vae.device == 'cpu':
|
|
self.vae.to(device)
|
|
self.vae.eval()
|
|
self.vae.requires_grad_(False)
|
|
# move to device and dtype
|
|
image_list = [image.to(device, dtype=dtype) for image in image_list]
|
|
|
|
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1)
|
|
|
|
# resize images if not divisible by 8
|
|
for i in range(len(image_list)):
|
|
image = image_list[i]
|
|
if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0:
|
|
image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR,
|
|
image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image)
|
|
|
|
images = torch.stack(image_list)
|
|
if isinstance(self.vae, AutoencoderTiny):
|
|
latents = self.vae.encode(images, return_dict=False)[0]
|
|
else:
|
|
latents = self.vae.encode(images).latent_dist.sample()
|
|
shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0
|
|
|
|
# flux ref https://github.com/black-forest-labs/flux/blob/c23ae247225daba30fbd56058d247cc1b1fc20a3/src/flux/modules/autoencoder.py#L303
|
|
# z = self.scale_factor * (z - self.shift_factor)
|
|
latents = self.vae.config['scaling_factor'] * (latents - shift)
|
|
latents = latents.to(device, dtype=dtype)
|
|
|
|
return latents
|
|
|
|
def decode_latents(
|
|
self,
|
|
latents: torch.Tensor,
|
|
device=None,
|
|
dtype=None
|
|
):
|
|
if device is None:
|
|
device = self.device
|
|
if dtype is None:
|
|
dtype = self.torch_dtype
|
|
|
|
# Move to vae to device if on cpu
|
|
if self.vae.device == 'cpu':
|
|
self.vae.to(self.device)
|
|
latents = latents.to(device, dtype=dtype)
|
|
latents = (latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor']
|
|
images = self.vae.decode(latents).sample
|
|
images = images.to(device, dtype=dtype)
|
|
|
|
return images
|
|
|
|
def encode_image_prompt_pairs(
|
|
self,
|
|
prompt_list: List[str],
|
|
image_list: List[torch.Tensor],
|
|
device=None,
|
|
dtype=None
|
|
):
|
|
# todo check image types and expand and rescale as needed
|
|
# device and dtype are for outputs
|
|
if device is None:
|
|
device = self.device
|
|
if dtype is None:
|
|
dtype = self.torch_dtype
|
|
|
|
embedding_list = []
|
|
latent_list = []
|
|
# embed the prompts
|
|
for prompt in prompt_list:
|
|
embedding = self.encode_prompt(prompt).to(self.device_torch, dtype=dtype)
|
|
embedding_list.append(embedding)
|
|
|
|
return embedding_list, latent_list
|
|
|
|
def get_weight_by_name(self, name):
|
|
# weights begin with te{te_num}_ for text encoder
|
|
# weights begin with unet_ for unet_
|
|
if name.startswith('te'):
|
|
key = name[4:]
|
|
# text encoder
|
|
te_num = int(name[2])
|
|
if isinstance(self.text_encoder, list):
|
|
return self.text_encoder[te_num].state_dict()[key]
|
|
else:
|
|
return self.text_encoder.state_dict()[key]
|
|
elif name.startswith('unet'):
|
|
key = name[5:]
|
|
# unet
|
|
return self.unet.state_dict()[key]
|
|
|
|
raise ValueError(f"Unknown weight name: {name}")
|
|
|
|
def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=False):
|
|
return inject_trigger_into_prompt(
|
|
prompt,
|
|
trigger=trigger,
|
|
to_replace_list=to_replace_list,
|
|
add_if_not_present=add_if_not_present,
|
|
)
|
|
|
|
def state_dict(self, vae=True, text_encoder=True, unet=True):
|
|
state_dict = OrderedDict()
|
|
if vae:
|
|
for k, v in self.vae.state_dict().items():
|
|
new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}"
|
|
state_dict[new_key] = v
|
|
if text_encoder:
|
|
if isinstance(self.text_encoder, list):
|
|
for i, encoder in enumerate(self.text_encoder):
|
|
for k, v in encoder.state_dict().items():
|
|
new_key = k if k.startswith(
|
|
f"{SD_PREFIX_TEXT_ENCODER}{i}_") else f"{SD_PREFIX_TEXT_ENCODER}{i}_{k}"
|
|
state_dict[new_key] = v
|
|
else:
|
|
for k, v in self.text_encoder.state_dict().items():
|
|
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER}_") else f"{SD_PREFIX_TEXT_ENCODER}_{k}"
|
|
state_dict[new_key] = v
|
|
if unet:
|
|
for k, v in self.unet.state_dict().items():
|
|
new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}"
|
|
state_dict[new_key] = v
|
|
return state_dict
|
|
|
|
def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> \
|
|
OrderedDict[
|
|
str, Parameter]:
|
|
named_params: OrderedDict[str, Parameter] = OrderedDict()
|
|
if vae:
|
|
for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"):
|
|
named_params[name] = param
|
|
if text_encoder:
|
|
if isinstance(self.text_encoder, list):
|
|
for i, encoder in enumerate(self.text_encoder):
|
|
if self.is_xl and not self.model_config.use_text_encoder_1 and i == 0:
|
|
# dont add these params
|
|
continue
|
|
if self.is_xl and not self.model_config.use_text_encoder_2 and i == 1:
|
|
# dont add these params
|
|
continue
|
|
|
|
for name, param in encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}{i}"):
|
|
named_params[name] = param
|
|
else:
|
|
for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"):
|
|
named_params[name] = param
|
|
if unet:
|
|
if self.is_flux:
|
|
# Just train the middle 2 blocks of each transformer block
|
|
# block_list = []
|
|
# num_transformer_blocks = 2
|
|
# start_block = len(self.unet.transformer_blocks) // 2 - (num_transformer_blocks // 2)
|
|
# for i in range(num_transformer_blocks):
|
|
# block_list.append(self.unet.transformer_blocks[start_block + i])
|
|
#
|
|
# num_single_transformer_blocks = 4
|
|
# start_block = len(self.unet.single_transformer_blocks) // 2 - (num_single_transformer_blocks // 2)
|
|
# for i in range(num_single_transformer_blocks):
|
|
# block_list.append(self.unet.single_transformer_blocks[start_block + i])
|
|
#
|
|
# for block in block_list:
|
|
# for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
|
# named_params[name] = param
|
|
|
|
# train the guidance embedding
|
|
if self.unet.config.guidance_embeds:
|
|
transformer: FluxTransformer2DModel = self.unet
|
|
for name, param in transformer.time_text_embed.named_parameters(recurse=True,
|
|
prefix=f"{SD_PREFIX_UNET}"):
|
|
named_params[name] = param
|
|
|
|
for name, param in self.unet.transformer_blocks.named_parameters(recurse=True,
|
|
prefix=f"{SD_PREFIX_UNET}"):
|
|
named_params[name] = param
|
|
for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True,
|
|
prefix=f"{SD_PREFIX_UNET}"):
|
|
named_params[name] = param
|
|
else:
|
|
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
|
named_params[name] = param
|
|
|
|
if refiner:
|
|
for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"):
|
|
named_params[name] = param
|
|
|
|
# convert to state dict keys, jsut replace . with _ on keys
|
|
if state_dict_keys:
|
|
new_named_params = OrderedDict()
|
|
for k, v in named_params.items():
|
|
# replace only the first . with an _
|
|
new_key = k.replace('.', '_', 1)
|
|
new_named_params[new_key] = v
|
|
named_params = new_named_params
|
|
|
|
return named_params
|
|
|
|
def save_refiner(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16')):
|
|
|
|
# load the full refiner since we only train unet
|
|
if self.model_config.refiner_name_or_path is None:
|
|
raise ValueError("Refiner must be specified to save it")
|
|
refiner_config_path = os.path.join(ORIG_CONFIGS_ROOT, 'sd_xl_refiner.yaml')
|
|
# load the refiner model
|
|
dtype = get_torch_dtype(self.dtype)
|
|
model_path = self.model_config._original_refiner_name_or_path
|
|
if not os.path.exists(model_path) or os.path.isdir(model_path):
|
|
# TODO only load unet??
|
|
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
|
model_path,
|
|
dtype=dtype,
|
|
device='cpu',
|
|
# variant="fp16",
|
|
use_safetensors=True,
|
|
)
|
|
else:
|
|
refiner = StableDiffusionXLImg2ImgPipeline.from_single_file(
|
|
model_path,
|
|
dtype=dtype,
|
|
device='cpu',
|
|
torch_dtype=self.torch_dtype,
|
|
original_config_file=refiner_config_path,
|
|
)
|
|
# replace original unet
|
|
refiner.unet = self.refiner_unet
|
|
flush()
|
|
|
|
diffusers_state_dict = OrderedDict()
|
|
for k, v in refiner.vae.state_dict().items():
|
|
new_key = k if k.startswith(f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}"
|
|
diffusers_state_dict[new_key] = v
|
|
for k, v in refiner.text_encoder_2.state_dict().items():
|
|
new_key = k if k.startswith(f"{SD_PREFIX_TEXT_ENCODER2}_") else f"{SD_PREFIX_TEXT_ENCODER2}_{k}"
|
|
diffusers_state_dict[new_key] = v
|
|
for k, v in refiner.unet.state_dict().items():
|
|
new_key = k if k.startswith(f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}"
|
|
diffusers_state_dict[new_key] = v
|
|
|
|
converted_state_dict = get_ldm_state_dict_from_diffusers(
|
|
diffusers_state_dict,
|
|
'sdxl_refiner',
|
|
device='cpu',
|
|
dtype=save_dtype
|
|
)
|
|
|
|
# make sure parent folder exists
|
|
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
|
save_file(converted_state_dict, output_file, metadata=meta)
|
|
|
|
if self.config_file is not None:
|
|
output_path_no_ext = os.path.splitext(output_file)[0]
|
|
output_config_path = f"{output_path_no_ext}.yaml"
|
|
shutil.copyfile(self.config_file, output_config_path)
|
|
|
|
def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None):
|
|
version_string = '1'
|
|
if self.is_v2:
|
|
version_string = '2'
|
|
if self.is_xl:
|
|
version_string = 'sdxl'
|
|
if self.is_ssd:
|
|
# overwrite sdxl because both wil be true here
|
|
version_string = 'ssd'
|
|
if self.is_ssd and self.is_vega:
|
|
version_string = 'vega'
|
|
# if output file does not end in .safetensors, then it is a directory and we are
|
|
# saving in diffusers format
|
|
if not output_file.endswith('.safetensors'):
|
|
# diffusers
|
|
# if self.is_pixart:
|
|
# self.unet.save_pretrained(
|
|
# save_directory=output_file,
|
|
# safe_serialization=True,
|
|
# )
|
|
# else:
|
|
if self.is_flux:
|
|
# only save the unet
|
|
transformer: FluxTransformer2DModel = self.unet
|
|
transformer.save_pretrained(
|
|
save_directory=os.path.join(output_file, 'transformer'),
|
|
safe_serialization=True,
|
|
)
|
|
else:
|
|
|
|
self.pipeline.save_pretrained(
|
|
save_directory=output_file,
|
|
safe_serialization=True,
|
|
)
|
|
# save out meta config
|
|
meta_path = os.path.join(output_file, 'aitk_meta.yaml')
|
|
with open(meta_path, 'w') as f:
|
|
yaml.dump(meta, f)
|
|
|
|
else:
|
|
save_ldm_model_from_diffusers(
|
|
sd=self,
|
|
output_file=output_file,
|
|
meta=meta,
|
|
save_dtype=save_dtype,
|
|
sd_version=version_string,
|
|
)
|
|
if self.config_file is not None:
|
|
output_path_no_ext = os.path.splitext(output_file)[0]
|
|
output_config_path = f"{output_path_no_ext}.yaml"
|
|
shutil.copyfile(self.config_file, output_config_path)
|
|
|
|
def prepare_optimizer_params(
|
|
self,
|
|
unet=False,
|
|
text_encoder=False,
|
|
text_encoder_lr=None,
|
|
unet_lr=None,
|
|
refiner_lr=None,
|
|
refiner=False,
|
|
default_lr=1e-6,
|
|
):
|
|
# todo maybe only get locon ones?
|
|
# not all items are saved, to make it match, we need to match out save mappings
|
|
# and not train anything not mapped. Also add learning rate
|
|
version = 'sd1'
|
|
if self.is_xl:
|
|
version = 'sdxl'
|
|
if self.is_v2:
|
|
version = 'sd2'
|
|
mapping_filename = f"stable_diffusion_{version}.json"
|
|
mapping_path = os.path.join(KEYMAPS_ROOT, mapping_filename)
|
|
with open(mapping_path, 'r') as f:
|
|
mapping = json.load(f)
|
|
ldm_diffusers_keymap = mapping['ldm_diffusers_keymap']
|
|
|
|
trainable_parameters = []
|
|
|
|
# we use state dict to find params
|
|
|
|
if unet:
|
|
named_params = self.named_parameters(vae=False, unet=unet, text_encoder=False, state_dict_keys=True)
|
|
unet_lr = unet_lr if unet_lr is not None else default_lr
|
|
params = []
|
|
if self.is_pixart or self.is_auraflow or self.is_flux:
|
|
for param in named_params.values():
|
|
if param.requires_grad:
|
|
params.append(param)
|
|
else:
|
|
for key, diffusers_key in ldm_diffusers_keymap.items():
|
|
if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS:
|
|
if named_params[diffusers_key].requires_grad:
|
|
params.append(named_params[diffusers_key])
|
|
param_data = {"params": params, "lr": unet_lr}
|
|
trainable_parameters.append(param_data)
|
|
print(f"Found {len(params)} trainable parameter in unet")
|
|
|
|
if text_encoder:
|
|
named_params = self.named_parameters(vae=False, unet=False, text_encoder=text_encoder, state_dict_keys=True)
|
|
text_encoder_lr = text_encoder_lr if text_encoder_lr is not None else default_lr
|
|
params = []
|
|
for key, diffusers_key in ldm_diffusers_keymap.items():
|
|
if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS:
|
|
if named_params[diffusers_key].requires_grad:
|
|
params.append(named_params[diffusers_key])
|
|
param_data = {"params": params, "lr": text_encoder_lr}
|
|
trainable_parameters.append(param_data)
|
|
|
|
print(f"Found {len(params)} trainable parameter in text encoder")
|
|
|
|
if refiner:
|
|
named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True,
|
|
state_dict_keys=True)
|
|
refiner_lr = refiner_lr if refiner_lr is not None else default_lr
|
|
params = []
|
|
for key, diffusers_key in ldm_diffusers_keymap.items():
|
|
diffusers_key = f"refiner_{diffusers_key}"
|
|
if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS:
|
|
if named_params[diffusers_key].requires_grad:
|
|
params.append(named_params[diffusers_key])
|
|
param_data = {"params": params, "lr": refiner_lr}
|
|
trainable_parameters.append(param_data)
|
|
|
|
print(f"Found {len(params)} trainable parameter in refiner")
|
|
|
|
return trainable_parameters
|
|
|
|
def save_device_state(self):
|
|
# saves the current device state for all modules
|
|
# this is useful for when we want to alter the state and restore it
|
|
if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux:
|
|
unet_has_grad = self.unet.proj_out.weight.requires_grad
|
|
else:
|
|
unet_has_grad = self.unet.conv_in.weight.requires_grad
|
|
|
|
self.device_state = {
|
|
**empty_preset,
|
|
'vae': {
|
|
'training': self.vae.training,
|
|
'device': self.vae.device,
|
|
},
|
|
'unet': {
|
|
'training': self.unet.training,
|
|
'device': self.unet.device,
|
|
'requires_grad': unet_has_grad,
|
|
},
|
|
}
|
|
if isinstance(self.text_encoder, list):
|
|
self.device_state['text_encoder']: List[dict] = []
|
|
for encoder in self.text_encoder:
|
|
try:
|
|
te_has_grad = encoder.text_model.final_layer_norm.weight.requires_grad
|
|
except:
|
|
te_has_grad = encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
|
|
self.device_state['text_encoder'].append({
|
|
'training': encoder.training,
|
|
'device': encoder.device,
|
|
# todo there has to be a better way to do this
|
|
'requires_grad': te_has_grad
|
|
})
|
|
else:
|
|
if isinstance(self.text_encoder, T5EncoderModel) or isinstance(self.text_encoder, UMT5EncoderModel):
|
|
te_has_grad = self.text_encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad
|
|
else:
|
|
te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad
|
|
|
|
self.device_state['text_encoder'] = {
|
|
'training': self.text_encoder.training,
|
|
'device': self.text_encoder.device,
|
|
'requires_grad': te_has_grad
|
|
}
|
|
if self.adapter is not None:
|
|
if isinstance(self.adapter, IPAdapter):
|
|
requires_grad = self.adapter.image_proj_model.training
|
|
adapter_device = self.unet.device
|
|
elif isinstance(self.adapter, T2IAdapter):
|
|
requires_grad = self.adapter.adapter.conv_in.weight.requires_grad
|
|
adapter_device = self.adapter.device
|
|
elif isinstance(self.adapter, ControlNetModel):
|
|
requires_grad = self.adapter.conv_in.training
|
|
adapter_device = self.adapter.device
|
|
elif isinstance(self.adapter, ClipVisionAdapter):
|
|
requires_grad = self.adapter.embedder.training
|
|
adapter_device = self.adapter.device
|
|
elif isinstance(self.adapter, CustomAdapter):
|
|
requires_grad = self.adapter.training
|
|
adapter_device = self.adapter.device
|
|
elif isinstance(self.adapter, ReferenceAdapter):
|
|
# todo update this!!
|
|
requires_grad = True
|
|
adapter_device = self.adapter.device
|
|
else:
|
|
raise ValueError(f"Unknown adapter type: {type(self.adapter)}")
|
|
self.device_state['adapter'] = {
|
|
'training': self.adapter.training,
|
|
'device': adapter_device,
|
|
'requires_grad': requires_grad,
|
|
}
|
|
|
|
if self.refiner_unet is not None:
|
|
self.device_state['refiner_unet'] = {
|
|
'training': self.refiner_unet.training,
|
|
'device': self.refiner_unet.device,
|
|
'requires_grad': self.refiner_unet.conv_in.weight.requires_grad,
|
|
}
|
|
|
|
def restore_device_state(self):
|
|
# restores the device state for all modules
|
|
# this is useful for when we want to alter the state and restore it
|
|
if self.device_state is None:
|
|
return
|
|
self.set_device_state(self.device_state)
|
|
self.device_state = None
|
|
|
|
def set_device_state(self, state):
|
|
if state['vae']['training']:
|
|
self.vae.train()
|
|
else:
|
|
self.vae.eval()
|
|
self.vae.to(state['vae']['device'])
|
|
if state['unet']['training']:
|
|
self.unet.train()
|
|
else:
|
|
self.unet.eval()
|
|
self.unet.to(state['unet']['device'])
|
|
if state['unet']['requires_grad']:
|
|
self.unet.requires_grad_(True)
|
|
else:
|
|
self.unet.requires_grad_(False)
|
|
if isinstance(self.text_encoder, list):
|
|
for i, encoder in enumerate(self.text_encoder):
|
|
if isinstance(state['text_encoder'], list):
|
|
if state['text_encoder'][i]['training']:
|
|
encoder.train()
|
|
else:
|
|
encoder.eval()
|
|
encoder.to(state['text_encoder'][i]['device'])
|
|
encoder.requires_grad_(state['text_encoder'][i]['requires_grad'])
|
|
else:
|
|
if state['text_encoder']['training']:
|
|
encoder.train()
|
|
else:
|
|
encoder.eval()
|
|
encoder.to(state['text_encoder']['device'])
|
|
encoder.requires_grad_(state['text_encoder']['requires_grad'])
|
|
else:
|
|
if state['text_encoder']['training']:
|
|
self.text_encoder.train()
|
|
else:
|
|
self.text_encoder.eval()
|
|
self.text_encoder.to(state['text_encoder']['device'])
|
|
self.text_encoder.requires_grad_(state['text_encoder']['requires_grad'])
|
|
|
|
if self.adapter is not None:
|
|
self.adapter.to(state['adapter']['device'])
|
|
self.adapter.requires_grad_(state['adapter']['requires_grad'])
|
|
if state['adapter']['training']:
|
|
self.adapter.train()
|
|
else:
|
|
self.adapter.eval()
|
|
|
|
if self.refiner_unet is not None:
|
|
self.refiner_unet.to(state['refiner_unet']['device'])
|
|
self.refiner_unet.requires_grad_(state['refiner_unet']['requires_grad'])
|
|
if state['refiner_unet']['training']:
|
|
self.refiner_unet.train()
|
|
else:
|
|
self.refiner_unet.eval()
|
|
flush()
|
|
|
|
def set_device_state_preset(self, device_state_preset: DeviceStatePreset):
|
|
# sets a preset for device state
|
|
|
|
# save current state first
|
|
self.save_device_state()
|
|
|
|
active_modules = []
|
|
training_modules = []
|
|
if device_state_preset in ['cache_latents']:
|
|
active_modules = ['vae']
|
|
if device_state_preset in ['cache_clip']:
|
|
active_modules = ['clip']
|
|
if device_state_preset in ['generate']:
|
|
active_modules = ['vae', 'unet', 'text_encoder', 'adapter', 'refiner_unet']
|
|
|
|
state = copy.deepcopy(empty_preset)
|
|
# vae
|
|
state['vae'] = {
|
|
'training': 'vae' in training_modules,
|
|
'device': self.vae_device_torch if 'vae' in active_modules else 'cpu',
|
|
'requires_grad': 'vae' in training_modules,
|
|
}
|
|
|
|
# unet
|
|
state['unet'] = {
|
|
'training': 'unet' in training_modules,
|
|
'device': self.device_torch if 'unet' in active_modules else 'cpu',
|
|
'requires_grad': 'unet' in training_modules,
|
|
}
|
|
|
|
if self.refiner_unet is not None:
|
|
state['refiner_unet'] = {
|
|
'training': 'refiner_unet' in training_modules,
|
|
'device': self.device_torch if 'refiner_unet' in active_modules else 'cpu',
|
|
'requires_grad': 'refiner_unet' in training_modules,
|
|
}
|
|
|
|
# text encoder
|
|
if isinstance(self.text_encoder, list):
|
|
state['text_encoder'] = []
|
|
for i, encoder in enumerate(self.text_encoder):
|
|
state['text_encoder'].append({
|
|
'training': 'text_encoder' in training_modules,
|
|
'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu',
|
|
'requires_grad': 'text_encoder' in training_modules,
|
|
})
|
|
else:
|
|
state['text_encoder'] = {
|
|
'training': 'text_encoder' in training_modules,
|
|
'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu',
|
|
'requires_grad': 'text_encoder' in training_modules,
|
|
}
|
|
|
|
if self.adapter is not None:
|
|
state['adapter'] = {
|
|
'training': 'adapter' in training_modules,
|
|
'device': self.device_torch if 'adapter' in active_modules else 'cpu',
|
|
'requires_grad': 'adapter' in training_modules,
|
|
}
|
|
|
|
self.set_device_state(state)
|