diff --git a/.gitmodules b/.gitmodules index 657cf28b..a98073dc 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,12 +1,16 @@ [submodule "repositories/sd-scripts"] path = repositories/sd-scripts url = https://github.com/kohya-ss/sd-scripts.git + commit = b78c0e2a69e52ce6c79abc6c8c82d1a9cabcf05c [submodule "repositories/leco"] path = repositories/leco url = https://github.com/p1atdev/LECO + commit = 9294adf40218e917df4516737afb13f069a6789d [submodule "repositories/batch_annotator"] path = repositories/batch_annotator url = https://github.com/ostris/batch-annotator + commit = 420e142f6ad3cc14b3ea0500affc2c6c7e7544bf [submodule "repositories/ipadapter"] path = repositories/ipadapter url = https://github.com/tencent-ailab/IP-Adapter.git + commit = 5a18b1f3660acaf8bee8250692d6fb3548a19b14 diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 827e05c4..3fe2f2f5 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -380,9 +380,19 @@ class SDTrainer(BaseSDTrainProcess): elif self.sd.prediction_type == 'v_prediction': # v-parameterization training target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps) - + + elif hasattr(self.sd, 'get_loss_target'): + target = self.sd.get_loss_target( + noise=noise, + batch=batch, + timesteps=timesteps, + ).detach() + elif self.sd.is_flow_matching: + # forward ODE target = (noise - batch.latents).detach() + # reverse ODE + # target = (batch.latents - noise).detach() else: target = noise diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 2482c26d..77122af9 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -68,6 +68,8 @@ import transformers import diffusers import hashlib +from toolkit.util.get_model import get_model_class + def flush(): torch.cuda.empty_cache() gc.collect() @@ -666,7 +668,6 @@ class BaseSDTrainProcess(BaseTrainProcess): # # prepare all the models stuff for accelerator (hopefully we dont miss any) self.sd.vae = self.accelerator.prepare(self.sd.vae) if self.sd.unet is not None: - self.sd.unet_unwrapped = self.sd.unet self.sd.unet = self.accelerator.prepare(self.sd.unet) # todo always tdo it? self.modules_being_trained.append(self.sd.unet) @@ -1103,11 +1104,19 @@ class BaseSDTrainProcess(BaseTrainProcess): if timestep_type is None: timestep_type = self.train_config.timestep_type + patch_size = 1 + if self.sd.is_flux: + # flux is a patch size of 1, but latents are divided by 2, so we need to double it + patch_size = 2 + elif hasattr(self.sd.unet.config, 'patch_size'): + patch_size = self.sd.unet.config.patch_size + self.sd.noise_scheduler.set_train_timesteps( num_train_timesteps, device=self.device_torch, timestep_type=timestep_type, - latents=latents + latents=latents, + patch_size=patch_size, ) else: self.sd.noise_scheduler.set_timesteps( @@ -1401,21 +1410,26 @@ class BaseSDTrainProcess(BaseTrainProcess): model_config_to_load.name_or_path = latest_save_path self.load_training_state_from_metadata(latest_save_path) - # get the noise scheduler - arch = 'sd' - if self.model_config.is_pixart: - arch = 'pixart' - if self.model_config.is_flux: - arch = 'flux' - if self.model_config.is_lumina2: - arch = 'lumina2' - sampler = get_sampler( - self.train_config.noise_scheduler, - { - "prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon", - }, - arch=arch, - ) + ModelClass = get_model_class(self.model_config) + # if the model class has get_train_scheduler static method + if hasattr(ModelClass, 'get_train_scheduler'): + sampler = ModelClass.get_train_scheduler() + else: + # get the noise scheduler + arch = 'sd' + if self.model_config.is_pixart: + arch = 'pixart' + if self.model_config.is_flux: + arch = 'flux' + if self.model_config.is_lumina2: + arch = 'lumina2' + sampler = get_sampler( + self.train_config.noise_scheduler, + { + "prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon", + }, + arch=arch, + ) if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None: previous_refiner_save = self.get_latest_save_path(self.job.name + '_refiner') @@ -1423,7 +1437,7 @@ class BaseSDTrainProcess(BaseTrainProcess): model_config_to_load.refiner_name_or_path = previous_refiner_save self.load_training_state_from_metadata(previous_refiner_save) - self.sd = StableDiffusion( + self.sd = ModelClass( device=self.device, model_config=model_config_to_load, dtype=self.train_config.dtype, @@ -1559,6 +1573,9 @@ class BaseSDTrainProcess(BaseTrainProcess): # if is_lycoris: # preset = PRESET['full'] # NetworkClass.apply_preset(preset) + + if hasattr(self.sd, 'target_lora_modules'): + network_kwargs['target_lin_modules'] = self.sd.target_lora_modules self.network = NetworkClass( text_encoder=text_encoder, @@ -1587,6 +1604,7 @@ class BaseSDTrainProcess(BaseTrainProcess): network_config=self.network_config, network_type=self.network_config.type, transformer_only=self.network_config.transformer_only, + is_transformer=self.sd.is_transformer, **network_kwargs ) diff --git a/requirements.txt b/requirements.txt index 4040e760..cef9b658 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ torch==2.5.1 torchvision==0.20.1 safetensors -git+https://github.com/huggingface/diffusers@28f48f4051e80082cbe97f2d62b365dbb01040ec -transformers +git+https://github.com/huggingface/diffusers@24c062aaa19f5626d03d058daf8afffa2dfd49f7 +transformers==4.49.0 lycoris-lora==1.8.3 flatten_json pyyaml diff --git a/testing/test_vae.py b/testing/test_vae.py index 44b31f63..463ab555 100644 --- a/testing/test_vae.py +++ b/testing/test_vae.py @@ -29,7 +29,7 @@ def paramiter_count(model): return int(paramiter_count) -def calculate_metrics(vae, images, max_imgs=-1): +def calculate_metrics(vae, images, max_imgs=-1, save_output=False): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") vae = vae.to(device) lpips_model = lpips.LPIPS(net='alex').to(device) @@ -44,6 +44,9 @@ def calculate_metrics(vae, images, max_imgs=-1): # ]) # needs values between -1 and 1 to_tensor = ToTensor() + + # remove _reconstructed.png files + images = [img for img in images if not img.endswith("_reconstructed.png")] if max_imgs > 0 and len(images) > max_imgs: images = images[:max_imgs] @@ -82,6 +85,15 @@ def calculate_metrics(vae, images, max_imgs=-1): avg_rfid = 0 avg_psnr = sum(psnr_scores) / len(psnr_scores) avg_lpips = sum(lpips_scores) / len(lpips_scores) + + if save_output: + filename_no_ext = os.path.splitext(os.path.basename(img_path))[0] + folder = os.path.dirname(img_path) + save_path = os.path.join(folder, filename_no_ext + "_reconstructed.png") + reconstructed = (reconstructed + 1) / 2 + reconstructed = reconstructed.clamp(0, 1) + reconstructed = transforms.ToPILImage()(reconstructed[0].cpu()) + reconstructed.save(save_path) return avg_rfid, avg_psnr, avg_lpips @@ -91,18 +103,23 @@ def main(): parser.add_argument("--vae_path", type=str, required=True, help="Path to the VAE model") parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images") parser.add_argument("--max_imgs", type=int, default=-1, help="Max num of images. Default is -1 for all images.") + # boolean store true + parser.add_argument("--save_output", action="store_true", help="Save the output images") args = parser.parse_args() if os.path.isfile(args.vae_path): vae = AutoencoderKL.from_single_file(args.vae_path) else: - vae = AutoencoderKL.from_pretrained(args.vae_path) + try: + vae = AutoencoderKL.from_pretrained(args.vae_path) + except: + vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae") vae.eval() vae = vae.to(device) print(f"Model has {paramiter_count(vae)} parameters") images = load_images(args.image_folder) - avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs) + avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs, args.save_output) # print(f"Average rFID: {avg_rfid}") print(f"Average PSNR: {avg_psnr}") diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 07b3e2b7..e92f7cbe 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -432,6 +432,9 @@ class TrainConfig: self.force_consistent_noise = kwargs.get('force_consistent_noise', False) +ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21'] + + class ModelConfig: def __init__(self, **kwargs): self.name_or_path: str = kwargs.get('name_or_path', None) @@ -509,6 +512,36 @@ class ModelConfig: self.split_model_other_module_param_count_scale = kwargs.get("split_model_other_module_param_count_scale", 0.3) self.te_name_or_path = kwargs.get("te_name_or_path", None) + + self.arch: ModelArch = kwargs.get("arch", None) + + # handle migrating to new model arch + if self.arch is None: + if kwargs.get('is_v2', False): + self.arch = 'sd2' + elif kwargs.get('is_v3', False): + self.arch = 'sd3' + elif kwargs.get('is_xl', False): + self.arch = 'sdxl' + elif kwargs.get('is_pixart', False): + self.arch = 'pixart' + elif kwargs.get('is_pixart_sigma', False): + self.arch = 'pixart_sigma' + elif kwargs.get('is_auraflow', False): + self.arch = 'auraflow' + elif kwargs.get('is_flux', False): + self.arch = 'flux' + elif kwargs.get('is_flex2', False): + self.arch = 'flex2' + elif kwargs.get('is_lumina2', False): + self.arch = 'lumina2' + elif kwargs.get('is_vega', False): + self.arch = 'vega' + elif kwargs.get('is_ssd', False): + self.arch = 'ssd' + else: + self.arch = 'sd1' + class EMAConfig: diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 1b308cd4..84ac02b1 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -178,6 +178,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): transformer_only: bool = False, peft_format: bool = False, is_assistant_adapter: bool = False, + is_transformer: bool = False, **kwargs ) -> None: """ @@ -237,9 +238,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): self.network_config: NetworkConfig = kwargs.get("network_config", None) self.peft_format = peft_format + self.is_transformer = is_transformer + # always do peft for flux only for now - if self.is_flux or self.is_v3 or self.is_lumina2: + if self.is_flux or self.is_v3 or self.is_lumina2 or is_transformer: # don't do peft format for lokr if self.network_type.lower() != "lokr": self.peft_format = True @@ -282,7 +285,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): unet_prefix = self.LORA_PREFIX_UNET if self.peft_format: unet_prefix = self.PEFT_PREFIX_UNET - if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2: + if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2 or self.is_transformer: unet_prefix = f"lora_transformer" if self.peft_format: unet_prefix = "transformer" @@ -341,6 +344,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): if self.transformer_only and self.is_v3 and is_unet: if "transformer_blocks" not in lora_name: skip = True + + # handle custom models + if self.transformer_only and is_unet and hasattr(root_module, 'transformer_blocks'): + if "transformer_blocks" not in lora_name: + skip = True if (is_linear or is_conv2d) and not skip: diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py new file mode 100644 index 00000000..cae29ffd --- /dev/null +++ b/toolkit/models/base_model.py @@ -0,0 +1,1426 @@ +import copy +import gc +import json +import random +import shutil +import typing +from typing import Union, List, Literal +import os +from collections import OrderedDict +import copy +import yaml +from PIL import Image +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg +from torch.nn import Parameter +from tqdm import tqdm +from torchvision.transforms import Resize, transforms + +from toolkit.clip_vision_adapter import ClipVisionAdapter +from toolkit.custom_adapter import CustomAdapter +from toolkit.ip_adapter import IPAdapter +from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch +from toolkit.models.decorator import Decorator +from toolkit.paths import KEYMAPS_ROOT +from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds +from toolkit.reference_adapter import ReferenceAdapter +from toolkit.saving import save_ldm_model_from_diffusers +from toolkit.sd_device_states_presets import empty_preset +from toolkit.train_tools import get_torch_dtype, apply_noise_offset +import torch +from toolkit.pipelines import CustomStableDiffusionXLPipeline +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \ + LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \ + FluxTransformer2DModel +from toolkit.models.lumina2 import Lumina2Transformer2DModel +import diffusers +from diffusers import \ + AutoencoderKL, \ + UNet2DConditionModel +from diffusers import PixArtAlphaPipeline +from transformers import T5EncoderModel, UMT5EncoderModel +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection + +from toolkit.accelerator import get_accelerator, unwrap_model +from typing import TYPE_CHECKING +from toolkit.print import print_acc +from transformers import Gemma2Model, Qwen2Model, LlamaModel + +if TYPE_CHECKING: + from toolkit.lora_special import LoRASpecialNetwork + +# tell it to shut up +diffusers.logging.set_verbosity(diffusers.logging.ERROR) + +SD_PREFIX_VAE = "vae" +SD_PREFIX_UNET = "unet" +SD_PREFIX_REFINER_UNET = "refiner_unet" +SD_PREFIX_TEXT_ENCODER = "te" + +SD_PREFIX_TEXT_ENCODER1 = "te0" +SD_PREFIX_TEXT_ENCODER2 = "te1" + +# prefixed diffusers keys +DO_NOT_TRAIN_WEIGHTS = [ + "unet_time_embedding.linear_1.bias", + "unet_time_embedding.linear_1.weight", + "unet_time_embedding.linear_2.bias", + "unet_time_embedding.linear_2.weight", + "refiner_unet_time_embedding.linear_1.bias", + "refiner_unet_time_embedding.linear_1.weight", + "refiner_unet_time_embedding.linear_2.bias", + "refiner_unet_time_embedding.linear_2.weight", +] + +DeviceStatePreset = Literal['cache_latents', 'generate'] + + +class BlankNetwork: + + def __init__(self): + self.multiplier = 1.0 + self.is_active = True + self.is_merged_in = False + self.can_merge_in = False + + def __enter__(self): + self.is_active = True + + def __exit__(self, exc_type, exc_val, exc_tb): + self.is_active = False + + def train(self): + pass + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。 +# VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 + + +class BaseModel: + + def __init__( + self, + device, + model_config: ModelConfig, + dtype='fp16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + self.accelerator = get_accelerator() + self.custom_pipeline = custom_pipeline + self.device = str(self.accelerator.device) + self.dtype = dtype + self.torch_dtype = get_torch_dtype(dtype) + self.device_torch = self.accelerator.device + + self.vae_device_torch = self.accelerator.device + self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype) + + self.te_device_torch = self.accelerator.device + self.te_torch_dtype = get_torch_dtype(model_config.te_dtype) + + self.model_config = model_config + self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon" + + self.device_state = None + + self.pipeline: Union[None, 'StableDiffusionPipeline', + 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline'] + self.vae: Union[None, 'AutoencoderKL'] + self.model: Union[None, 'Transformer2DModel', 'UNet2DConditionModel'] + self.text_encoder: Union[None, 'CLIPTextModel', + List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]] + self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']] + self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler + + self.refiner_unet: Union[None, 'UNet2DConditionModel'] = None + self.assistant_lora: Union[None, 'LoRASpecialNetwork'] = None + + # sdxl stuff + self.logit_scale = None + self.ckppt_info = None + self.is_loaded = False + + # to hold network if there is one + self.network = None + self.adapter: Union['ControlNetModel', 'T2IAdapter', + 'IPAdapter', 'ReferenceAdapter', None] = None + self.decorator: Union[Decorator, None] = None + self.arch: ModelArch = model_config.arch + + self.use_text_encoder_1 = model_config.use_text_encoder_1 + self.use_text_encoder_2 = model_config.use_text_encoder_2 + + self.config_file = None + + self.is_flow_matching = False + + self.quantize_device = self.device_torch + self.low_vram = self.model_config.low_vram + + # merge in and preview active with -1 weight + self.invert_assistant_lora = False + self._after_sample_img_hooks = [] + self._status_update_hooks = [] + self.is_transformer = False + + # properties for old arch for backwards compatibility + @property + def unet(self): + return self.model + + # set unet to model + @unet.setter + def unet(self, value): + self.model = value + + @property + def unet_unwrapped(self): + return unwrap_model(self.model) + + @property + def model_unwrapped(self): + return unwrap_model(self.model) + + @property + def is_xl(self): + return self.arch == 'sdxl' + + @property + def is_v2(self): + return self.arch == 'sd2' + + @property + def is_ssd(self): + return self.arch == 'ssd' + + @property + def is_v3(self): + return self.arch == 'sd3' + + @property + def is_vega(self): + return self.arch == 'vega' + + @property + def is_pixart(self): + return self.arch == 'pixart' + + @property + def is_auraflow(self): + return self.arch == 'auraflow' + + @property + def is_flux(self): + return self.arch == 'flux' + + @property + def is_flex2(self): + return self.arch == 'flex2' + + @property + def is_lumina2(self): + return self.arch == 'lumina2' + + # these must be implemented in child classes + def load_model(self): + # override this in child classes + raise NotImplementedError( + "load_model must be implemented in child classes") + + def get_generation_pipeline(self): + # override this in child classes + raise NotImplementedError( + "get_generation_pipeline must be implemented in child classes") + + def generate_single_image( + self, + pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + # override this in child classes + raise NotImplementedError( + "generate_single_image must be implemented in child classes") + + def get_noise_prediction( + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs + ): + raise NotImplementedError( + "get_noise_prediction must be implemented in child classes") + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + raise NotImplementedError( + "get_prompt_embeds must be implemented in child classes") + + def get_model_has_grad(self): + raise NotImplementedError( + "get_model_has_grad must be implemented in child classes") + + def get_te_has_grad(self): + raise NotImplementedError( + "get_te_has_grad must be implemented in child classes") + + def save_model(self, output_path, meta, save_dtype): + # todo handle dtype without overloading anything (vram, cpu, etc) + unwrap_model(self.pipeline).save_pretrained( + save_directory=output_path, + safe_serialization=True, + ) + # save out meta config + meta_path = os.path.join(output_path, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(meta, f) + # end must be implemented in child classes + + def te_train(self): + if isinstance(self.text_encoder, list): + for te in self.text_encoder: + te.train() + elif self.text_encoder is not None: + self.text_encoder.train() + + def te_eval(self): + if isinstance(self.text_encoder, list): + for te in self.text_encoder: + te.eval() + elif self.text_encoder is not None: + self.text_encoder.eval() + + def _after_sample_image(self, img_num, total_imgs): + # process all hooks + for hook in self._after_sample_img_hooks: + hook(img_num, total_imgs) + + def add_after_sample_image_hook(self, func): + self._after_sample_img_hooks.append(func) + + def _status_update(self, status: str): + for hook in self._status_update_hooks: + hook(status) + + def print_and_status_update(self, status: str): + print_acc(status) + self._status_update(status) + + def add_status_update_hook(self, func): + self._status_update_hooks.append(func) + + @torch.no_grad() + def generate_images( + self, + image_configs: List[GenerateImageConfig], + sampler=None, + pipeline: Union[None, StableDiffusionPipeline, + StableDiffusionXLPipeline] = None, + ): + network = unwrap_model(self.network) + merge_multiplier = 1.0 + flush() + # if using assistant, unfuse it + if self.model_config.assistant_lora_path is not None: + print_acc("Unloading assistant lora") + if self.invert_assistant_lora: + self.assistant_lora.is_active = True + # move weights on to the device + self.assistant_lora.force_to( + self.device_torch, self.torch_dtype) + else: + self.assistant_lora.is_active = False + + if self.model_config.inference_lora_path is not None: + print_acc("Loading inference lora") + self.assistant_lora.is_active = True + # move weights on to the device + self.assistant_lora.force_to(self.device_torch, self.torch_dtype) + + if network is not None: + network.eval() + # check if we have the same network weight for all samples. If we do, we can merge in th + # the network to drastically speed up inference + unique_network_weights = set( + [x.network_multiplier for x in image_configs]) + if len(unique_network_weights) == 1 and network.can_merge_in: + can_merge_in = True + merge_multiplier = unique_network_weights.pop() + network.merge_in(merge_weight=merge_multiplier) + else: + network = BlankNetwork() + + self.save_device_state() + self.set_device_state_preset('generate') + + # save current seed state for training + rng_state = torch.get_rng_state() + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + + if pipeline is None: + pipeline = self.get_generation_pipeline() + try: + pipeline.set_progress_bar_config(disable=True) + except: + pass + + start_multiplier = 1.0 + if network is not None: + start_multiplier = network.multiplier + + # pipeline.to(self.device_torch) + + with network: + with torch.no_grad(): + if network is not None: + assert network.is_active + + for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False): + gen_config = image_configs[i] + + extra = {} + validation_image = None + if self.adapter is not None and gen_config.adapter_image_path is not None: + validation_image = Image.open( + gen_config.adapter_image_path).convert("RGB") + if isinstance(self.adapter, T2IAdapter): + # not sure why this is double?? + validation_image = validation_image.resize( + (gen_config.width * 2, gen_config.height * 2)) + extra['image'] = validation_image + extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale + if isinstance(self.adapter, ControlNetModel): + validation_image = validation_image.resize( + (gen_config.width, gen_config.height)) + extra['image'] = validation_image + extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale + if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter): + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + validation_image = transform(validation_image) + if isinstance(self.adapter, CustomAdapter): + # todo allow loading multiple + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + validation_image = transform(validation_image) + self.adapter.num_images = 1 + if isinstance(self.adapter, ReferenceAdapter): + # need -1 to 1 + validation_image = transforms.ToTensor()(validation_image) + validation_image = validation_image * 2.0 - 1.0 + validation_image = validation_image.unsqueeze(0) + self.adapter.set_reference_images(validation_image) + + if network is not None: + network.multiplier = gen_config.network_multiplier + torch.manual_seed(gen_config.seed) + torch.cuda.manual_seed(gen_config.seed) + + generator = torch.manual_seed(gen_config.seed) + + if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter) \ + and gen_config.adapter_image_path is not None: + # run through the adapter to saturate the embeds + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + validation_image) + self.adapter(conditional_clip_embeds) + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + # handle condition the prompts + gen_config.prompt = self.adapter.condition_prompt( + gen_config.prompt, + is_unconditional=False, + ) + gen_config.prompt_2 = gen_config.prompt + gen_config.negative_prompt = self.adapter.condition_prompt( + gen_config.negative_prompt, + is_unconditional=True, + ) + gen_config.negative_prompt_2 = gen_config.negative_prompt + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None: + self.adapter.trigger_pre_te( + tensors_0_1=validation_image, + is_training=False, + has_been_preprocessed=False, + quad_count=4 + ) + + # encode the prompt ourselves so we can do fun stuff with embeddings + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + conditional_embeds = self.encode_prompt( + gen_config.prompt, gen_config.prompt_2, force_all=True) + + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = True + unconditional_embeds = self.encode_prompt( + gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True + ) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + + # allow any manipulations to take place to embeddings + gen_config.post_process_embeddings( + conditional_embeds, + unconditional_embeds, + ) + + if self.decorator is not None: + # apply the decorator to the embeddings + conditional_embeds.text_embeds = self.decorator( + conditional_embeds.text_embeds) + unconditional_embeds.text_embeds = self.decorator( + unconditional_embeds.text_embeds, is_unconditional=True) + + if self.adapter is not None and isinstance(self.adapter, IPAdapter) \ + and gen_config.adapter_image_path is not None: + # apply the image projection + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + validation_image) + unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, + True) + conditional_embeds = self.adapter( + conditional_embeds, conditional_clip_embeds, is_unconditional=False) + unconditional_embeds = self.adapter( + unconditional_embeds, unconditional_clip_embeds, is_unconditional=True) + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + conditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=validation_image, + prompt_embeds=conditional_embeds, + is_training=False, + has_been_preprocessed=False, + is_generating_samples=True, + ) + unconditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=validation_image, + prompt_embeds=unconditional_embeds, + is_training=False, + has_been_preprocessed=False, + is_unconditional=True, + is_generating_samples=True, + ) + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len( + gen_config.extra_values) > 0: + extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch, + dtype=self.torch_dtype) + # apply extra values to the embeddings + self.adapter.add_extra_values( + extra_values, is_unconditional=False) + self.adapter.add_extra_values(torch.zeros_like( + extra_values), is_unconditional=True) + pass # todo remove, for debugging + + if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: + # if we have a refiner loaded, set the denoising end at the refiner start + extra['denoising_end'] = gen_config.refiner_start_at + extra['output_type'] = 'latent' + if not self.is_xl: + raise ValueError( + "Refiner is only supported for XL models") + + conditional_embeds = conditional_embeds.to( + self.device_torch, dtype=self.unet.dtype) + unconditional_embeds = unconditional_embeds.to( + self.device_torch, dtype=self.unet.dtype) + + img = self.generate_single_image( + pipeline, + gen_config, + conditional_embeds, + unconditional_embeds, + generator, + extra, + ) + + gen_config.save_image(img, i) + gen_config.log_image(img, i) + self._after_sample_image(i, len(image_configs)) + flush() + + if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter): + self.adapter.clear_memory() + + # clear pipeline and cache to reduce vram usage + del pipeline + torch.cuda.empty_cache() + + # restore training state + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + self.restore_device_state() + if network is not None: + network.train() + network.multiplier = start_multiplier + + self.unet.to(self.device_torch, dtype=self.torch_dtype) + if network.is_merged_in: + network.merge_out(merge_multiplier) + # self.tokenizer.to(original_device_dict['tokenizer']) + + # refuse loras + if self.model_config.assistant_lora_path is not None: + print_acc("Loading assistant lora") + if self.invert_assistant_lora: + self.assistant_lora.is_active = False + # move weights off the device + self.assistant_lora.force_to('cpu', self.torch_dtype) + else: + self.assistant_lora.is_active = True + + if self.model_config.inference_lora_path is not None: + print_acc("Unloading inference lora") + self.assistant_lora.is_active = False + # move weights off the device + self.assistant_lora.force_to('cpu', self.torch_dtype) + flush() + + def get_latent_noise( + self, + height=None, + width=None, + pixel_height=None, + pixel_width=None, + batch_size=1, + noise_offset=0.0, + ): + VAE_SCALE_FACTOR = 2 ** ( + len(self.vae.config['block_out_channels']) - 1) + if height is None and pixel_height is None: + raise ValueError("height or pixel_height must be specified") + if width is None and pixel_width is None: + raise ValueError("width or pixel_width must be specified") + if height is None: + height = pixel_height // VAE_SCALE_FACTOR + if width is None: + width = pixel_width // VAE_SCALE_FACTOR + + num_channels = self.unet_unwrapped.config['in_channels'] + if self.is_flux: + # has 64 channels in for some reason + num_channels = 16 + noise = torch.randn( + ( + batch_size, + num_channels, + height, + width, + ), + device=self.unet.device, + ) + noise = apply_noise_offset(noise, noise_offset) + return noise + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + **kwargs, + ) -> torch.FloatTensor: + original_samples_chunks = torch.chunk( + original_samples, original_samples.shape[0], dim=0) + noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) + timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) + + if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): + timesteps_chunks = [timesteps_chunks[0]] * \ + len(original_samples_chunks) + + noisy_latents_chunks = [] + + for idx in range(original_samples.shape[0]): + noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx], + timesteps_chunks[idx]) + noisy_latents_chunks.append(noisy_latents) + + noisy_latents = torch.cat(noisy_latents_chunks, dim=0) + return noisy_latents + + def predict_noise( + self, + latents: torch.Tensor, + text_embeddings: Union[PromptEmbeds, None] = None, + timestep: Union[int, torch.Tensor] = 1, + guidance_scale=7.5, + guidance_rescale=0, + add_time_ids=None, + conditional_embeddings: Union[PromptEmbeds, None] = None, + unconditional_embeddings: Union[PromptEmbeds, None] = None, + is_input_scaled=False, + detach_unconditional=False, + rescale_cfg=None, + return_conditional_pred=False, + guidance_embedding_scale=1.0, + bypass_guidance_embedding=False, + **kwargs, + ): + conditional_pred = None + # get the embeddings + if text_embeddings is None and conditional_embeddings is None: + raise ValueError( + "Either text_embeddings or conditional_embeddings must be specified") + if text_embeddings is None and unconditional_embeddings is not None: + text_embeddings = concat_prompt_embeds([ + unconditional_embeddings, # negative embedding + conditional_embeddings, # positive embedding + ]) + elif text_embeddings is None and conditional_embeddings is not None: + # not doing cfg + text_embeddings = conditional_embeddings + + # CFG is comparing neg and positive, if we have concatenated embeddings + # then we are doing it, otherwise we are not and takes half the time. + do_classifier_free_guidance = True + + # check if batch size of embeddings matches batch size of latents + if latents.shape[0] == text_embeddings.text_embeds.shape[0]: + do_classifier_free_guidance = False + elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]: + raise ValueError( + "Batch size of latents must be the same or half the batch size of text embeddings") + latents = latents.to(self.device_torch) + text_embeddings = text_embeddings.to(self.device_torch) + timestep = timestep.to(self.device_torch) + + # if timestep is zero dim, unsqueeze it + if len(timestep.shape) == 0: + timestep = timestep.unsqueeze(0) + + # if we only have 1 timestep, we can just use the same timestep for all + if timestep.shape[0] == 1 and latents.shape[0] > 1: + # check if it is rank 1 or 2 + if len(timestep.shape) == 1: + timestep = timestep.repeat(latents.shape[0]) + else: + timestep = timestep.repeat(latents.shape[0], 0) + + # handle t2i adapters + if 'down_intrablock_additional_residuals' in kwargs: + # go through each item and concat if doing cfg and it doesnt have the same shape + for idx, item in enumerate(kwargs['down_intrablock_additional_residuals']): + if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: + kwargs['down_intrablock_additional_residuals'][idx] = torch.cat([ + item] * 2, dim=0) + + # handle controlnet + if 'down_block_additional_residuals' in kwargs and 'mid_block_additional_residual' in kwargs: + # go through each item and concat if doing cfg and it doesnt have the same shape + for idx, item in enumerate(kwargs['down_block_additional_residuals']): + if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: + kwargs['down_block_additional_residuals'][idx] = torch.cat([ + item] * 2, dim=0) + for idx, item in enumerate(kwargs['mid_block_additional_residual']): + if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: + kwargs['mid_block_additional_residual'][idx] = torch.cat( + [item] * 2, dim=0) + + def scale_model_input(model_input, timestep_tensor): + if is_input_scaled: + return model_input + mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) + timestep_chunks = torch.chunk( + timestep_tensor, timestep_tensor.shape[0], dim=0) + out_chunks = [] + # unsqueeze if timestep is zero dim + for idx in range(model_input.shape[0]): + # if scheduler has step_index + if hasattr(self.noise_scheduler, '_step_index'): + self.noise_scheduler._step_index = None + out_chunks.append( + self.noise_scheduler.scale_model_input( + mi_chunks[idx], timestep_chunks[idx]) + ) + return torch.cat(out_chunks, dim=0) + + with torch.no_grad(): + if do_classifier_free_guidance: + # if we are doing classifier free guidance, need to double up + latent_model_input = torch.cat([latents] * 2, dim=0) + timestep = torch.cat([timestep] * 2) + else: + latent_model_input = latents + + latent_model_input = scale_model_input( + latent_model_input, timestep) + + # check if we need to concat timesteps + if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1: + ts_bs = timestep.shape[0] + if ts_bs != latent_model_input.shape[0]: + if ts_bs == 1: + timestep = torch.cat( + [timestep] * latent_model_input.shape[0]) + elif ts_bs * 2 == latent_model_input.shape[0]: + timestep = torch.cat([timestep] * 2, dim=0) + else: + raise ValueError( + f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}") + + # predict the noise residual + if self.unet.device != self.device_torch: + self.unet.to(self.device_torch) + if self.unet.dtype != self.torch_dtype: + self.unet = self.unet.to(dtype=self.torch_dtype) + + noise_pred = self.get_noise_prediction( + latent_model_input=latent_model_input, + timestep=timestep, + text_embeddings=text_embeddings, + **kwargs + ) + + conditional_pred = noise_pred + + if do_classifier_free_guidance: + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2, dim=0) + conditional_pred = noise_pred_text + if detach_unconditional: + noise_pred_uncond = noise_pred_uncond.detach() + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + if rescale_cfg is not None and rescale_cfg != guidance_scale: + with torch.no_grad(): + # do cfg at the target rescale so we can match it + target_pred_mean_std = noise_pred_uncond + rescale_cfg * ( + noise_pred_text - noise_pred_uncond + ) + target_mean = target_pred_mean_std.mean( + [1, 2, 3], keepdim=True).detach() + target_std = target_pred_mean_std.std( + [1, 2, 3], keepdim=True).detach() + + pred_mean = noise_pred.mean( + [1, 2, 3], keepdim=True).detach() + pred_std = noise_pred.std([1, 2, 3], keepdim=True).detach() + + # match the mean and std + noise_pred = (noise_pred - pred_mean) / pred_std + noise_pred = (noise_pred * target_std) + target_mean + + # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + if return_conditional_pred: + return noise_pred, conditional_pred + return noise_pred + + def step_scheduler(self, model_input, latent_input, timestep_tensor, noise_scheduler=None): + if noise_scheduler is None: + noise_scheduler = self.noise_scheduler + # // sometimes they are on the wrong device, no idea why + if isinstance(noise_scheduler, DDPMScheduler) or isinstance(noise_scheduler, LCMScheduler): + try: + noise_scheduler.betas = noise_scheduler.betas.to( + self.device_torch) + noise_scheduler.alphas = noise_scheduler.alphas.to( + self.device_torch) + noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.to( + self.device_torch) + except Exception as e: + pass + + mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) + latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0) + timestep_chunks = torch.chunk( + timestep_tensor, timestep_tensor.shape[0], dim=0) + out_chunks = [] + if len(timestep_chunks) == 1 and len(mi_chunks) > 1: + # expand timestep to match + timestep_chunks = timestep_chunks * len(mi_chunks) + + for idx in range(model_input.shape[0]): + # Reset it so it is unique for the + if hasattr(noise_scheduler, '_step_index'): + noise_scheduler._step_index = None + if hasattr(noise_scheduler, 'is_scale_input_called'): + noise_scheduler.is_scale_input_called = True + out_chunks.append( + noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[ + 0] + ) + return torch.cat(out_chunks, dim=0) + + # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746 + def diffuse_some_steps( + self, + latents: torch.FloatTensor, + text_embeddings: PromptEmbeds, + total_timesteps: int = 1000, + start_timesteps=0, + guidance_scale=1, + add_time_ids=None, + bleed_ratio: float = 0.5, + bleed_latents: torch.FloatTensor = None, + is_input_scaled=False, + return_first_prediction=False, + **kwargs, + ): + timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps] + + first_prediction = None + + for timestep in tqdm(timesteps_to_run, leave=False): + timestep = timestep.unsqueeze_(0) + noise_pred, conditional_pred = self.predict_noise( + latents, + text_embeddings, + timestep, + guidance_scale=guidance_scale, + add_time_ids=add_time_ids, + is_input_scaled=is_input_scaled, + return_conditional_pred=True, + **kwargs, + ) + # some schedulers need to run separately, so do that. (euler for example) + + if return_first_prediction and first_prediction is None: + first_prediction = conditional_pred + + latents = self.step_scheduler(noise_pred, latents, timestep) + + # if not last step, and bleeding, bleed in some latents + if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]: + latents = (latents * (1 - bleed_ratio)) + \ + (bleed_latents * bleed_ratio) + + # only skip first scaling + is_input_scaled = False + + # return latents_steps + if return_first_prediction: + return latents, first_prediction + return latents + + def encode_prompt( + self, + prompt, + prompt2=None, + num_images_per_prompt=1, + force_all=False, + long_prompts=False, + max_length=None, + dropout_prob=0.0, + ) -> PromptEmbeds: + # sd1.5 embeddings are (bs, 77, 768) + prompt = prompt + # if it is not a list, make it one + if not isinstance(prompt, list): + prompt = [prompt] + + if prompt2 is not None and not isinstance(prompt2, list): + prompt2 = [prompt2] + + return self.get_prompt_embeds(prompt) + + @torch.no_grad() + def encode_images( + self, + image_list: List[torch.Tensor], + device=None, + dtype=None + ): + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + + latent_list = [] + # Move to vae to device if on cpu + if self.vae.device == 'cpu': + self.vae.to(device) + self.vae.eval() + self.vae.requires_grad_(False) + # move to device and dtype + image_list = [image.to(device, dtype=dtype) for image in image_list] + + VAE_SCALE_FACTOR = 2 ** ( + len(self.vae.config['block_out_channels']) - 1) + + # resize images if not divisible by 8 + for i in range(len(image_list)): + image = image_list[i] + if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0: + image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR, + image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image) + + images = torch.stack(image_list) + if isinstance(self.vae, AutoencoderTiny): + latents = self.vae.encode(images, return_dict=False)[0] + else: + latents = self.vae.encode(images).latent_dist.sample() + shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 + + # flux ref https://github.com/black-forest-labs/flux/blob/c23ae247225daba30fbd56058d247cc1b1fc20a3/src/flux/modules/autoencoder.py#L303 + # z = self.scale_factor * (z - self.shift_factor) + latents = self.vae.config['scaling_factor'] * (latents - shift) + latents = latents.to(device, dtype=dtype) + + return latents + + def decode_latents( + self, + latents: torch.Tensor, + device=None, + dtype=None + ): + if device is None: + device = self.device + if dtype is None: + dtype = self.torch_dtype + + # Move to vae to device if on cpu + if self.vae.device == 'cpu': + self.vae.to(self.device) + latents = latents.to(device, dtype=dtype) + latents = ( + latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor'] + images = self.vae.decode(latents).sample + images = images.to(device, dtype=dtype) + + return images + + def encode_image_prompt_pairs( + self, + prompt_list: List[str], + image_list: List[torch.Tensor], + device=None, + dtype=None + ): + # todo check image types and expand and rescale as needed + # device and dtype are for outputs + if device is None: + device = self.device + if dtype is None: + dtype = self.torch_dtype + + embedding_list = [] + latent_list = [] + # embed the prompts + for prompt in prompt_list: + embedding = self.encode_prompt(prompt).to( + self.device_torch, dtype=dtype) + embedding_list.append(embedding) + + return embedding_list, latent_list + + def get_weight_by_name(self, name): + # weights begin with te{te_num}_ for text encoder + # weights begin with unet_ for unet_ + if name.startswith('te'): + key = name[4:] + # text encoder + te_num = int(name[2]) + if isinstance(self.text_encoder, list): + return self.text_encoder[te_num].state_dict()[key] + else: + return self.text_encoder.state_dict()[key] + elif name.startswith('unet'): + key = name[5:] + # unet + return self.unet.state_dict()[key] + + raise ValueError(f"Unknown weight name: {name}") + + def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=False): + return inject_trigger_into_prompt( + prompt, + trigger=trigger, + to_replace_list=to_replace_list, + add_if_not_present=add_if_not_present, + ) + + def state_dict(self, vae=True, text_encoder=True, unet=True): + state_dict = OrderedDict() + if vae: + for k, v in self.vae.state_dict().items(): + new_key = k if k.startswith( + f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}" + state_dict[new_key] = v + if text_encoder: + if isinstance(self.text_encoder, list): + for i, encoder in enumerate(self.text_encoder): + for k, v in encoder.state_dict().items(): + new_key = k if k.startswith( + f"{SD_PREFIX_TEXT_ENCODER}{i}_") else f"{SD_PREFIX_TEXT_ENCODER}{i}_{k}" + state_dict[new_key] = v + else: + for k, v in self.text_encoder.state_dict().items(): + new_key = k if k.startswith( + f"{SD_PREFIX_TEXT_ENCODER}_") else f"{SD_PREFIX_TEXT_ENCODER}_{k}" + state_dict[new_key] = v + if unet: + for k, v in self.unet.state_dict().items(): + new_key = k if k.startswith( + f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}" + state_dict[new_key] = v + return state_dict + + def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> \ + OrderedDict[ + str, Parameter]: + named_params: OrderedDict[str, Parameter] = OrderedDict() + if vae: + for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"): + named_params[name] = param + if text_encoder: + if isinstance(self.text_encoder, list): + for i, encoder in enumerate(self.text_encoder): + if self.is_xl and not self.model_config.use_text_encoder_1 and i == 0: + # dont add these params + continue + if self.is_xl and not self.model_config.use_text_encoder_2 and i == 1: + # dont add these params + continue + + for name, param in encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}{i}"): + named_params[name] = param + else: + for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"): + named_params[name] = param + if unet: + if self.is_flux or self.is_lumina2 or self.is_transformer: + for name, param in self.unet.named_parameters(recurse=True, prefix="transformer"): + named_params[name] = param + else: + for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): + named_params[name] = param + + if self.model_config.ignore_if_contains is not None: + # remove params that contain the ignore_if_contains from named params + for key in list(named_params.keys()): + if any([s in key for s in self.model_config.ignore_if_contains]): + del named_params[key] + if self.model_config.only_if_contains is not None: + # remove params that do not contain the only_if_contains from named params + for key in list(named_params.keys()): + if not any([s in key for s in self.model_config.only_if_contains]): + del named_params[key] + + if refiner: + for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"): + named_params[name] = param + + # convert to state dict keys, jsut replace . with _ on keys + if state_dict_keys: + new_named_params = OrderedDict() + for k, v in named_params.items(): + # replace only the first . with an _ + new_key = k.replace('.', '_', 1) + new_named_params[new_key] = v + named_params = new_named_params + + return named_params + + def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None): + self.save_model( + output_path=output_file, + meta=meta, + save_dtype=save_dtype + ) + + def prepare_optimizer_params( + self, + unet=False, + text_encoder=False, + text_encoder_lr=None, + unet_lr=None, + refiner_lr=None, + refiner=False, + default_lr=1e-6, + ): + # todo maybe only get locon ones? + # not all items are saved, to make it match, we need to match out save mappings + # and not train anything not mapped. Also add learning rate + version = 'sd1' + if self.is_xl: + version = 'sdxl' + if self.is_v2: + version = 'sd2' + mapping_filename = f"stable_diffusion_{version}.json" + mapping_path = os.path.join(KEYMAPS_ROOT, mapping_filename) + with open(mapping_path, 'r') as f: + mapping = json.load(f) + ldm_diffusers_keymap = mapping['ldm_diffusers_keymap'] + + trainable_parameters = [] + + # we use state dict to find params + + if unet: + named_params = self.named_parameters( + vae=False, unet=unet, text_encoder=False, state_dict_keys=True) + unet_lr = unet_lr if unet_lr is not None else default_lr + params = [] + if self.is_pixart or self.is_auraflow or self.is_flux or self.is_v3 or self.is_lumina2: + for param in named_params.values(): + if param.requires_grad: + params.append(param) + else: + for key, diffusers_key in ldm_diffusers_keymap.items(): + if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: + if named_params[diffusers_key].requires_grad: + params.append(named_params[diffusers_key]) + param_data = {"params": params, "lr": unet_lr} + trainable_parameters.append(param_data) + print_acc(f"Found {len(params)} trainable parameter in unet") + + if text_encoder: + named_params = self.named_parameters( + vae=False, unet=False, text_encoder=text_encoder, state_dict_keys=True) + text_encoder_lr = text_encoder_lr if text_encoder_lr is not None else default_lr + params = [] + for key, diffusers_key in ldm_diffusers_keymap.items(): + if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: + if named_params[diffusers_key].requires_grad: + params.append(named_params[diffusers_key]) + param_data = {"params": params, "lr": text_encoder_lr} + trainable_parameters.append(param_data) + + print_acc( + f"Found {len(params)} trainable parameter in text encoder") + + if refiner: + named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, + state_dict_keys=True) + refiner_lr = refiner_lr if refiner_lr is not None else default_lr + params = [] + for key, diffusers_key in ldm_diffusers_keymap.items(): + diffusers_key = f"refiner_{diffusers_key}" + if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: + if named_params[diffusers_key].requires_grad: + params.append(named_params[diffusers_key]) + param_data = {"params": params, "lr": refiner_lr} + trainable_parameters.append(param_data) + + print_acc(f"Found {len(params)} trainable parameter in refiner") + + return trainable_parameters + + def save_device_state(self): + # saves the current device state for all modules + # this is useful for when we want to alter the state and restore it + unet_has_grad = self.get_model_has_grad() + + self.device_state = { + **empty_preset, + 'vae': { + 'training': self.vae.training, + 'device': self.vae.device, + }, + 'unet': { + 'training': self.unet.training, + 'device': self.unet.device, + 'requires_grad': unet_has_grad, + }, + } + if isinstance(self.text_encoder, list): + self.device_state['text_encoder']: List[dict] = [] + for encoder in self.text_encoder: + te_has_grad = self.get_te_has_grad() + self.device_state['text_encoder'].append({ + 'training': encoder.training, + 'device': encoder.device, + # todo there has to be a better way to do this + 'requires_grad': te_has_grad + }) + else: + te_has_grad = self.get_te_has_grad() + + self.device_state['text_encoder'] = { + 'training': self.text_encoder.training, + 'device': self.text_encoder.device, + 'requires_grad': te_has_grad + } + if self.adapter is not None: + if isinstance(self.adapter, IPAdapter): + requires_grad = self.adapter.image_proj_model.training + adapter_device = self.unet.device + elif isinstance(self.adapter, T2IAdapter): + requires_grad = self.adapter.adapter.conv_in.weight.requires_grad + adapter_device = self.adapter.device + elif isinstance(self.adapter, ControlNetModel): + requires_grad = self.adapter.conv_in.training + adapter_device = self.adapter.device + elif isinstance(self.adapter, ClipVisionAdapter): + requires_grad = self.adapter.embedder.training + adapter_device = self.adapter.device + elif isinstance(self.adapter, CustomAdapter): + requires_grad = self.adapter.training + adapter_device = self.adapter.device + elif isinstance(self.adapter, ReferenceAdapter): + # todo update this!! + requires_grad = True + adapter_device = self.adapter.device + else: + raise ValueError(f"Unknown adapter type: {type(self.adapter)}") + self.device_state['adapter'] = { + 'training': self.adapter.training, + 'device': adapter_device, + 'requires_grad': requires_grad, + } + + if self.refiner_unet is not None: + self.device_state['refiner_unet'] = { + 'training': self.refiner_unet.training, + 'device': self.refiner_unet.device, + 'requires_grad': self.refiner_unet.conv_in.weight.requires_grad, + } + + def restore_device_state(self): + # restores the device state for all modules + # this is useful for when we want to alter the state and restore it + if self.device_state is None: + return + self.set_device_state(self.device_state) + self.device_state = None + + def set_device_state(self, state): + if state['vae']['training']: + self.vae.train() + else: + self.vae.eval() + self.vae.to(state['vae']['device']) + if state['unet']['training']: + self.unet.train() + else: + self.unet.eval() + self.unet.to(state['unet']['device']) + if state['unet']['requires_grad']: + self.unet.requires_grad_(True) + else: + self.unet.requires_grad_(False) + if isinstance(self.text_encoder, list): + for i, encoder in enumerate(self.text_encoder): + if isinstance(state['text_encoder'], list): + if state['text_encoder'][i]['training']: + encoder.train() + else: + encoder.eval() + encoder.to(state['text_encoder'][i]['device']) + encoder.requires_grad_( + state['text_encoder'][i]['requires_grad']) + else: + if state['text_encoder']['training']: + encoder.train() + else: + encoder.eval() + encoder.to(state['text_encoder']['device']) + encoder.requires_grad_( + state['text_encoder']['requires_grad']) + else: + if state['text_encoder']['training']: + self.text_encoder.train() + else: + self.text_encoder.eval() + self.text_encoder.to(state['text_encoder']['device']) + self.text_encoder.requires_grad_( + state['text_encoder']['requires_grad']) + + if self.adapter is not None: + self.adapter.to(state['adapter']['device']) + self.adapter.requires_grad_(state['adapter']['requires_grad']) + if state['adapter']['training']: + self.adapter.train() + else: + self.adapter.eval() + + if self.refiner_unet is not None: + self.refiner_unet.to(state['refiner_unet']['device']) + self.refiner_unet.requires_grad_( + state['refiner_unet']['requires_grad']) + if state['refiner_unet']['training']: + self.refiner_unet.train() + else: + self.refiner_unet.eval() + flush() + + def set_device_state_preset(self, device_state_preset: DeviceStatePreset): + # sets a preset for device state + + # save current state first + self.save_device_state() + + active_modules = [] + training_modules = [] + if device_state_preset in ['cache_latents']: + active_modules = ['vae'] + if device_state_preset in ['cache_clip']: + active_modules = ['clip'] + if device_state_preset in ['generate']: + active_modules = ['vae', 'unet', + 'text_encoder', 'adapter', 'refiner_unet'] + + state = copy.deepcopy(empty_preset) + # vae + state['vae'] = { + 'training': 'vae' in training_modules, + 'device': self.vae_device_torch if 'vae' in active_modules else 'cpu', + 'requires_grad': 'vae' in training_modules, + } + + # unet + state['unet'] = { + 'training': 'unet' in training_modules, + 'device': self.device_torch if 'unet' in active_modules else 'cpu', + 'requires_grad': 'unet' in training_modules, + } + + if self.refiner_unet is not None: + state['refiner_unet'] = { + 'training': 'refiner_unet' in training_modules, + 'device': self.device_torch if 'refiner_unet' in active_modules else 'cpu', + 'requires_grad': 'refiner_unet' in training_modules, + } + + # text encoder + if isinstance(self.text_encoder, list): + state['text_encoder'] = [] + for i, encoder in enumerate(self.text_encoder): + state['text_encoder'].append({ + 'training': 'text_encoder' in training_modules, + 'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu', + 'requires_grad': 'text_encoder' in training_modules, + }) + else: + state['text_encoder'] = { + 'training': 'text_encoder' in training_modules, + 'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu', + 'requires_grad': 'text_encoder' in training_modules, + } + + if self.adapter is not None: + state['adapter'] = { + 'training': 'adapter' in training_modules, + 'device': self.device_torch if 'adapter' in active_modules else 'cpu', + 'requires_grad': 'adapter' in training_modules, + } + + self.set_device_state(state) + + def text_encoder_to(self, *args, **kwargs): + if isinstance(self.text_encoder, list): + for encoder in self.text_encoder: + encoder.to(*args, **kwargs) + else: + self.text_encoder.to(*args, **kwargs) diff --git a/toolkit/models/cogview4.py b/toolkit/models/cogview4.py new file mode 100644 index 00000000..593fa977 --- /dev/null +++ b/toolkit/models/cogview4.py @@ -0,0 +1,466 @@ +# DONT USE THIS!. IT DOES NOT WORK YET! +# Will revisit this when they release more info on how it was trained. + +import weakref +from diffusers import CogView4Pipeline +import torch +import yaml + +from toolkit.basic import flush +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.dequantize import patch_dequantization_on_save +from toolkit.models.base_model import BaseModel +from toolkit.prompt_utils import PromptEmbeds + +import os +import copy +from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch +import torch +import diffusers +from diffusers import AutoencoderKL, CogView4Transformer2DModel, CogView4Pipeline +from optimum.quanto import freeze, qfloat8, QTensor, qint4 +from toolkit.util.quantize import quantize +from transformers import GlmModel, AutoTokenizer +from diffusers import FlowMatchEulerDiscreteScheduler +from typing import TYPE_CHECKING +from toolkit.accelerator import unwrap_model +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler + +if TYPE_CHECKING: + from toolkit.lora_special import LoRASpecialNetwork + +# remove this after a bug is fixed in diffusers code. This is a workaround. + + +class FakeModel: + def __init__(self, model): + self.model_ref = weakref.ref(model) + pass + + @property + def device(self): + return self.model_ref().device + + +scheduler_config = { + "base_image_seq_len": 256, + "base_shift": 0.25, + "invert_sigmas": False, + "max_image_seq_len": 4096, + "max_shift": 0.75, + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": None, + "time_shift_type": "linear", + "use_beta_sigmas": False, + "use_dynamic_shifting": True, + "use_exponential_sigmas": False, + "use_karras_sigmas": False +} + + +class CogView4(BaseModel): + def __init__( + self, + device, + model_config: ModelConfig, + dtype='bf16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + super().__init__(device, model_config, dtype, + custom_pipeline, noise_scheduler, **kwargs) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ['CogView4Transformer2DModel'] + + # cache for holding noise + self.effective_noise = None + + # static method to get the scheduler + @staticmethod + def get_train_scheduler(): + scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + return scheduler + + def load_model(self): + dtype = self.torch_dtype + base_model_path = "THUDM/CogView4-6B" + model_path = self.model_config.name_or_path + + self.print_and_status_update("Loading CogView4 model") + # base_model_path = "black-forest-labs/FLUX.1-schnell" + base_model_path = self.model_config.name_or_path_original + subfolder = 'transformer' + transformer_path = model_path + if os.path.exists(transformer_path): + subfolder = None + transformer_path = os.path.join(transformer_path, 'transformer') + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, 'text_encoder') + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + self.print_and_status_update("Loading GlmModel") + tokenizer = AutoTokenizer.from_pretrained( + base_model_path, subfolder="tokenizer", torch_dtype=dtype) + text_encoder = GlmModel.from_pretrained( + base_model_path, subfolder="text_encoder", torch_dtype=dtype) + + text_encoder.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing GlmModel") + quantize(text_encoder, weights=qfloat8) + freeze(text_encoder) + flush() + + # hack to fix diffusers bug workaround + text_encoder.model = FakeModel(text_encoder) + + self.print_and_status_update("Loading transformer") + transformer = CogView4Transformer2DModel.from_pretrained( + transformer_path, + subfolder=subfolder, + torch_dtype=dtype, + ) + + if self.model_config.split_model_over_gpus: + raise ValueError( + "Splitting model over gpus is not supported for CogViewModels models") + + transformer.to(self.quantize_device, dtype=dtype) + flush() + + if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None: + raise ValueError( + "Assistant LoRA is not supported for CogViewModels models currently") + + if self.model_config.lora_path is not None: + raise ValueError( + "Loading LoRA is not supported for CogViewModels models currently") + + flush() + + if self.model_config.quantize: + quantization_args = self.model_config.quantize_kwargs + if 'exclude' not in quantization_args: + quantization_args['exclude'] = [] + if 'include' not in quantization_args: + quantization_args['include'] = [] + + # Be more specific with the include pattern to exactly match transformer blocks + quantization_args['include'] += ["transformer_blocks.*"] + + # Exclude all LayerNorm layers within transformer blocks + quantization_args['exclude'] += [ + "transformer_blocks.*.norm1", + "transformer_blocks.*.norm2", + "transformer_blocks.*.norm2_context", + "transformer_blocks.*.attn1.norm_q", + "transformer_blocks.*.attn1.norm_k" + ] + + # patch the state dict method + patch_dequantization_on_save(transformer) + quantization_type = qfloat8 + self.print_and_status_update("Quantizing transformer") + quantize(transformer, weights=quantization_type, **quantization_args) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + flush() + + scheduler = CogView4.get_train_scheduler() + self.print_and_status_update("Loading VAE") + vae = AutoencoderKL.from_pretrained( + base_model_path, subfolder="vae", torch_dtype=dtype) + flush() + + self.print_and_status_update("Making pipe") + pipe: CogView4Pipeline = CogView4Pipeline( + scheduler=scheduler, + text_encoder=None, + tokenizer=tokenizer, + vae=vae, + transformer=None, + ) + pipe.text_encoder = text_encoder + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = pipe.text_encoder + tokenizer = pipe.tokenizer + + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + text_encoder.to(self.device_torch) + text_encoder.requires_grad_(False) + text_encoder.eval() + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() + self.pipeline = pipe + self.model = transformer + self.vae = vae + self.text_encoder = text_encoder + self.tokenizer = tokenizer + + def get_generation_pipeline(self): + scheduler = CogView4.get_train_scheduler() + pipeline = CogView4Pipeline( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=scheduler, + ) + return pipeline + + def generate_single_image( + self, + pipeline: CogView4Pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype), + negative_prompt_embeds=unconditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype), + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs + ): + # target_size = (height, width) + target_size = latent_model_input.shape[-2:] + # multiply by 8 + target_size = (target_size[0] * 8, target_size[1] * 8) + crops_coords_top_left = torch.tensor( + [(0, 0)], dtype=self.torch_dtype, device=self.device_torch) + + original_size = torch.tensor( + [target_size], dtype=self.torch_dtype, device=self.device_torch) + target_size = original_size.clone() + noise_pred_cond = self.model( + hidden_states=latent_model_input, + encoder_hidden_states=text_embeddings.text_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + return_dict=False, + )[0] + return noise_pred_cond + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + prompt_embeds, _ = self.pipeline.encode_prompt( + prompt, + do_classifier_free_guidance=False, + device=self.device_torch, + dtype=self.torch_dtype, + ) + return PromptEmbeds(prompt_embeds) + + def get_model_has_grad(self): + return self.model.proj_out.weight.requires_grad + + def get_te_has_grad(self): + return self.text_encoder.layers[0].mlp.down_proj.weight.requires_grad + + def save_model(self, output_path, meta, save_dtype): + # only save the unet + transformer: CogView4Transformer2DModel = unwrap_model(self.model) + transformer.save_pretrained( + save_directory=os.path.join(output_path, 'transformer'), + safe_serialization=True, + ) + + meta_path = os.path.join(output_path, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(meta, f) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get('noise') + effective_noise = self.effective_noise + batch = kwargs.get('batch') + if batch is None: + raise ValueError("Batch is not provided") + if noise is None: + raise ValueError("Noise is not provided") + # return batch.latents + # return (batch.latents - noise).detach() + return (noise - batch.latents).detach() + # return (batch.latents).detach() + # return (effective_noise - batch.latents).detach() + + def _get_low_res_latents(self, latents): + # todo prevent needing to do this and grab the tensor another way. + with torch.no_grad(): + # Decode latents to image space + images = self.decode_latents( + latents, device=latents.device, dtype=latents.dtype) + + # Downsample by a factor of 2 using bilinear interpolation + B, C, H, W = images.shape + low_res_images = torch.nn.functional.interpolate( + images, + size=(H // 2, W // 2), + mode="bilinear", + align_corners=False + ) + + # Upsample back to original resolution to match expected VAE input dimensions + upsampled_low_res_images = torch.nn.functional.interpolate( + low_res_images, + size=(H, W), + mode="bilinear", + align_corners=False + ) + + # Encode the low-resolution images back to latent space + low_res_latents = self.encode_images( + upsampled_low_res_images, device=latents.device, dtype=latents.dtype) + return low_res_latents + + # def add_noise( + # self, + # original_samples: torch.FloatTensor, + # noise: torch.FloatTensor, + # timesteps: torch.IntTensor, + # **kwargs, + # ) -> torch.FloatTensor: + # relay_start_point = 500 + + # # Store original samples for loss calculation + # self.original_samples = original_samples + + # # Prepare chunks for batch processing + # original_samples_chunks = torch.chunk( + # original_samples, original_samples.shape[0], dim=0) + # noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) + # timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) + + # # Get the low res latents only if needed + # low_res_latents_chunks = None + + # # Handle case where timesteps is a single value for all samples + # if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): + # timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks) + + # noisy_latents_chunks = [] + # effective_noise_chunks = [] # Store the effective noise for each sample + + # for idx in range(original_samples.shape[0]): + # t = timesteps_chunks[idx] + # t_01 = (t / 1000).to(original_samples_chunks[idx].device) + + # # Flowmatching interpolation between original and noise + # if t > relay_start_point: + # # Standard flowmatching - direct linear interpolation + # noisy_latents = (1 - t_01) * original_samples_chunks[idx] + t_01 * noise_chunks[idx] + # effective_noise_chunks.append(noise_chunks[idx]) # Effective noise is just the noise + # else: + # # Relay flowmatching case - only compute low_res_latents if needed + # if low_res_latents_chunks is None: + # low_res_latents = self._get_low_res_latents(original_samples) + # low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0) + + # # Calculate the relay ratio (0 to 1) + # t_ratio = t.float() / relay_start_point + # t_ratio = torch.clamp(t_ratio, 0.0, 1.0) + + # # First blend between original and low-res based on t_ratio + # z0_t = (1 - t_ratio) * original_samples_chunks[idx] + t_ratio * low_res_latents_chunks[idx] + + # added_lor_res_noise = z0_t - original_samples_chunks[idx] + + # # Then apply flowmatching interpolation between this blended state and noise + # noisy_latents = (1 - t_01) * z0_t + t_01 * noise_chunks[idx] + + # # For prediction target, we need to store the effective "source" + # effective_noise_chunks.append(noise_chunks[idx] + added_lor_res_noise) + + # noisy_latents_chunks.append(noisy_latents) + + # noisy_latents = torch.cat(noisy_latents_chunks, dim=0) + # self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation + + # return noisy_latents + + # def add_noise( + # self, + # original_samples: torch.FloatTensor, + # noise: torch.FloatTensor, + # timesteps: torch.IntTensor, + # **kwargs, + # ) -> torch.FloatTensor: + # relay_start_point = 500 + + # # Store original samples for loss calculation + # self.original_samples = original_samples + + # # Prepare chunks for batch processing + # original_samples_chunks = torch.chunk( + # original_samples, original_samples.shape[0], dim=0) + # noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) + # timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) + + # # Get the low res latents only if needed + # low_res_latents = self._get_low_res_latents(original_samples) + # low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0) + + # # Handle case where timesteps is a single value for all samples + # if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): + # timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks) + + # noisy_latents_chunks = [] + # effective_noise_chunks = [] # Store the effective noise for each sample + + # for idx in range(original_samples.shape[0]): + # t = timesteps_chunks[idx] + # t_01 = (t / 1000).to(original_samples_chunks[idx].device) + + # lrln = low_res_latents_chunks[idx] - original_samples_chunks[idx] + # # lrln = lrln * (1 - t_01) + + # # make the noise an interpolation between noise and low_res_latents with + # # being noise at t_01=1 and low_res_latents at t_01=0 + # new_noise = t_01 * noise_chunks[idx] + (1 - t_01) * lrln + # # new_noise = noise_chunks[idx] + lrln + # # new_noise = noise_chunks[idx] + lrln + + # # Then apply flowmatching interpolation between this blended state and noise + # noisy_latents = (1 - t_01) * original_samples + t_01 * new_noise + + # # For prediction target, we need to store the effective "source" + # effective_noise_chunks.append(new_noise) + + # noisy_latents_chunks.append(noisy_latents) + + # noisy_latents = torch.cat(noisy_latents_chunks, dim=0) + # self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation + + # return noisy_latents diff --git a/toolkit/models/wan21.py b/toolkit/models/wan21.py new file mode 100644 index 00000000..045e9b1c --- /dev/null +++ b/toolkit/models/wan21.py @@ -0,0 +1,82 @@ +# WIP, coming soon ish +import torch +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.models.base_model import BaseModel +from toolkit.prompt_utils import PromptEmbeds +from toolkit.paths import REPOS_ROOT +import sys +import os + +import gc +import logging +import math +import os +import random +import sys +import types +from contextlib import contextmanager +from functools import partial + +import torch +import torch.cuda.amp as amp +import torch.distributed as dist +from tqdm import tqdm + + +class Wan21(BaseModel): + def __init__( + self, + device, + model_config: ModelConfig, + dtype='bf16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + super().__init__(device, model_config, dtype, + custom_pipeline, noise_scheduler, **kwargs) + self.is_flow_matching = True + raise NotImplementedError("Wan21 is not implemented yet") + # these must be implemented in child classes + + def load_model(self): + pass + + def get_generation_pipeline(self): + # override this in child classes + raise NotImplementedError( + "get_generation_pipeline must be implemented in child classes") + + def generate_single_image( + self, + pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + # override this in child classes + raise NotImplementedError( + "generate_single_image must be implemented in child classes") + + def get_noise_prediction( + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs + ): + raise NotImplementedError( + "get_noise_prediction must be implemented in child classes") + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + raise NotImplementedError( + "get_prompt_embeds must be implemented in child classes") + + def get_model_has_grad(self): + raise NotImplementedError( + "get_model_has_grad must be implemented in child classes") + + def get_te_has_grad(self): + raise NotImplementedError( + "get_te_has_grad must be implemented in child classes") diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py index 2a7a1cfd..1e0ae2ab 100644 --- a/toolkit/samplers/custom_flowmatch_sampler.py +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -44,7 +44,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): hbsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum()) # flatten second half to max - hbsmntw_weighing[num_timesteps // 2:] = hbsmntw_weighing[num_timesteps // 2:].max() + hbsmntw_weighing[num_timesteps // + 2:] = hbsmntw_weighing[num_timesteps // 2:].max() # Create linear timesteps from 1000 to 0 timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu') @@ -56,7 +57,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor: # Get the indices of the timesteps - step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps] + step_indices = [(self.timesteps == t).nonzero().item() + for t in timesteps] # Get the weights for the timesteps if v2: @@ -70,7 +72,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): sigmas = self.sigmas.to(device=device, dtype=dtype) schedule_timesteps = self.timesteps.to(device) timesteps = timesteps.to(device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + step_indices = [(schedule_timesteps == t).nonzero().item() + for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < n_dim: @@ -84,27 +87,24 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: - ## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578 - ## Add noise according to flow matching. - ## zt = (1 - texp) * x + texp * z1 - - # sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) - # noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise - - # timestep needs to be in [0, 1], we store them in [0, 1000] - # noisy_sample = (1 - timestep) * latent + timestep * noise t_01 = (timesteps / 1000).to(original_samples.device) - noisy_model_input = (1 - t_01) * original_samples + t_01 * noise - - # n_dim = original_samples.ndim - # sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device) - # noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise + # forward ODE + noisy_model_input = (1.0 - t_01) * original_samples + t_01 * noise + # reverse ODE + # noisy_model_input = (1 - t_01) * noise + t_01 * original_samples return noisy_model_input def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: return sample - def set_train_timesteps(self, num_timesteps, device, timestep_type='linear', latents=None): + def set_train_timesteps( + self, + num_timesteps, + device, + timestep_type='linear', + latents=None, + patch_size=1 + ): self.timestep_type = timestep_type if timestep_type == 'linear': timesteps = torch.linspace(1000, 0, num_timesteps, device=device) @@ -124,42 +124,67 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): self.timesteps = timesteps.to(device=device) return timesteps - elif timestep_type == 'flux_shift' or timestep_type == 'lumina2_shift': + elif timestep_type in ['flux_shift', 'lumina2_shift', 'shift']: # matches inference dynamic shifting timesteps = np.linspace( - self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_timesteps + self._sigma_to_t(self.sigma_max), self._sigma_to_t( + self.sigma_min), num_timesteps ) sigmas = timesteps / self.config.num_train_timesteps - - if latents is None: - raise ValueError('latents is None') - - h = latents.shape[2] // 2 # Divide by ph - w = latents.shape[3] // 2 # Divide by pw - image_seq_len = h * w - # todo need to know the mu for the shift - mu = calculate_shift( - image_seq_len, - self.config.get("base_image_seq_len", 256), - self.config.get("max_image_seq_len", 4096), - self.config.get("base_shift", 0.5), - self.config.get("max_shift", 1.16), - ) - sigmas = self.time_shift(mu, 1.0, sigmas) + if self.config.use_dynamic_shifting: + if latents is None: + raise ValueError('latents is None') - sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + # for flux we double up the patch size before sending her to simulate the latent reduction + h = latents.shape[2] + w = latents.shape[3] + image_seq_len = h * w // (patch_size**2) + + mu = calculate_shift( + image_seq_len, + self.config.get("base_image_seq_len", 256), + self.config.get("max_image_seq_len", 4096), + self.config.get("base_shift", 0.5), + self.config.get("max_shift", 1.16), + ) + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) + + if self.config.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras( + in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps) + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential( + in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps) + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta( + in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps) + + sigmas = torch.from_numpy(sigmas).to( + dtype=torch.float32, device=device) timesteps = sigmas * self.config.num_train_timesteps - sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + if self.config.invert_sigmas: + sigmas = 1.0 - sigmas + timesteps = sigmas * self.config.num_train_timesteps + sigmas = torch.cat( + [sigmas, torch.ones(1, device=sigmas.device)]) + else: + sigmas = torch.cat( + [sigmas, torch.zeros(1, device=sigmas.device)]) self.timesteps = timesteps.to(device=device) self.sigmas = sigmas - + self.timesteps = timesteps.to(device=device) return timesteps - + elif timestep_type == 'lognorm_blend': # disgtribute timestepd to the center/early and blend in linear alpha = 0.75 @@ -173,7 +198,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): t1 = ((1 - t1/t1.max()) * 1000) # add half of linear - t2 = torch.linspace(1000, 0, int(num_timesteps * (1 - alpha)), device=device) + t2 = torch.linspace(1000, 0, int( + num_timesteps * (1 - alpha)), device=device) timesteps = torch.cat((t1, t2)) # Sort the timesteps in descending order diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 7a2dcdff..65736178 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -29,7 +29,7 @@ from toolkit.ip_adapter import IPAdapter from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \ convert_vae_state_dict, load_vae from toolkit import train_tools -from toolkit.config_modules import ModelConfig, GenerateImageConfig +from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch from toolkit.metadata import get_meta_for_safetensors from toolkit.models.decorator import Decorator from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT @@ -64,7 +64,8 @@ from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT from huggingface_hub import hf_hub_download from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance -from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4 +from optimum.quanto import freeze, qfloat8, QTensor, qint4 +from toolkit.util.quantize import quantize from toolkit.accelerator import get_accelerator, unwrap_model from typing import TYPE_CHECKING from toolkit.print import print_acc @@ -160,7 +161,6 @@ class StableDiffusion: self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline'] self.vae: Union[None, 'AutoencoderKL'] self.unet: Union[None, 'UNet2DConditionModel'] - self.unet_unwrapped: Union[None, 'UNet2DConditionModel'] self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]] self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']] self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler @@ -177,16 +177,17 @@ class StableDiffusion: self.network = None self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None self.decorator: Union[Decorator, None] = None - self.is_xl = model_config.is_xl - self.is_v2 = model_config.is_v2 - self.is_ssd = model_config.is_ssd - self.is_v3 = model_config.is_v3 - self.is_vega = model_config.is_vega - self.is_pixart = model_config.is_pixart - self.is_auraflow = model_config.is_auraflow - self.is_flux = model_config.is_flux - self.is_flex2 = model_config.is_flex2 - self.is_lumina2 = model_config.is_lumina2 + self.arch: ModelArch = model_config.arch + # self.is_xl = model_config.is_xl + # self.is_v2 = model_config.is_v2 + # self.is_ssd = model_config.is_ssd + # self.is_v3 = model_config.is_v3 + # self.is_vega = model_config.is_vega + # self.is_pixart = model_config.is_pixart + # self.is_auraflow = model_config.is_auraflow + # self.is_flux = model_config.is_flux + # self.is_flex2 = model_config.is_flex2 + # self.is_lumina2 = model_config.is_lumina2 self.use_text_encoder_1 = model_config.use_text_encoder_1 self.use_text_encoder_2 = model_config.use_text_encoder_2 @@ -204,6 +205,53 @@ class StableDiffusion: self.invert_assistant_lora = False self._after_sample_img_hooks = [] self._status_update_hooks = [] + # todo update this based on the model + self.is_transformer = False + + # properties for old arch for backwards compatibility + @property + def is_xl(self): + return self.arch == 'sdxl' + + @property + def is_v2(self): + return self.arch == 'sd2' + + @property + def is_ssd(self): + return self.arch == 'ssd' + + @property + def is_v3(self): + return self.arch == 'sd3' + + @property + def is_vega(self): + return self.arch == 'vega' + + @property + def is_pixart(self): + return self.arch == 'pixart' + + @property + def is_auraflow(self): + return self.arch == 'auraflow' + + @property + def is_flux(self): + return self.arch == 'flux' + + @property + def is_flex2(self): + return self.arch == 'flex2' + + @property + def is_lumina2(self): + return self.arch == 'lumina2' + + @property + def unet_unwrapped(self): + return unwrap_model(self.unet) def load_model(self): if self.is_loaded: @@ -935,7 +983,6 @@ class StableDiffusion: if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux or self.is_lumina2: # pixart and sd3 dont use a unet self.unet = pipe.transformer - self.unet_unwrapped = pipe.transformer else: self.unet: 'UNet2DConditionModel' = pipe.unet self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) @@ -1734,7 +1781,8 @@ class StableDiffusion: self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.IntTensor + timesteps: torch.IntTensor, + **kwargs, ) -> torch.FloatTensor: original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0) noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) diff --git a/toolkit/util/get_model.py b/toolkit/util/get_model.py new file mode 100644 index 00000000..4d1668f8 --- /dev/null +++ b/toolkit/util/get_model.py @@ -0,0 +1,12 @@ +from toolkit.stable_diffusion_model import StableDiffusion +from toolkit.config_modules import ModelConfig + +def get_model_class(config: ModelConfig): + if config.arch == "wan21": + from toolkit.models.wan21 import Wan21 + return Wan21 + elif config.arch == "cogview4": + from toolkit.models.cogview4 import CogView4 + return CogView4 + else: + return StableDiffusion \ No newline at end of file diff --git a/toolkit/util/quantize.py b/toolkit/util/quantize.py new file mode 100644 index 00000000..fd7b3178 --- /dev/null +++ b/toolkit/util/quantize.py @@ -0,0 +1,55 @@ +from fnmatch import fnmatch +from typing import Any, Dict, List, Optional, Union +import torch + +from optimum.quanto.quantize import _quantize_submodule +from optimum.quanto.tensor import Optimizer, qtype + +# the quantize function in quanto had a bug where it was using exclude instead of include + + +def quantize( + model: torch.nn.Module, + weights: Optional[Union[str, qtype]] = None, + activations: Optional[Union[str, qtype]] = None, + optimizer: Optional[Optimizer] = None, + include: Optional[Union[str, List[str]]] = None, + exclude: Optional[Union[str, List[str]]] = None, +): + """Quantize the specified model submodules + + Recursively quantize the submodules of the specified parent model. + + Only modules that have quantized counterparts will be quantized. + + If include patterns are specified, the submodule name must match one of them. + + If exclude patterns are specified, the submodule must not match one of them. + + Include or exclude patterns are Unix shell-style wildcards which are NOT regular expressions. See + https://docs.python.org/3/library/fnmatch.html for more details. + + Note: quantization happens in-place and modifies the original model and its descendants. + + Args: + model (`torch.nn.Module`): the model whose submodules will be quantized. + weights (`Optional[Union[str, qtype]]`): the qtype for weights quantization. + activations (`Optional[Union[str, qtype]]`): the qtype for activations quantization. + include (`Optional[Union[str, List[str]]]`): + Patterns constituting the allowlist. If provided, module names must match at + least one pattern from the allowlist. + exclude (`Optional[Union[str, List[str]]]`): + Patterns constituting the denylist. If provided, module names must not match + any patterns from the denylist. + """ + if include is not None: + include = [include] if isinstance(include, str) else include + if exclude is not None: + exclude = [exclude] if isinstance(exclude, str) else exclude + for name, m in model.named_modules(): + if include is not None and not any(fnmatch(name, pattern) for pattern in include): + continue + if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude): + continue + _quantize_submodule(model, name, m, weights=weights, + activations=activations, optimizer=optimizer)