From acc79956aa00100bdf5037183768cdaafa59b5aa Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 1 Mar 2025 13:49:02 -0700 Subject: [PATCH 1/8] WIP create new class to add new models more easily --- jobs/process/BaseSDTrainProcess.py | 5 +- requirements.txt | 3 +- toolkit/config_modules.py | 33 + toolkit/models/base_model.py | 1467 ++++++++++++++++++++++++++++ toolkit/models/wan21.py | 56 ++ toolkit/stable_diffusion_model.py | 64 +- toolkit/util/get_model.py | 9 + 7 files changed, 1624 insertions(+), 13 deletions(-) create mode 100644 toolkit/models/base_model.py create mode 100644 toolkit/models/wan21.py create mode 100644 toolkit/util/get_model.py diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 2482c26d..b1493dbd 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -68,6 +68,8 @@ import transformers import diffusers import hashlib +from toolkit.util.get_model import get_model_class + def flush(): torch.cuda.empty_cache() gc.collect() @@ -1423,7 +1425,8 @@ class BaseSDTrainProcess(BaseTrainProcess): model_config_to_load.refiner_name_or_path = previous_refiner_save self.load_training_state_from_metadata(previous_refiner_save) - self.sd = StableDiffusion( + ModelClass = get_model_class(self.model_config) + self.sd = ModelClass( device=self.device, model_config=model_config_to_load, dtype=self.train_config.dtype, diff --git a/requirements.txt b/requirements.txt index 4040e760..f521b379 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,8 @@ torch==2.5.1 torchvision==0.20.1 safetensors -git+https://github.com/huggingface/diffusers@28f48f4051e80082cbe97f2d62b365dbb01040ec +# https://github.com/huggingface/diffusers/pull/10921 +git+https://github.com/huggingface/diffusers@refs/pull/10921/head transformers lycoris-lora==1.8.3 flatten_json diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 3fc7728c..f251743c 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -423,6 +423,9 @@ class TrainConfig: self.force_consistent_noise = kwargs.get('force_consistent_noise', False) +ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21'] + + class ModelConfig: def __init__(self, **kwargs): self.name_or_path: str = kwargs.get('name_or_path', None) @@ -500,6 +503,36 @@ class ModelConfig: self.split_model_other_module_param_count_scale = kwargs.get("split_model_other_module_param_count_scale", 0.3) self.te_name_or_path = kwargs.get("te_name_or_path", None) + + self.arch: ModelArch = kwargs.get("model_arch", None) + + # handle migrating to new model arch + if self.arch is None: + if kwargs.get('is_v2', False): + self.arch = 'sd2' + elif kwargs.get('is_v3', False): + self.arch = 'sd3' + elif kwargs.get('is_xl', False): + self.arch = 'sdxl' + elif kwargs.get('is_pixart', False): + self.arch = 'pixart' + elif kwargs.get('is_pixart_sigma', False): + self.arch = 'pixart_sigma' + elif kwargs.get('is_auraflow', False): + self.arch = 'auraflow' + elif kwargs.get('is_flux', False): + self.arch = 'flux' + elif kwargs.get('is_flex2', False): + self.arch = 'flex2' + elif kwargs.get('is_lumina2', False): + self.arch = 'lumina2' + elif kwargs.get('is_vega', False): + self.arch = 'vega' + elif kwargs.get('is_ssd', False): + self.arch = 'ssd' + else: + self.arch = 'sd1' + class EMAConfig: diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py new file mode 100644 index 00000000..adc4e882 --- /dev/null +++ b/toolkit/models/base_model.py @@ -0,0 +1,1467 @@ +import copy +import gc +import json +import random +import shutil +import typing +from typing import Union, List, Literal +import os +from collections import OrderedDict +import copy +import yaml +from PIL import Image +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg +from torch.nn import Parameter +from tqdm import tqdm +from torchvision.transforms import Resize, transforms + +from toolkit.clip_vision_adapter import ClipVisionAdapter +from toolkit.custom_adapter import CustomAdapter +from toolkit.ip_adapter import IPAdapter +from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch +from toolkit.models.decorator import Decorator +from toolkit.paths import KEYMAPS_ROOT +from toolkit.prompt_utils import inject_trigger_into_prompt, PromptEmbeds, concat_prompt_embeds +from toolkit.reference_adapter import ReferenceAdapter +from toolkit.saving import save_ldm_model_from_diffusers +from toolkit.sd_device_states_presets import empty_preset +from toolkit.train_tools import get_torch_dtype, apply_noise_offset +import torch +from toolkit.pipelines import CustomStableDiffusionXLPipeline +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \ + LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \ + FluxTransformer2DModel +from toolkit.models.lumina2 import Lumina2Transformer2DModel +import diffusers +from diffusers import \ + AutoencoderKL, \ + UNet2DConditionModel +from diffusers import PixArtAlphaPipeline +from transformers import T5EncoderModel, UMT5EncoderModel +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection + +from toolkit.accelerator import get_accelerator, unwrap_model +from typing import TYPE_CHECKING +from toolkit.print import print_acc +from transformers import Gemma2Model, Qwen2Model, LlamaModel + +if TYPE_CHECKING: + from toolkit.lora_special import LoRASpecialNetwork + +# tell it to shut up +diffusers.logging.set_verbosity(diffusers.logging.ERROR) + +SD_PREFIX_VAE = "vae" +SD_PREFIX_UNET = "unet" +SD_PREFIX_REFINER_UNET = "refiner_unet" +SD_PREFIX_TEXT_ENCODER = "te" + +SD_PREFIX_TEXT_ENCODER1 = "te0" +SD_PREFIX_TEXT_ENCODER2 = "te1" + +# prefixed diffusers keys +DO_NOT_TRAIN_WEIGHTS = [ + "unet_time_embedding.linear_1.bias", + "unet_time_embedding.linear_1.weight", + "unet_time_embedding.linear_2.bias", + "unet_time_embedding.linear_2.weight", + "refiner_unet_time_embedding.linear_1.bias", + "refiner_unet_time_embedding.linear_1.weight", + "refiner_unet_time_embedding.linear_2.bias", + "refiner_unet_time_embedding.linear_2.weight", +] + +DeviceStatePreset = Literal['cache_latents', 'generate'] + + +class BlankNetwork: + + def __init__(self): + self.multiplier = 1.0 + self.is_active = True + self.is_merged_in = False + self.can_merge_in = False + + def __enter__(self): + self.is_active = True + + def __exit__(self, exc_type, exc_val, exc_tb): + self.is_active = False + + def train(self): + pass + + +def flush(): + torch.cuda.empty_cache() + gc.collect() + + +UNET_IN_CHANNELS = 4 # Stable Diffusion の in_channels は 4 で固定。XLも同じ。 +# VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 + + +class BaseModel: + + def __init__( + self, + device, + model_config: ModelConfig, + dtype='fp16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + self.accelerator = get_accelerator() + self.custom_pipeline = custom_pipeline + self.device = str(self.accelerator.device) + self.dtype = dtype + self.torch_dtype = get_torch_dtype(dtype) + self.device_torch = self.accelerator.device + + self.vae_device_torch = self.accelerator.device + self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype) + + self.te_device_torch = self.accelerator.device + self.te_torch_dtype = get_torch_dtype(model_config.te_dtype) + + self.model_config = model_config + self.prediction_type = "v_prediction" if self.model_config.is_v_pred else "epsilon" + + self.device_state = None + + self.pipeline: Union[None, 'StableDiffusionPipeline', + 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline'] + self.vae: Union[None, 'AutoencoderKL'] + self.model: Union[None, 'Transformer2DModel', 'UNet2DConditionModel'] + self.text_encoder: Union[None, 'CLIPTextModel', + List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]] + self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']] + self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler + + self.refiner_unet: Union[None, 'UNet2DConditionModel'] = None + self.assistant_lora: Union[None, 'LoRASpecialNetwork'] = None + + # sdxl stuff + self.logit_scale = None + self.ckppt_info = None + self.is_loaded = False + + # to hold network if there is one + self.network = None + self.adapter: Union['ControlNetModel', 'T2IAdapter', + 'IPAdapter', 'ReferenceAdapter', None] = None + self.decorator: Union[Decorator, None] = None + self.arch: ModelArch = model_config.arch + + self.use_text_encoder_1 = model_config.use_text_encoder_1 + self.use_text_encoder_2 = model_config.use_text_encoder_2 + + self.config_file = None + + self.is_flow_matching = False + + self.quantize_device = self.device_torch + self.low_vram = self.model_config.low_vram + + # merge in and preview active with -1 weight + self.invert_assistant_lora = False + self._after_sample_img_hooks = [] + self._status_update_hooks = [] + + # properties for old arch for backwards compatibility + @property + def unet(self): + return self.model + + @property + def unet_unwrapped(self): + return unwrap_model(self.model) + + @property + def model_unwrapped(self): + return unwrap_model(self.model) + + @property + def is_xl(self): + return self.arch == 'sdxl' + + @property + def is_v2(self): + return self.arch == 'sd2' + + @property + def is_ssd(self): + return self.arch == 'ssd' + + @property + def is_v3(self): + return self.arch == 'sd3' + + @property + def is_vega(self): + return self.arch == 'vega' + + @property + def is_pixart(self): + return self.arch == 'pixart' + + @property + def is_auraflow(self): + return self.arch == 'auraflow' + + @property + def is_flux(self): + return self.arch == 'flux' + + @property + def is_flex2(self): + return self.arch == 'flex2' + + @property + def is_lumina2(self): + return self.arch == 'lumina2' + + # these must be implemented in child classes + def load_model(self): + # override this in child classes + raise NotImplementedError( + "load_model must be implemented in child classes") + + def get_generation_pipeline(self): + # override this in child classes + raise NotImplementedError( + "get_generation_pipeline must be implemented in child classes") + + def generate_single_image( + self, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + # override this in child classes + raise NotImplementedError( + "generate_single_image must be implemented in child classes") + + def get_noise_prediction( + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs + ): + raise NotImplementedError( + "get_noise_prediction must be implemented in child classes") + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + raise NotImplementedError( + "get_prompt_embeds must be implemented in child classes") + # end must be implemented in child classes + + def te_train(self): + if isinstance(self.text_encoder, list): + for te in self.text_encoder: + te.train() + elif self.text_encoder is not None: + self.text_encoder.train() + + def te_eval(self): + if isinstance(self.text_encoder, list): + for te in self.text_encoder: + te.eval() + elif self.text_encoder is not None: + self.text_encoder.eval() + + def _after_sample_image(self, img_num, total_imgs): + # process all hooks + for hook in self._after_sample_img_hooks: + hook(img_num, total_imgs) + + def add_after_sample_image_hook(self, func): + self._after_sample_img_hooks.append(func) + + def _status_update(self, status: str): + for hook in self._status_update_hooks: + hook(status) + + def print_and_status_update(self, status: str): + print_acc(status) + self._status_update(status) + + def add_status_update_hook(self, func): + self._status_update_hooks.append(func) + + @torch.no_grad() + def generate_images( + self, + image_configs: List[GenerateImageConfig], + sampler=None, + pipeline: Union[None, StableDiffusionPipeline, + StableDiffusionXLPipeline] = None, + ): + network = unwrap_model(self.network) + merge_multiplier = 1.0 + flush() + # if using assistant, unfuse it + if self.model_config.assistant_lora_path is not None: + print_acc("Unloading assistant lora") + if self.invert_assistant_lora: + self.assistant_lora.is_active = True + # move weights on to the device + self.assistant_lora.force_to( + self.device_torch, self.torch_dtype) + else: + self.assistant_lora.is_active = False + + if self.model_config.inference_lora_path is not None: + print_acc("Loading inference lora") + self.assistant_lora.is_active = True + # move weights on to the device + self.assistant_lora.force_to(self.device_torch, self.torch_dtype) + + if network is not None: + network.eval() + # check if we have the same network weight for all samples. If we do, we can merge in th + # the network to drastically speed up inference + unique_network_weights = set( + [x.network_multiplier for x in image_configs]) + if len(unique_network_weights) == 1 and network.can_merge_in: + can_merge_in = True + merge_multiplier = unique_network_weights.pop() + network.merge_in(merge_weight=merge_multiplier) + else: + network = BlankNetwork() + + self.save_device_state() + self.set_device_state_preset('generate') + + # save current seed state for training + rng_state = torch.get_rng_state() + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None + + if pipeline is None: + pipeline = self.get_generation_pipeline() + try: + pipeline.set_progress_bar_config(disable=True) + except: + pass + + start_multiplier = 1.0 + if network is not None: + start_multiplier = network.multiplier + + # pipeline.to(self.device_torch) + + with network: + with torch.no_grad(): + if network is not None: + assert network.is_active + + for i in tqdm(range(len(image_configs)), desc=f"Generating Images", leave=False): + gen_config = image_configs[i] + + extra = {} + validation_image = None + if self.adapter is not None and gen_config.adapter_image_path is not None: + validation_image = Image.open( + gen_config.adapter_image_path).convert("RGB") + if isinstance(self.adapter, T2IAdapter): + # not sure why this is double?? + validation_image = validation_image.resize( + (gen_config.width * 2, gen_config.height * 2)) + extra['image'] = validation_image + extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale + if isinstance(self.adapter, ControlNetModel): + validation_image = validation_image.resize( + (gen_config.width, gen_config.height)) + extra['image'] = validation_image + extra['controlnet_conditioning_scale'] = gen_config.adapter_conditioning_scale + if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter): + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + validation_image = transform(validation_image) + if isinstance(self.adapter, CustomAdapter): + # todo allow loading multiple + transform = transforms.Compose([ + transforms.ToTensor(), + ]) + validation_image = transform(validation_image) + self.adapter.num_images = 1 + if isinstance(self.adapter, ReferenceAdapter): + # need -1 to 1 + validation_image = transforms.ToTensor()(validation_image) + validation_image = validation_image * 2.0 - 1.0 + validation_image = validation_image.unsqueeze(0) + self.adapter.set_reference_images(validation_image) + + if network is not None: + network.multiplier = gen_config.network_multiplier + torch.manual_seed(gen_config.seed) + torch.cuda.manual_seed(gen_config.seed) + + generator = torch.manual_seed(gen_config.seed) + + if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter) \ + and gen_config.adapter_image_path is not None: + # run through the adapter to saturate the embeds + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + validation_image) + self.adapter(conditional_clip_embeds) + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + # handle condition the prompts + gen_config.prompt = self.adapter.condition_prompt( + gen_config.prompt, + is_unconditional=False, + ) + gen_config.prompt_2 = gen_config.prompt + gen_config.negative_prompt = self.adapter.condition_prompt( + gen_config.negative_prompt, + is_unconditional=True, + ) + gen_config.negative_prompt_2 = gen_config.negative_prompt + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None: + self.adapter.trigger_pre_te( + tensors_0_1=validation_image, + is_training=False, + has_been_preprocessed=False, + quad_count=4 + ) + + # encode the prompt ourselves so we can do fun stuff with embeddings + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + conditional_embeds = self.encode_prompt( + gen_config.prompt, gen_config.prompt_2, force_all=True) + + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = True + unconditional_embeds = self.encode_prompt( + gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True + ) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False + + # allow any manipulations to take place to embeddings + gen_config.post_process_embeddings( + conditional_embeds, + unconditional_embeds, + ) + + if self.decorator is not None: + # apply the decorator to the embeddings + conditional_embeds.text_embeds = self.decorator( + conditional_embeds.text_embeds) + unconditional_embeds.text_embeds = self.decorator( + unconditional_embeds.text_embeds, is_unconditional=True) + + if self.adapter is not None and isinstance(self.adapter, IPAdapter) \ + and gen_config.adapter_image_path is not None: + # apply the image projection + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( + validation_image) + unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, + True) + conditional_embeds = self.adapter( + conditional_embeds, conditional_clip_embeds, is_unconditional=False) + unconditional_embeds = self.adapter( + unconditional_embeds, unconditional_clip_embeds, is_unconditional=True) + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): + conditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=validation_image, + prompt_embeds=conditional_embeds, + is_training=False, + has_been_preprocessed=False, + is_generating_samples=True, + ) + unconditional_embeds = self.adapter.condition_encoded_embeds( + tensors_0_1=validation_image, + prompt_embeds=unconditional_embeds, + is_training=False, + has_been_preprocessed=False, + is_unconditional=True, + is_generating_samples=True, + ) + + if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len( + gen_config.extra_values) > 0: + extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch, + dtype=self.torch_dtype) + # apply extra values to the embeddings + self.adapter.add_extra_values( + extra_values, is_unconditional=False) + self.adapter.add_extra_values(torch.zeros_like( + extra_values), is_unconditional=True) + pass # todo remove, for debugging + + if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: + # if we have a refiner loaded, set the denoising end at the refiner start + extra['denoising_end'] = gen_config.refiner_start_at + extra['output_type'] = 'latent' + if not self.is_xl: + raise ValueError( + "Refiner is only supported for XL models") + + conditional_embeds = conditional_embeds.to( + self.device_torch, dtype=self.unet.dtype) + unconditional_embeds = unconditional_embeds.to( + self.device_torch, dtype=self.unet.dtype) + + img = self.generate_single_image( + gen_config, + conditional_embeds, + unconditional_embeds, + generator, + extra, + ) + + gen_config.save_image(img, i) + gen_config.log_image(img, i) + self._after_sample_image(i, len(image_configs)) + flush() + + if self.adapter is not None and isinstance(self.adapter, ReferenceAdapter): + self.adapter.clear_memory() + + # clear pipeline and cache to reduce vram usage + del pipeline + torch.cuda.empty_cache() + + # restore training state + torch.set_rng_state(rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) + + self.restore_device_state() + if network is not None: + network.train() + network.multiplier = start_multiplier + + self.unet.to(self.device_torch, dtype=self.torch_dtype) + if network.is_merged_in: + network.merge_out(merge_multiplier) + # self.tokenizer.to(original_device_dict['tokenizer']) + + # refuse loras + if self.model_config.assistant_lora_path is not None: + print_acc("Loading assistant lora") + if self.invert_assistant_lora: + self.assistant_lora.is_active = False + # move weights off the device + self.assistant_lora.force_to('cpu', self.torch_dtype) + else: + self.assistant_lora.is_active = True + + if self.model_config.inference_lora_path is not None: + print_acc("Unloading inference lora") + self.assistant_lora.is_active = False + # move weights off the device + self.assistant_lora.force_to('cpu', self.torch_dtype) + flush() + + def get_latent_noise( + self, + height=None, + width=None, + pixel_height=None, + pixel_width=None, + batch_size=1, + noise_offset=0.0, + ): + VAE_SCALE_FACTOR = 2 ** ( + len(self.vae.config['block_out_channels']) - 1) + if height is None and pixel_height is None: + raise ValueError("height or pixel_height must be specified") + if width is None and pixel_width is None: + raise ValueError("width or pixel_width must be specified") + if height is None: + height = pixel_height // VAE_SCALE_FACTOR + if width is None: + width = pixel_width // VAE_SCALE_FACTOR + + num_channels = self.unet_unwrapped.config['in_channels'] + if self.is_flux: + # has 64 channels in for some reason + num_channels = 16 + noise = torch.randn( + ( + batch_size, + num_channels, + height, + width, + ), + device=self.unet.device, + ) + noise = apply_noise_offset(noise, noise_offset) + return noise + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor + ) -> torch.FloatTensor: + original_samples_chunks = torch.chunk( + original_samples, original_samples.shape[0], dim=0) + noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) + timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) + + if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): + timesteps_chunks = [timesteps_chunks[0]] * \ + len(original_samples_chunks) + + noisy_latents_chunks = [] + + for idx in range(original_samples.shape[0]): + noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx], + timesteps_chunks[idx]) + noisy_latents_chunks.append(noisy_latents) + + noisy_latents = torch.cat(noisy_latents_chunks, dim=0) + return noisy_latents + + def predict_noise( + self, + latents: torch.Tensor, + text_embeddings: Union[PromptEmbeds, None] = None, + timestep: Union[int, torch.Tensor] = 1, + guidance_scale=7.5, + guidance_rescale=0, + add_time_ids=None, + conditional_embeddings: Union[PromptEmbeds, None] = None, + unconditional_embeddings: Union[PromptEmbeds, None] = None, + is_input_scaled=False, + detach_unconditional=False, + rescale_cfg=None, + return_conditional_pred=False, + guidance_embedding_scale=1.0, + bypass_guidance_embedding=False, + **kwargs, + ): + conditional_pred = None + # get the embeddings + if text_embeddings is None and conditional_embeddings is None: + raise ValueError( + "Either text_embeddings or conditional_embeddings must be specified") + if text_embeddings is None and unconditional_embeddings is not None: + text_embeddings = concat_prompt_embeds([ + unconditional_embeddings, # negative embedding + conditional_embeddings, # positive embedding + ]) + elif text_embeddings is None and conditional_embeddings is not None: + # not doing cfg + text_embeddings = conditional_embeddings + + # CFG is comparing neg and positive, if we have concatenated embeddings + # then we are doing it, otherwise we are not and takes half the time. + do_classifier_free_guidance = True + + # check if batch size of embeddings matches batch size of latents + if latents.shape[0] == text_embeddings.text_embeds.shape[0]: + do_classifier_free_guidance = False + elif latents.shape[0] * 2 != text_embeddings.text_embeds.shape[0]: + raise ValueError( + "Batch size of latents must be the same or half the batch size of text embeddings") + latents = latents.to(self.device_torch) + text_embeddings = text_embeddings.to(self.device_torch) + timestep = timestep.to(self.device_torch) + + # if timestep is zero dim, unsqueeze it + if len(timestep.shape) == 0: + timestep = timestep.unsqueeze(0) + + # if we only have 1 timestep, we can just use the same timestep for all + if timestep.shape[0] == 1 and latents.shape[0] > 1: + # check if it is rank 1 or 2 + if len(timestep.shape) == 1: + timestep = timestep.repeat(latents.shape[0]) + else: + timestep = timestep.repeat(latents.shape[0], 0) + + # handle t2i adapters + if 'down_intrablock_additional_residuals' in kwargs: + # go through each item and concat if doing cfg and it doesnt have the same shape + for idx, item in enumerate(kwargs['down_intrablock_additional_residuals']): + if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: + kwargs['down_intrablock_additional_residuals'][idx] = torch.cat([ + item] * 2, dim=0) + + # handle controlnet + if 'down_block_additional_residuals' in kwargs and 'mid_block_additional_residual' in kwargs: + # go through each item and concat if doing cfg and it doesnt have the same shape + for idx, item in enumerate(kwargs['down_block_additional_residuals']): + if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: + kwargs['down_block_additional_residuals'][idx] = torch.cat([ + item] * 2, dim=0) + for idx, item in enumerate(kwargs['mid_block_additional_residual']): + if do_classifier_free_guidance and item.shape[0] != text_embeddings.text_embeds.shape[0]: + kwargs['mid_block_additional_residual'][idx] = torch.cat( + [item] * 2, dim=0) + + def scale_model_input(model_input, timestep_tensor): + if is_input_scaled: + return model_input + mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) + timestep_chunks = torch.chunk( + timestep_tensor, timestep_tensor.shape[0], dim=0) + out_chunks = [] + # unsqueeze if timestep is zero dim + for idx in range(model_input.shape[0]): + # if scheduler has step_index + if hasattr(self.noise_scheduler, '_step_index'): + self.noise_scheduler._step_index = None + out_chunks.append( + self.noise_scheduler.scale_model_input( + mi_chunks[idx], timestep_chunks[idx]) + ) + return torch.cat(out_chunks, dim=0) + + with torch.no_grad(): + if do_classifier_free_guidance: + # if we are doing classifier free guidance, need to double up + latent_model_input = torch.cat([latents] * 2, dim=0) + timestep = torch.cat([timestep] * 2) + else: + latent_model_input = latents + + latent_model_input = scale_model_input( + latent_model_input, timestep) + + # check if we need to concat timesteps + if isinstance(timestep, torch.Tensor) and len(timestep.shape) > 1: + ts_bs = timestep.shape[0] + if ts_bs != latent_model_input.shape[0]: + if ts_bs == 1: + timestep = torch.cat( + [timestep] * latent_model_input.shape[0]) + elif ts_bs * 2 == latent_model_input.shape[0]: + timestep = torch.cat([timestep] * 2, dim=0) + else: + raise ValueError( + f"Batch size of latents {latent_model_input.shape[0]} must be the same or half the batch size of timesteps {timestep.shape[0]}") + + # predict the noise residual + if self.unet.device != self.device_torch: + self.unet.to(self.device_torch) + if self.unet.dtype != self.torch_dtype: + self.unet = self.unet.to(dtype=self.torch_dtype) + + noise_pred = self.get_noise_prediction( + latent_model_input=latent_model_input, + timestep=timestep, + text_embeddings=text_embeddings, + **kwargs + ) + + conditional_pred = noise_pred + + if do_classifier_free_guidance: + # perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2, dim=0) + conditional_pred = noise_pred_text + if detach_unconditional: + noise_pred_uncond = noise_pred_uncond.detach() + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + if rescale_cfg is not None and rescale_cfg != guidance_scale: + with torch.no_grad(): + # do cfg at the target rescale so we can match it + target_pred_mean_std = noise_pred_uncond + rescale_cfg * ( + noise_pred_text - noise_pred_uncond + ) + target_mean = target_pred_mean_std.mean( + [1, 2, 3], keepdim=True).detach() + target_std = target_pred_mean_std.std( + [1, 2, 3], keepdim=True).detach() + + pred_mean = noise_pred.mean( + [1, 2, 3], keepdim=True).detach() + pred_std = noise_pred.std([1, 2, 3], keepdim=True).detach() + + # match the mean and std + noise_pred = (noise_pred - pred_mean) / pred_std + noise_pred = (noise_pred * target_std) + target_mean + + # https://github.com/huggingface/diffusers/blob/7a91ea6c2b53f94da930a61ed571364022b21044/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L775 + if guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + if return_conditional_pred: + return noise_pred, conditional_pred + return noise_pred + + def step_scheduler(self, model_input, latent_input, timestep_tensor, noise_scheduler=None): + if noise_scheduler is None: + noise_scheduler = self.noise_scheduler + # // sometimes they are on the wrong device, no idea why + if isinstance(noise_scheduler, DDPMScheduler) or isinstance(noise_scheduler, LCMScheduler): + try: + noise_scheduler.betas = noise_scheduler.betas.to( + self.device_torch) + noise_scheduler.alphas = noise_scheduler.alphas.to( + self.device_torch) + noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.to( + self.device_torch) + except Exception as e: + pass + + mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) + latent_chunks = torch.chunk(latent_input, latent_input.shape[0], dim=0) + timestep_chunks = torch.chunk( + timestep_tensor, timestep_tensor.shape[0], dim=0) + out_chunks = [] + if len(timestep_chunks) == 1 and len(mi_chunks) > 1: + # expand timestep to match + timestep_chunks = timestep_chunks * len(mi_chunks) + + for idx in range(model_input.shape[0]): + # Reset it so it is unique for the + if hasattr(noise_scheduler, '_step_index'): + noise_scheduler._step_index = None + if hasattr(noise_scheduler, 'is_scale_input_called'): + noise_scheduler.is_scale_input_called = True + out_chunks.append( + noise_scheduler.step(mi_chunks[idx], timestep_chunks[idx], latent_chunks[idx], return_dict=False)[ + 0] + ) + return torch.cat(out_chunks, dim=0) + + # ref: https://github.com/huggingface/diffusers/blob/0bab447670f47c28df60fbd2f6a0f833f75a16f5/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L746 + def diffuse_some_steps( + self, + latents: torch.FloatTensor, + text_embeddings: PromptEmbeds, + total_timesteps: int = 1000, + start_timesteps=0, + guidance_scale=1, + add_time_ids=None, + bleed_ratio: float = 0.5, + bleed_latents: torch.FloatTensor = None, + is_input_scaled=False, + return_first_prediction=False, + **kwargs, + ): + timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps] + + first_prediction = None + + for timestep in tqdm(timesteps_to_run, leave=False): + timestep = timestep.unsqueeze_(0) + noise_pred, conditional_pred = self.predict_noise( + latents, + text_embeddings, + timestep, + guidance_scale=guidance_scale, + add_time_ids=add_time_ids, + is_input_scaled=is_input_scaled, + return_conditional_pred=True, + **kwargs, + ) + # some schedulers need to run separately, so do that. (euler for example) + + if return_first_prediction and first_prediction is None: + first_prediction = conditional_pred + + latents = self.step_scheduler(noise_pred, latents, timestep) + + # if not last step, and bleeding, bleed in some latents + if bleed_latents is not None and timestep != self.noise_scheduler.timesteps[-1]: + latents = (latents * (1 - bleed_ratio)) + \ + (bleed_latents * bleed_ratio) + + # only skip first scaling + is_input_scaled = False + + # return latents_steps + if return_first_prediction: + return latents, first_prediction + return latents + + def encode_prompt( + self, + prompt, + prompt2=None, + num_images_per_prompt=1, + force_all=False, + long_prompts=False, + max_length=None, + dropout_prob=0.0, + ) -> PromptEmbeds: + # sd1.5 embeddings are (bs, 77, 768) + prompt = prompt + # if it is not a list, make it one + if not isinstance(prompt, list): + prompt = [prompt] + + if prompt2 is not None and not isinstance(prompt2, list): + prompt2 = [prompt2] + + return self.get_prompt_embeds(prompt) + + @torch.no_grad() + def encode_images( + self, + image_list: List[torch.Tensor], + device=None, + dtype=None + ): + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + + latent_list = [] + # Move to vae to device if on cpu + if self.vae.device == 'cpu': + self.vae.to(device) + self.vae.eval() + self.vae.requires_grad_(False) + # move to device and dtype + image_list = [image.to(device, dtype=dtype) for image in image_list] + + VAE_SCALE_FACTOR = 2 ** ( + len(self.vae.config['block_out_channels']) - 1) + + # resize images if not divisible by 8 + for i in range(len(image_list)): + image = image_list[i] + if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0: + image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR, + image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image) + + images = torch.stack(image_list) + if isinstance(self.vae, AutoencoderTiny): + latents = self.vae.encode(images, return_dict=False)[0] + else: + latents = self.vae.encode(images).latent_dist.sample() + shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 + + # flux ref https://github.com/black-forest-labs/flux/blob/c23ae247225daba30fbd56058d247cc1b1fc20a3/src/flux/modules/autoencoder.py#L303 + # z = self.scale_factor * (z - self.shift_factor) + latents = self.vae.config['scaling_factor'] * (latents - shift) + latents = latents.to(device, dtype=dtype) + + return latents + + def decode_latents( + self, + latents: torch.Tensor, + device=None, + dtype=None + ): + if device is None: + device = self.device + if dtype is None: + dtype = self.torch_dtype + + # Move to vae to device if on cpu + if self.vae.device == 'cpu': + self.vae.to(self.device) + latents = latents.to(device, dtype=dtype) + latents = ( + latents / self.vae.config['scaling_factor']) + self.vae.config['shift_factor'] + images = self.vae.decode(latents).sample + images = images.to(device, dtype=dtype) + + return images + + def encode_image_prompt_pairs( + self, + prompt_list: List[str], + image_list: List[torch.Tensor], + device=None, + dtype=None + ): + # todo check image types and expand and rescale as needed + # device and dtype are for outputs + if device is None: + device = self.device + if dtype is None: + dtype = self.torch_dtype + + embedding_list = [] + latent_list = [] + # embed the prompts + for prompt in prompt_list: + embedding = self.encode_prompt(prompt).to( + self.device_torch, dtype=dtype) + embedding_list.append(embedding) + + return embedding_list, latent_list + + def get_weight_by_name(self, name): + # weights begin with te{te_num}_ for text encoder + # weights begin with unet_ for unet_ + if name.startswith('te'): + key = name[4:] + # text encoder + te_num = int(name[2]) + if isinstance(self.text_encoder, list): + return self.text_encoder[te_num].state_dict()[key] + else: + return self.text_encoder.state_dict()[key] + elif name.startswith('unet'): + key = name[5:] + # unet + return self.unet.state_dict()[key] + + raise ValueError(f"Unknown weight name: {name}") + + def inject_trigger_into_prompt(self, prompt, trigger=None, to_replace_list=None, add_if_not_present=False): + return inject_trigger_into_prompt( + prompt, + trigger=trigger, + to_replace_list=to_replace_list, + add_if_not_present=add_if_not_present, + ) + + def state_dict(self, vae=True, text_encoder=True, unet=True): + state_dict = OrderedDict() + if vae: + for k, v in self.vae.state_dict().items(): + new_key = k if k.startswith( + f"{SD_PREFIX_VAE}") else f"{SD_PREFIX_VAE}_{k}" + state_dict[new_key] = v + if text_encoder: + if isinstance(self.text_encoder, list): + for i, encoder in enumerate(self.text_encoder): + for k, v in encoder.state_dict().items(): + new_key = k if k.startswith( + f"{SD_PREFIX_TEXT_ENCODER}{i}_") else f"{SD_PREFIX_TEXT_ENCODER}{i}_{k}" + state_dict[new_key] = v + else: + for k, v in self.text_encoder.state_dict().items(): + new_key = k if k.startswith( + f"{SD_PREFIX_TEXT_ENCODER}_") else f"{SD_PREFIX_TEXT_ENCODER}_{k}" + state_dict[new_key] = v + if unet: + for k, v in self.unet.state_dict().items(): + new_key = k if k.startswith( + f"{SD_PREFIX_UNET}_") else f"{SD_PREFIX_UNET}_{k}" + state_dict[new_key] = v + return state_dict + + def named_parameters(self, vae=True, text_encoder=True, unet=True, refiner=False, state_dict_keys=False) -> \ + OrderedDict[ + str, Parameter]: + named_params: OrderedDict[str, Parameter] = OrderedDict() + if vae: + for name, param in self.vae.named_parameters(recurse=True, prefix=f"{SD_PREFIX_VAE}"): + named_params[name] = param + if text_encoder: + if isinstance(self.text_encoder, list): + for i, encoder in enumerate(self.text_encoder): + if self.is_xl and not self.model_config.use_text_encoder_1 and i == 0: + # dont add these params + continue + if self.is_xl and not self.model_config.use_text_encoder_2 and i == 1: + # dont add these params + continue + + for name, param in encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}{i}"): + named_params[name] = param + else: + for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"): + named_params[name] = param + if unet: + if self.is_flux or self.is_lumina2: + for name, param in self.unet.named_parameters(recurse=True, prefix="transformer"): + named_params[name] = param + else: + for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): + named_params[name] = param + + if self.model_config.ignore_if_contains is not None: + # remove params that contain the ignore_if_contains from named params + for key in list(named_params.keys()): + if any([s in key for s in self.model_config.ignore_if_contains]): + del named_params[key] + if self.model_config.only_if_contains is not None: + # remove params that do not contain the only_if_contains from named params + for key in list(named_params.keys()): + if not any([s in key for s in self.model_config.only_if_contains]): + del named_params[key] + + if refiner: + for name, param in self.refiner_unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_REFINER_UNET}"): + named_params[name] = param + + # convert to state dict keys, jsut replace . with _ on keys + if state_dict_keys: + new_named_params = OrderedDict() + for k, v in named_params.items(): + # replace only the first . with an _ + new_key = k.replace('.', '_', 1) + new_named_params[new_key] = v + named_params = new_named_params + + return named_params + + def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None): + version_string = '1' + if self.is_v2: + version_string = '2' + if self.is_xl: + version_string = 'sdxl' + if self.is_ssd: + # overwrite sdxl because both wil be true here + version_string = 'ssd' + if self.is_ssd and self.is_vega: + version_string = 'vega' + # if output file does not end in .safetensors, then it is a directory and we are + # saving in diffusers format + if not output_file.endswith('.safetensors'): + # diffusers + if self.is_flux: + # only save the unet + transformer: FluxTransformer2DModel = unwrap_model(self.unet) + transformer.save_pretrained( + save_directory=os.path.join(output_file, 'transformer'), + safe_serialization=True, + ) + elif self.is_lumina2: + # only save the unet + transformer: Lumina2Transformer2DModel = unwrap_model( + self.unet) + transformer.save_pretrained( + save_directory=os.path.join(output_file, 'transformer'), + safe_serialization=True, + ) + + else: + + self.pipeline.save_pretrained( + save_directory=output_file, + safe_serialization=True, + ) + # save out meta config + meta_path = os.path.join(output_file, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(meta, f) + + else: + save_ldm_model_from_diffusers( + sd=self, + output_file=output_file, + meta=meta, + save_dtype=save_dtype, + sd_version=version_string, + ) + if self.config_file is not None: + output_path_no_ext = os.path.splitext(output_file)[0] + output_config_path = f"{output_path_no_ext}.yaml" + shutil.copyfile(self.config_file, output_config_path) + + def prepare_optimizer_params( + self, + unet=False, + text_encoder=False, + text_encoder_lr=None, + unet_lr=None, + refiner_lr=None, + refiner=False, + default_lr=1e-6, + ): + # todo maybe only get locon ones? + # not all items are saved, to make it match, we need to match out save mappings + # and not train anything not mapped. Also add learning rate + version = 'sd1' + if self.is_xl: + version = 'sdxl' + if self.is_v2: + version = 'sd2' + mapping_filename = f"stable_diffusion_{version}.json" + mapping_path = os.path.join(KEYMAPS_ROOT, mapping_filename) + with open(mapping_path, 'r') as f: + mapping = json.load(f) + ldm_diffusers_keymap = mapping['ldm_diffusers_keymap'] + + trainable_parameters = [] + + # we use state dict to find params + + if unet: + named_params = self.named_parameters( + vae=False, unet=unet, text_encoder=False, state_dict_keys=True) + unet_lr = unet_lr if unet_lr is not None else default_lr + params = [] + if self.is_pixart or self.is_auraflow or self.is_flux or self.is_v3 or self.is_lumina2: + for param in named_params.values(): + if param.requires_grad: + params.append(param) + else: + for key, diffusers_key in ldm_diffusers_keymap.items(): + if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: + if named_params[diffusers_key].requires_grad: + params.append(named_params[diffusers_key]) + param_data = {"params": params, "lr": unet_lr} + trainable_parameters.append(param_data) + print_acc(f"Found {len(params)} trainable parameter in unet") + + if text_encoder: + named_params = self.named_parameters( + vae=False, unet=False, text_encoder=text_encoder, state_dict_keys=True) + text_encoder_lr = text_encoder_lr if text_encoder_lr is not None else default_lr + params = [] + for key, diffusers_key in ldm_diffusers_keymap.items(): + if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: + if named_params[diffusers_key].requires_grad: + params.append(named_params[diffusers_key]) + param_data = {"params": params, "lr": text_encoder_lr} + trainable_parameters.append(param_data) + + print_acc( + f"Found {len(params)} trainable parameter in text encoder") + + if refiner: + named_params = self.named_parameters(vae=False, unet=False, text_encoder=False, refiner=True, + state_dict_keys=True) + refiner_lr = refiner_lr if refiner_lr is not None else default_lr + params = [] + for key, diffusers_key in ldm_diffusers_keymap.items(): + diffusers_key = f"refiner_{diffusers_key}" + if diffusers_key in named_params and diffusers_key not in DO_NOT_TRAIN_WEIGHTS: + if named_params[diffusers_key].requires_grad: + params.append(named_params[diffusers_key]) + param_data = {"params": params, "lr": refiner_lr} + trainable_parameters.append(param_data) + + print_acc(f"Found {len(params)} trainable parameter in refiner") + + return trainable_parameters + + def save_device_state(self): + # saves the current device state for all modules + # this is useful for when we want to alter the state and restore it + if self.is_lumina2: + unet_has_grad = self.unet.x_embedder.weight.requires_grad + elif self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux: + unet_has_grad = self.unet.proj_out.weight.requires_grad + else: + unet_has_grad = self.unet.conv_in.weight.requires_grad + + self.device_state = { + **empty_preset, + 'vae': { + 'training': self.vae.training, + 'device': self.vae.device, + }, + 'unet': { + 'training': self.unet.training, + 'device': self.unet.device, + 'requires_grad': unet_has_grad, + }, + } + if isinstance(self.text_encoder, list): + self.device_state['text_encoder']: List[dict] = [] + for encoder in self.text_encoder: + if isinstance(encoder, LlamaModel): + te_has_grad = encoder.layers[0].mlp.gate_proj.weight.requires_grad + else: + try: + te_has_grad = encoder.text_model.final_layer_norm.weight.requires_grad + except: + te_has_grad = encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad + self.device_state['text_encoder'].append({ + 'training': encoder.training, + 'device': encoder.device, + # todo there has to be a better way to do this + 'requires_grad': te_has_grad + }) + else: + if isinstance(self.text_encoder, T5EncoderModel) or isinstance(self.text_encoder, UMT5EncoderModel): + te_has_grad = self.text_encoder.encoder.block[ + 0].layer[0].SelfAttention.q.weight.requires_grad + elif isinstance(self.text_encoder, Gemma2Model): + te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad + elif isinstance(self.text_encoder, Qwen2Model): + te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad + elif isinstance(self.text_encoder, LlamaModel): + te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad + else: + te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad + + self.device_state['text_encoder'] = { + 'training': self.text_encoder.training, + 'device': self.text_encoder.device, + 'requires_grad': te_has_grad + } + if self.adapter is not None: + if isinstance(self.adapter, IPAdapter): + requires_grad = self.adapter.image_proj_model.training + adapter_device = self.unet.device + elif isinstance(self.adapter, T2IAdapter): + requires_grad = self.adapter.adapter.conv_in.weight.requires_grad + adapter_device = self.adapter.device + elif isinstance(self.adapter, ControlNetModel): + requires_grad = self.adapter.conv_in.training + adapter_device = self.adapter.device + elif isinstance(self.adapter, ClipVisionAdapter): + requires_grad = self.adapter.embedder.training + adapter_device = self.adapter.device + elif isinstance(self.adapter, CustomAdapter): + requires_grad = self.adapter.training + adapter_device = self.adapter.device + elif isinstance(self.adapter, ReferenceAdapter): + # todo update this!! + requires_grad = True + adapter_device = self.adapter.device + else: + raise ValueError(f"Unknown adapter type: {type(self.adapter)}") + self.device_state['adapter'] = { + 'training': self.adapter.training, + 'device': adapter_device, + 'requires_grad': requires_grad, + } + + if self.refiner_unet is not None: + self.device_state['refiner_unet'] = { + 'training': self.refiner_unet.training, + 'device': self.refiner_unet.device, + 'requires_grad': self.refiner_unet.conv_in.weight.requires_grad, + } + + def restore_device_state(self): + # restores the device state for all modules + # this is useful for when we want to alter the state and restore it + if self.device_state is None: + return + self.set_device_state(self.device_state) + self.device_state = None + + def set_device_state(self, state): + if state['vae']['training']: + self.vae.train() + else: + self.vae.eval() + self.vae.to(state['vae']['device']) + if state['unet']['training']: + self.unet.train() + else: + self.unet.eval() + self.unet.to(state['unet']['device']) + if state['unet']['requires_grad']: + self.unet.requires_grad_(True) + else: + self.unet.requires_grad_(False) + if isinstance(self.text_encoder, list): + for i, encoder in enumerate(self.text_encoder): + if isinstance(state['text_encoder'], list): + if state['text_encoder'][i]['training']: + encoder.train() + else: + encoder.eval() + encoder.to(state['text_encoder'][i]['device']) + encoder.requires_grad_( + state['text_encoder'][i]['requires_grad']) + else: + if state['text_encoder']['training']: + encoder.train() + else: + encoder.eval() + encoder.to(state['text_encoder']['device']) + encoder.requires_grad_( + state['text_encoder']['requires_grad']) + else: + if state['text_encoder']['training']: + self.text_encoder.train() + else: + self.text_encoder.eval() + self.text_encoder.to(state['text_encoder']['device']) + self.text_encoder.requires_grad_( + state['text_encoder']['requires_grad']) + + if self.adapter is not None: + self.adapter.to(state['adapter']['device']) + self.adapter.requires_grad_(state['adapter']['requires_grad']) + if state['adapter']['training']: + self.adapter.train() + else: + self.adapter.eval() + + if self.refiner_unet is not None: + self.refiner_unet.to(state['refiner_unet']['device']) + self.refiner_unet.requires_grad_( + state['refiner_unet']['requires_grad']) + if state['refiner_unet']['training']: + self.refiner_unet.train() + else: + self.refiner_unet.eval() + flush() + + def set_device_state_preset(self, device_state_preset: DeviceStatePreset): + # sets a preset for device state + + # save current state first + self.save_device_state() + + active_modules = [] + training_modules = [] + if device_state_preset in ['cache_latents']: + active_modules = ['vae'] + if device_state_preset in ['cache_clip']: + active_modules = ['clip'] + if device_state_preset in ['generate']: + active_modules = ['vae', 'unet', + 'text_encoder', 'adapter', 'refiner_unet'] + + state = copy.deepcopy(empty_preset) + # vae + state['vae'] = { + 'training': 'vae' in training_modules, + 'device': self.vae_device_torch if 'vae' in active_modules else 'cpu', + 'requires_grad': 'vae' in training_modules, + } + + # unet + state['unet'] = { + 'training': 'unet' in training_modules, + 'device': self.device_torch if 'unet' in active_modules else 'cpu', + 'requires_grad': 'unet' in training_modules, + } + + if self.refiner_unet is not None: + state['refiner_unet'] = { + 'training': 'refiner_unet' in training_modules, + 'device': self.device_torch if 'refiner_unet' in active_modules else 'cpu', + 'requires_grad': 'refiner_unet' in training_modules, + } + + # text encoder + if isinstance(self.text_encoder, list): + state['text_encoder'] = [] + for i, encoder in enumerate(self.text_encoder): + state['text_encoder'].append({ + 'training': 'text_encoder' in training_modules, + 'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu', + 'requires_grad': 'text_encoder' in training_modules, + }) + else: + state['text_encoder'] = { + 'training': 'text_encoder' in training_modules, + 'device': self.te_device_torch if 'text_encoder' in active_modules else 'cpu', + 'requires_grad': 'text_encoder' in training_modules, + } + + if self.adapter is not None: + state['adapter'] = { + 'training': 'adapter' in training_modules, + 'device': self.device_torch if 'adapter' in active_modules else 'cpu', + 'requires_grad': 'adapter' in training_modules, + } + + self.set_device_state(state) + + def text_encoder_to(self, *args, **kwargs): + if isinstance(self.text_encoder, list): + for encoder in self.text_encoder: + encoder.to(*args, **kwargs) + else: + self.text_encoder.to(*args, **kwargs) diff --git a/toolkit/models/wan21.py b/toolkit/models/wan21.py new file mode 100644 index 00000000..b52017a2 --- /dev/null +++ b/toolkit/models/wan21.py @@ -0,0 +1,56 @@ + +import torch +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.models.base_model import BaseModel +from toolkit.prompt_utils import PromptEmbeds +from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanPipeline + +class Wan21(BaseModel): + def __init__( + self, + device, + model_config: ModelConfig, + dtype='bf16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + super().__init__(device, model_config, dtype, + custom_pipeline, noise_scheduler, **kwargs) + self.is_flow_matching = True + # these must be implemented in child classes + + def load_model(self): + # override this in child classes + raise NotImplementedError( + "load_model must be implemented in child classes") + + def get_generation_pipeline(self): + # override this in child classes + raise NotImplementedError( + "get_generation_pipeline must be implemented in child classes") + + def generate_single_image( + self, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + # override this in child classes + raise NotImplementedError( + "generate_single_image must be implemented in child classes") + + def get_noise_prediction( + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs + ): + raise NotImplementedError( + "get_noise_prediction must be implemented in child classes") + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + raise NotImplementedError( + "get_prompt_embeds must be implemented in child classes") diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 00b38574..38e15825 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -29,7 +29,7 @@ from toolkit.ip_adapter import IPAdapter from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \ convert_vae_state_dict, load_vae from toolkit import train_tools -from toolkit.config_modules import ModelConfig, GenerateImageConfig +from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch from toolkit.metadata import get_meta_for_safetensors from toolkit.models.decorator import Decorator from toolkit.paths import REPOS_ROOT, KEYMAPS_ROOT @@ -177,16 +177,17 @@ class StableDiffusion: self.network = None self.adapter: Union['ControlNetModel', 'T2IAdapter', 'IPAdapter', 'ReferenceAdapter', None] = None self.decorator: Union[Decorator, None] = None - self.is_xl = model_config.is_xl - self.is_v2 = model_config.is_v2 - self.is_ssd = model_config.is_ssd - self.is_v3 = model_config.is_v3 - self.is_vega = model_config.is_vega - self.is_pixart = model_config.is_pixart - self.is_auraflow = model_config.is_auraflow - self.is_flux = model_config.is_flux - self.is_flex2 = model_config.is_flex2 - self.is_lumina2 = model_config.is_lumina2 + self.arch: ModelArch = model_config.arch + # self.is_xl = model_config.is_xl + # self.is_v2 = model_config.is_v2 + # self.is_ssd = model_config.is_ssd + # self.is_v3 = model_config.is_v3 + # self.is_vega = model_config.is_vega + # self.is_pixart = model_config.is_pixart + # self.is_auraflow = model_config.is_auraflow + # self.is_flux = model_config.is_flux + # self.is_flex2 = model_config.is_flex2 + # self.is_lumina2 = model_config.is_lumina2 self.use_text_encoder_1 = model_config.use_text_encoder_1 self.use_text_encoder_2 = model_config.use_text_encoder_2 @@ -204,6 +205,47 @@ class StableDiffusion: self.invert_assistant_lora = False self._after_sample_img_hooks = [] self._status_update_hooks = [] + + # properties for old arch for backwards compatibility + @property + def is_xl(self): + return self.arch == 'sdxl' + + @property + def is_v2(self): + return self.arch == 'sd2' + + @property + def is_ssd(self): + return self.arch == 'ssd' + + @property + def is_v3(self): + return self.arch == 'sd3' + + @property + def is_vega(self): + return self.arch == 'vega' + + @property + def is_pixart(self): + return self.arch == 'pixart' + + @property + def is_auraflow(self): + return self.arch == 'auraflow' + + @property + def is_flux(self): + return self.arch == 'flux' + + @property + def is_flex2(self): + return self.arch == 'flex2' + + @property + def is_lumina2(self): + return self.arch == 'lumina2' def load_model(self): if self.is_loaded: diff --git a/toolkit/util/get_model.py b/toolkit/util/get_model.py new file mode 100644 index 00000000..b22d52c5 --- /dev/null +++ b/toolkit/util/get_model.py @@ -0,0 +1,9 @@ +from toolkit.stable_diffusion_model import StableDiffusion +from toolkit.config_modules import ModelConfig + +def get_model_class(config: ModelConfig): + if config.arch == "wan21": + from toolkit.models.wan21 import Wan21 + return Wan21 + else: + return StableDiffusion \ No newline at end of file From f5e40dfa62ffeb4a7e3e5b45efd9861b978aec75 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 1 Mar 2025 16:12:52 -0700 Subject: [PATCH 2/8] WIP on wan --- .gitmodules | 8 ++++++++ repositories/wan21 | 1 + requirements.txt | 3 +-- toolkit/models/wan21.py | 39 +++++++++++++++++++++++++++++++++++---- 4 files changed, 45 insertions(+), 6 deletions(-) create mode 160000 repositories/wan21 diff --git a/.gitmodules b/.gitmodules index 657cf28b..ea80e2af 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,12 +1,20 @@ [submodule "repositories/sd-scripts"] path = repositories/sd-scripts url = https://github.com/kohya-ss/sd-scripts.git + commit = b78c0e2a69e52ce6c79abc6c8c82d1a9cabcf05c [submodule "repositories/leco"] path = repositories/leco url = https://github.com/p1atdev/LECO + commit = 9294adf40218e917df4516737afb13f069a6789d [submodule "repositories/batch_annotator"] path = repositories/batch_annotator url = https://github.com/ostris/batch-annotator + commit = 420e142f6ad3cc14b3ea0500affc2c6c7e7544bf [submodule "repositories/ipadapter"] path = repositories/ipadapter url = https://github.com/tencent-ailab/IP-Adapter.git + commit = 5a18b1f3660acaf8bee8250692d6fb3548a19b14 +[submodule "repositories/wan21"] + path = repositories/wan21 + url = https://github.com/Wan-Video/Wan2.1.git + commit = a326079926a4a347ecda8863dc40ba2d7680a294 \ No newline at end of file diff --git a/repositories/wan21 b/repositories/wan21 new file mode 160000 index 00000000..a3260799 --- /dev/null +++ b/repositories/wan21 @@ -0,0 +1 @@ +Subproject commit a326079926a4a347ecda8863dc40ba2d7680a294 diff --git a/requirements.txt b/requirements.txt index f521b379..4040e760 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,7 @@ torch==2.5.1 torchvision==0.20.1 safetensors -# https://github.com/huggingface/diffusers/pull/10921 -git+https://github.com/huggingface/diffusers@refs/pull/10921/head +git+https://github.com/huggingface/diffusers@28f48f4051e80082cbe97f2d62b365dbb01040ec transformers lycoris-lora==1.8.3 flatten_json diff --git a/toolkit/models/wan21.py b/toolkit/models/wan21.py index b52017a2..f30c1cd6 100644 --- a/toolkit/models/wan21.py +++ b/toolkit/models/wan21.py @@ -3,7 +3,38 @@ import torch from toolkit.config_modules import GenerateImageConfig, ModelConfig from toolkit.models.base_model import BaseModel from toolkit.prompt_utils import PromptEmbeds -from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanPipeline +from toolkit.paths import REPOS_ROOT +import sys +import os + +import gc +import logging +import math +import os +import random +import sys +import types +from contextlib import contextmanager +from functools import partial + +import torch +import torch.cuda.amp as amp +import torch.distributed as dist +from tqdm import tqdm + + +WAN_ROOT = os.path.join(REPOS_ROOT, "wan21") +sys.path.append(WAN_ROOT) + +if True: + from wan.text2video import WanT2V + from wan.distributed.fsdp import shard_model + from wan.modules.model import WanModel + from wan.modules.t5 import T5EncoderModel + from wan.modules.vae import WanVAE + from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, retrieve_timesteps) + from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler class Wan21(BaseModel): def __init__( @@ -21,9 +52,9 @@ class Wan21(BaseModel): # these must be implemented in child classes def load_model(self): - # override this in child classes - raise NotImplementedError( - "load_model must be implemented in child classes") + self.pipeline = Wan21( + + ) def get_generation_pipeline(self): # override this in child classes From e7dbb20f68746dc03cf30feebfd7a2002b231410 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 4 Mar 2025 00:29:19 -0700 Subject: [PATCH 3/8] Removed wan submodule for now --- .gitmodules | 4 ---- repositories/wan21 | 1 - 2 files changed, 5 deletions(-) delete mode 160000 repositories/wan21 diff --git a/.gitmodules b/.gitmodules index ea80e2af..a98073dc 100644 --- a/.gitmodules +++ b/.gitmodules @@ -14,7 +14,3 @@ path = repositories/ipadapter url = https://github.com/tencent-ailab/IP-Adapter.git commit = 5a18b1f3660acaf8bee8250692d6fb3548a19b14 -[submodule "repositories/wan21"] - path = repositories/wan21 - url = https://github.com/Wan-Video/Wan2.1.git - commit = a326079926a4a347ecda8863dc40ba2d7680a294 \ No newline at end of file diff --git a/repositories/wan21 b/repositories/wan21 deleted file mode 160000 index a3260799..00000000 --- a/repositories/wan21 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit a326079926a4a347ecda8863dc40ba2d7680a294 From c57434ad7b6c2f672e91c42c4679966d47c10c84 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 4 Mar 2025 00:32:24 -0700 Subject: [PATCH 4/8] Removed wan submodule stuff for now --- toolkit/models/wan21.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/toolkit/models/wan21.py b/toolkit/models/wan21.py index f30c1cd6..b9a98400 100644 --- a/toolkit/models/wan21.py +++ b/toolkit/models/wan21.py @@ -1,4 +1,4 @@ - +# WIP, coming soon ish import torch from toolkit.config_modules import GenerateImageConfig, ModelConfig from toolkit.models.base_model import BaseModel @@ -23,19 +23,6 @@ import torch.distributed as dist from tqdm import tqdm -WAN_ROOT = os.path.join(REPOS_ROOT, "wan21") -sys.path.append(WAN_ROOT) - -if True: - from wan.text2video import WanT2V - from wan.distributed.fsdp import shard_model - from wan.modules.model import WanModel - from wan.modules.t5 import T5EncoderModel - from wan.modules.vae import WanVAE - from wan.utils.fm_solvers import (FlowDPMSolverMultistepScheduler, - get_sampling_sigmas, retrieve_timesteps) - from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler - class Wan21(BaseModel): def __init__( self, From 6f6fb9081283d9a437fdb35f44329378d6fb1e08 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 4 Mar 2025 18:43:52 -0700 Subject: [PATCH 5/8] Added cogview4. Loss still needs work. --- extensions_built_in/sd_trainer/SDTrainer.py | 12 +- jobs/process/BaseSDTrainProcess.py | 51 ++- requirements.txt | 4 +- testing/test_vae.py | 23 +- toolkit/config_modules.py | 2 +- toolkit/lora_special.py | 12 +- toolkit/models/base_model.py | 117 ++--- toolkit/models/cogview4.py | 458 +++++++++++++++++++ toolkit/models/wan21.py | 14 +- toolkit/samplers/custom_flowmatch_sampler.py | 106 +++-- toolkit/stable_diffusion_model.py | 11 +- toolkit/util/get_model.py | 3 + 12 files changed, 661 insertions(+), 152 deletions(-) create mode 100644 toolkit/models/cogview4.py diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 827e05c4..3fe2f2f5 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -380,9 +380,19 @@ class SDTrainer(BaseSDTrainProcess): elif self.sd.prediction_type == 'v_prediction': # v-parameterization training target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps) - + + elif hasattr(self.sd, 'get_loss_target'): + target = self.sd.get_loss_target( + noise=noise, + batch=batch, + timesteps=timesteps, + ).detach() + elif self.sd.is_flow_matching: + # forward ODE target = (noise - batch.latents).detach() + # reverse ODE + # target = (batch.latents - noise).detach() else: target = noise diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index b1493dbd..77122af9 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -668,7 +668,6 @@ class BaseSDTrainProcess(BaseTrainProcess): # # prepare all the models stuff for accelerator (hopefully we dont miss any) self.sd.vae = self.accelerator.prepare(self.sd.vae) if self.sd.unet is not None: - self.sd.unet_unwrapped = self.sd.unet self.sd.unet = self.accelerator.prepare(self.sd.unet) # todo always tdo it? self.modules_being_trained.append(self.sd.unet) @@ -1105,11 +1104,19 @@ class BaseSDTrainProcess(BaseTrainProcess): if timestep_type is None: timestep_type = self.train_config.timestep_type + patch_size = 1 + if self.sd.is_flux: + # flux is a patch size of 1, but latents are divided by 2, so we need to double it + patch_size = 2 + elif hasattr(self.sd.unet.config, 'patch_size'): + patch_size = self.sd.unet.config.patch_size + self.sd.noise_scheduler.set_train_timesteps( num_train_timesteps, device=self.device_torch, timestep_type=timestep_type, - latents=latents + latents=latents, + patch_size=patch_size, ) else: self.sd.noise_scheduler.set_timesteps( @@ -1403,21 +1410,26 @@ class BaseSDTrainProcess(BaseTrainProcess): model_config_to_load.name_or_path = latest_save_path self.load_training_state_from_metadata(latest_save_path) - # get the noise scheduler - arch = 'sd' - if self.model_config.is_pixart: - arch = 'pixart' - if self.model_config.is_flux: - arch = 'flux' - if self.model_config.is_lumina2: - arch = 'lumina2' - sampler = get_sampler( - self.train_config.noise_scheduler, - { - "prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon", - }, - arch=arch, - ) + ModelClass = get_model_class(self.model_config) + # if the model class has get_train_scheduler static method + if hasattr(ModelClass, 'get_train_scheduler'): + sampler = ModelClass.get_train_scheduler() + else: + # get the noise scheduler + arch = 'sd' + if self.model_config.is_pixart: + arch = 'pixart' + if self.model_config.is_flux: + arch = 'flux' + if self.model_config.is_lumina2: + arch = 'lumina2' + sampler = get_sampler( + self.train_config.noise_scheduler, + { + "prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon", + }, + arch=arch, + ) if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None: previous_refiner_save = self.get_latest_save_path(self.job.name + '_refiner') @@ -1425,7 +1437,6 @@ class BaseSDTrainProcess(BaseTrainProcess): model_config_to_load.refiner_name_or_path = previous_refiner_save self.load_training_state_from_metadata(previous_refiner_save) - ModelClass = get_model_class(self.model_config) self.sd = ModelClass( device=self.device, model_config=model_config_to_load, @@ -1562,6 +1573,9 @@ class BaseSDTrainProcess(BaseTrainProcess): # if is_lycoris: # preset = PRESET['full'] # NetworkClass.apply_preset(preset) + + if hasattr(self.sd, 'target_lora_modules'): + network_kwargs['target_lin_modules'] = self.sd.target_lora_modules self.network = NetworkClass( text_encoder=text_encoder, @@ -1590,6 +1604,7 @@ class BaseSDTrainProcess(BaseTrainProcess): network_config=self.network_config, network_type=self.network_config.type, transformer_only=self.network_config.transformer_only, + is_transformer=self.sd.is_transformer, **network_kwargs ) diff --git a/requirements.txt b/requirements.txt index 4040e760..d25678d2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ torch==2.5.1 torchvision==0.20.1 safetensors -git+https://github.com/huggingface/diffusers@28f48f4051e80082cbe97f2d62b365dbb01040ec -transformers +git+https://github.com/huggingface/diffusers@97fda1b75c70705b245a462044fedb47abb17e56 +transformers==4.49.0 lycoris-lora==1.8.3 flatten_json pyyaml diff --git a/testing/test_vae.py b/testing/test_vae.py index 44b31f63..463ab555 100644 --- a/testing/test_vae.py +++ b/testing/test_vae.py @@ -29,7 +29,7 @@ def paramiter_count(model): return int(paramiter_count) -def calculate_metrics(vae, images, max_imgs=-1): +def calculate_metrics(vae, images, max_imgs=-1, save_output=False): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") vae = vae.to(device) lpips_model = lpips.LPIPS(net='alex').to(device) @@ -44,6 +44,9 @@ def calculate_metrics(vae, images, max_imgs=-1): # ]) # needs values between -1 and 1 to_tensor = ToTensor() + + # remove _reconstructed.png files + images = [img for img in images if not img.endswith("_reconstructed.png")] if max_imgs > 0 and len(images) > max_imgs: images = images[:max_imgs] @@ -82,6 +85,15 @@ def calculate_metrics(vae, images, max_imgs=-1): avg_rfid = 0 avg_psnr = sum(psnr_scores) / len(psnr_scores) avg_lpips = sum(lpips_scores) / len(lpips_scores) + + if save_output: + filename_no_ext = os.path.splitext(os.path.basename(img_path))[0] + folder = os.path.dirname(img_path) + save_path = os.path.join(folder, filename_no_ext + "_reconstructed.png") + reconstructed = (reconstructed + 1) / 2 + reconstructed = reconstructed.clamp(0, 1) + reconstructed = transforms.ToPILImage()(reconstructed[0].cpu()) + reconstructed.save(save_path) return avg_rfid, avg_psnr, avg_lpips @@ -91,18 +103,23 @@ def main(): parser.add_argument("--vae_path", type=str, required=True, help="Path to the VAE model") parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images") parser.add_argument("--max_imgs", type=int, default=-1, help="Max num of images. Default is -1 for all images.") + # boolean store true + parser.add_argument("--save_output", action="store_true", help="Save the output images") args = parser.parse_args() if os.path.isfile(args.vae_path): vae = AutoencoderKL.from_single_file(args.vae_path) else: - vae = AutoencoderKL.from_pretrained(args.vae_path) + try: + vae = AutoencoderKL.from_pretrained(args.vae_path) + except: + vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae") vae.eval() vae = vae.to(device) print(f"Model has {paramiter_count(vae)} parameters") images = load_images(args.image_folder) - avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs) + avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs, args.save_output) # print(f"Average rFID: {avg_rfid}") print(f"Average PSNR: {avg_psnr}") diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index a9aa8cf8..e92f7cbe 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -513,7 +513,7 @@ class ModelConfig: self.te_name_or_path = kwargs.get("te_name_or_path", None) - self.arch: ModelArch = kwargs.get("model_arch", None) + self.arch: ModelArch = kwargs.get("arch", None) # handle migrating to new model arch if self.arch is None: diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 1b308cd4..84ac02b1 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -178,6 +178,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): transformer_only: bool = False, peft_format: bool = False, is_assistant_adapter: bool = False, + is_transformer: bool = False, **kwargs ) -> None: """ @@ -237,9 +238,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): self.network_config: NetworkConfig = kwargs.get("network_config", None) self.peft_format = peft_format + self.is_transformer = is_transformer + # always do peft for flux only for now - if self.is_flux or self.is_v3 or self.is_lumina2: + if self.is_flux or self.is_v3 or self.is_lumina2 or is_transformer: # don't do peft format for lokr if self.network_type.lower() != "lokr": self.peft_format = True @@ -282,7 +285,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): unet_prefix = self.LORA_PREFIX_UNET if self.peft_format: unet_prefix = self.PEFT_PREFIX_UNET - if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2: + if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2 or self.is_transformer: unet_prefix = f"lora_transformer" if self.peft_format: unet_prefix = "transformer" @@ -341,6 +344,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): if self.transformer_only and self.is_v3 and is_unet: if "transformer_blocks" not in lora_name: skip = True + + # handle custom models + if self.transformer_only and is_unet and hasattr(root_module, 'transformer_blocks'): + if "transformer_blocks" not in lora_name: + skip = True if (is_linear or is_conv2d) and not skip: diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index adc4e882..cae29ffd 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -168,11 +168,17 @@ class BaseModel: self.invert_assistant_lora = False self._after_sample_img_hooks = [] self._status_update_hooks = [] + self.is_transformer = False # properties for old arch for backwards compatibility @property def unet(self): return self.model + + # set unet to model + @unet.setter + def unet(self, value): + self.model = value @property def unet_unwrapped(self): @@ -235,6 +241,7 @@ class BaseModel: def generate_single_image( self, + pipeline, gen_config: GenerateImageConfig, conditional_embeds: PromptEmbeds, unconditional_embeds: PromptEmbeds, @@ -257,6 +264,25 @@ class BaseModel: def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: raise NotImplementedError( "get_prompt_embeds must be implemented in child classes") + + def get_model_has_grad(self): + raise NotImplementedError( + "get_model_has_grad must be implemented in child classes") + + def get_te_has_grad(self): + raise NotImplementedError( + "get_te_has_grad must be implemented in child classes") + + def save_model(self, output_path, meta, save_dtype): + # todo handle dtype without overloading anything (vram, cpu, etc) + unwrap_model(self.pipeline).save_pretrained( + save_directory=output_path, + safe_serialization=True, + ) + # save out meta config + meta_path = os.path.join(output_path, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(meta, f) # end must be implemented in child classes def te_train(self): @@ -512,6 +538,7 @@ class BaseModel: self.device_torch, dtype=self.unet.dtype) img = self.generate_single_image( + pipeline, gen_config, conditional_embeds, unconditional_embeds, @@ -603,7 +630,8 @@ class BaseModel: self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.IntTensor + timesteps: torch.IntTensor, + **kwargs, ) -> torch.FloatTensor: original_samples_chunks = torch.chunk( original_samples, original_samples.shape[0], dim=0) @@ -1071,7 +1099,7 @@ class BaseModel: for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"): named_params[name] = param if unet: - if self.is_flux or self.is_lumina2: + if self.is_flux or self.is_lumina2 or self.is_transformer: for name, param in self.unet.named_parameters(recurse=True, prefix="transformer"): named_params[name] = param else: @@ -1105,59 +1133,11 @@ class BaseModel: return named_params def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None): - version_string = '1' - if self.is_v2: - version_string = '2' - if self.is_xl: - version_string = 'sdxl' - if self.is_ssd: - # overwrite sdxl because both wil be true here - version_string = 'ssd' - if self.is_ssd and self.is_vega: - version_string = 'vega' - # if output file does not end in .safetensors, then it is a directory and we are - # saving in diffusers format - if not output_file.endswith('.safetensors'): - # diffusers - if self.is_flux: - # only save the unet - transformer: FluxTransformer2DModel = unwrap_model(self.unet) - transformer.save_pretrained( - save_directory=os.path.join(output_file, 'transformer'), - safe_serialization=True, - ) - elif self.is_lumina2: - # only save the unet - transformer: Lumina2Transformer2DModel = unwrap_model( - self.unet) - transformer.save_pretrained( - save_directory=os.path.join(output_file, 'transformer'), - safe_serialization=True, - ) - - else: - - self.pipeline.save_pretrained( - save_directory=output_file, - safe_serialization=True, - ) - # save out meta config - meta_path = os.path.join(output_file, 'aitk_meta.yaml') - with open(meta_path, 'w') as f: - yaml.dump(meta, f) - - else: - save_ldm_model_from_diffusers( - sd=self, - output_file=output_file, - meta=meta, - save_dtype=save_dtype, - sd_version=version_string, - ) - if self.config_file is not None: - output_path_no_ext = os.path.splitext(output_file)[0] - output_config_path = f"{output_path_no_ext}.yaml" - shutil.copyfile(self.config_file, output_config_path) + self.save_model( + output_path=output_file, + meta=meta, + save_dtype=save_dtype + ) def prepare_optimizer_params( self, @@ -1240,12 +1220,7 @@ class BaseModel: def save_device_state(self): # saves the current device state for all modules # this is useful for when we want to alter the state and restore it - if self.is_lumina2: - unet_has_grad = self.unet.x_embedder.weight.requires_grad - elif self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux: - unet_has_grad = self.unet.proj_out.weight.requires_grad - else: - unet_has_grad = self.unet.conv_in.weight.requires_grad + unet_has_grad = self.get_model_has_grad() self.device_state = { **empty_preset, @@ -1262,13 +1237,7 @@ class BaseModel: if isinstance(self.text_encoder, list): self.device_state['text_encoder']: List[dict] = [] for encoder in self.text_encoder: - if isinstance(encoder, LlamaModel): - te_has_grad = encoder.layers[0].mlp.gate_proj.weight.requires_grad - else: - try: - te_has_grad = encoder.text_model.final_layer_norm.weight.requires_grad - except: - te_has_grad = encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad + te_has_grad = self.get_te_has_grad() self.device_state['text_encoder'].append({ 'training': encoder.training, 'device': encoder.device, @@ -1276,17 +1245,7 @@ class BaseModel: 'requires_grad': te_has_grad }) else: - if isinstance(self.text_encoder, T5EncoderModel) or isinstance(self.text_encoder, UMT5EncoderModel): - te_has_grad = self.text_encoder.encoder.block[ - 0].layer[0].SelfAttention.q.weight.requires_grad - elif isinstance(self.text_encoder, Gemma2Model): - te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad - elif isinstance(self.text_encoder, Qwen2Model): - te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad - elif isinstance(self.text_encoder, LlamaModel): - te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad - else: - te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad + te_has_grad = self.get_te_has_grad() self.device_state['text_encoder'] = { 'training': self.text_encoder.training, diff --git a/toolkit/models/cogview4.py b/toolkit/models/cogview4.py new file mode 100644 index 00000000..51d87a55 --- /dev/null +++ b/toolkit/models/cogview4.py @@ -0,0 +1,458 @@ +import weakref +from diffusers import CogView4Pipeline +import torch +import yaml + +from toolkit.basic import flush +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.dequantize import patch_dequantization_on_save +from toolkit.models.base_model import BaseModel +from toolkit.prompt_utils import PromptEmbeds + +import os +import copy +from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch +import torch +import diffusers +from diffusers import AutoencoderKL, CogView4Transformer2DModel, CogView4Pipeline +from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4 +from transformers import GlmModel, AutoTokenizer +from diffusers import FlowMatchEulerDiscreteScheduler +from typing import TYPE_CHECKING +from toolkit.accelerator import unwrap_model + +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler + +if TYPE_CHECKING: + from toolkit.lora_special import LoRASpecialNetwork + +# remove this after a bug is fixed in diffusers code. This is a workaround. + + +class FakeModel: + def __init__(self, model): + self.model_ref = weakref.ref(model) + pass + + @property + def device(self): + return self.model_ref().device + + +scheduler_config = { + "base_image_seq_len": 256, + "base_shift": 0.25, + "invert_sigmas": False, + "max_image_seq_len": 4096, + "max_shift": 0.75, + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": None, + "time_shift_type": "linear", + "use_beta_sigmas": False, + "use_dynamic_shifting": True, + "use_exponential_sigmas": False, + "use_karras_sigmas": False +} + + +class CogView4(BaseModel): + def __init__( + self, + device, + model_config: ModelConfig, + dtype='bf16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + super().__init__(device, model_config, dtype, + custom_pipeline, noise_scheduler, **kwargs) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ['CogView4Transformer2DModel'] + + # cache for holding noise + self.effective_noise = None + + # static method to get the scheduler + @staticmethod + def get_train_scheduler(): + scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + return scheduler + + def load_model(self): + dtype = self.torch_dtype + base_model_path = "THUDM/CogView4-6B" + model_path = self.model_config.name_or_path + + # pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16) + self.print_and_status_update("Loading CogView4 model") + # base_model_path = "black-forest-labs/FLUX.1-schnell" + base_model_path = self.model_config.name_or_path_original + subfolder = 'transformer' + transformer_path = model_path + if os.path.exists(transformer_path): + subfolder = None + transformer_path = os.path.join(transformer_path, 'transformer') + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, 'text_encoder') + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + self.print_and_status_update("Loading GlmModel") + tokenizer = AutoTokenizer.from_pretrained( + base_model_path, subfolder="tokenizer", torch_dtype=dtype) + text_encoder = GlmModel.from_pretrained( + base_model_path, subfolder="text_encoder", torch_dtype=dtype) + + text_encoder.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing GlmModel") + quantize(text_encoder, weights=qfloat8) + freeze(text_encoder) + flush() + + # hack to fix diffusers bug workaround + text_encoder.model = FakeModel(text_encoder) + + self.print_and_status_update("Loading transformer") + transformer = CogView4Transformer2DModel.from_pretrained( + transformer_path, + subfolder=subfolder, + torch_dtype=dtype, + ) + + if self.model_config.split_model_over_gpus: + raise ValueError( + "Splitting model over gpus is not supported for CogViewModels models") + + transformer.to(self.quantize_device, dtype=dtype) + flush() + + if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None: + raise ValueError( + "Assistant LoRA is not supported for CogViewModels models currently") + + if self.model_config.lora_path is not None: + raise ValueError( + "Loading LoRA is not supported for CogViewModels models currently") + + flush() + + if self.model_config.quantize: + # patch the state dict method + patch_dequantization_on_save(transformer) + quantization_type = qfloat8 + self.print_and_status_update("Quantizing transformer") + quantize(transformer, weights=quantization_type, + **self.model_config.quantize_kwargs) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + flush() + + scheduler = CogView4.get_train_scheduler() + self.print_and_status_update("Loading VAE") + vae = AutoencoderKL.from_pretrained( + base_model_path, subfolder="vae", torch_dtype=dtype) + flush() + + self.print_and_status_update("Making pipe") + pipe: CogView4Pipeline = CogView4Pipeline( + scheduler=scheduler, + text_encoder=None, + tokenizer=tokenizer, + vae=vae, + transformer=None, + ) + pipe.text_encoder = text_encoder + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = pipe.text_encoder + tokenizer = pipe.tokenizer + + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + text_encoder.to(self.device_torch) + text_encoder.requires_grad_(False) + text_encoder.eval() + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() + self.pipeline = pipe + self.model = transformer + self.vae = vae + self.text_encoder = text_encoder + self.tokenizer = tokenizer + + def get_generation_pipeline(self): + scheduler = CogView4.get_train_scheduler() + pipeline = CogView4Pipeline( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=scheduler, + ) + return pipeline + + def generate_single_image( + self, + pipeline: CogView4Pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + # there is a bug in the check in diffusers code that requires the prompt embeds to be the same length for conditional and unconditional + # they are processed in 2 passes and the encoding code doesnt do this. So it shouldnt be needed. But, we will zero pad the shorter one. for now. Just inference here, so it should be fine. + if conditional_embeds.text_embeds.shape[1] < unconditional_embeds.text_embeds.shape[1]: + pad_len = unconditional_embeds.text_embeds.shape[1] - \ + conditional_embeds.text_embeds.shape[1] + conditional_embeds.text_embeds = torch.cat([conditional_embeds.text_embeds, torch.zeros(conditional_embeds.text_embeds.shape[0], pad_len, + conditional_embeds.text_embeds.shape[2], device=conditional_embeds.text_embeds.device, dtype=conditional_embeds.text_embeds.dtype)], dim=1) + elif conditional_embeds.text_embeds.shape[1] > unconditional_embeds.text_embeds.shape[1]: + pad_len = conditional_embeds.text_embeds.shape[1] - \ + unconditional_embeds.text_embeds.shape[1] + unconditional_embeds.text_embeds = torch.cat([unconditional_embeds.text_embeds, torch.zeros(unconditional_embeds.text_embeds.shape[0], pad_len, + unconditional_embeds.text_embeds.shape[2], device=unconditional_embeds.text_embeds.device, dtype=unconditional_embeds.text_embeds.dtype)], dim=1) + + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype), + negative_prompt_embeds=unconditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype), + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs + ): + # target_size = (height, width) + target_size = latent_model_input.shape[-2:] + # multiply by 8 + target_size = (target_size[0] * 8, target_size[1] * 8) + crops_coords_top_left = torch.tensor( + [(0, 0)], dtype=self.torch_dtype, device=self.device_torch) + + original_size = torch.tensor( + [target_size], dtype=self.torch_dtype, device=self.device_torch) + target_size = original_size.clone() + noise_pred_cond = self.model( + hidden_states=latent_model_input, # torch.Size([1, 16, 128, 128]) + encoder_hidden_states=text_embeddings.text_embeds, # torch.Size([1, 16, 4096]) + timestep=timestep, + original_size=original_size, # [[1024., 1024.]] + target_size=target_size, # [[1024., 1024.]] + crop_coords=crops_coords_top_left, # [[0., 0.]] + return_dict=False, + )[0] + return noise_pred_cond + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + prompt_embeds, _ = self.pipeline.encode_prompt( + prompt, + do_classifier_free_guidance=False, + device=self.device_torch, + dtype=self.torch_dtype, + ) + return PromptEmbeds(prompt_embeds) + + def get_model_has_grad(self): + return self.model.proj_out.weight.requires_grad + + def get_te_has_grad(self): + return self.text_encoder.layers[0].mlp.down_proj.weight.requires_grad + + def save_model(self, output_path, meta, save_dtype): + # only save the unet + transformer: CogView4Transformer2DModel = unwrap_model(self.model) + transformer.save_pretrained( + save_directory=os.path.join(output_path, 'transformer'), + safe_serialization=True, + ) + + meta_path = os.path.join(output_path, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(meta, f) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get('noise') + effective_noise = self.effective_noise + batch = kwargs.get('batch') + if batch is None: + raise ValueError("Batch is not provided") + if noise is None: + raise ValueError("Noise is not provided") + # return batch.latents + return (noise - batch.latents).detach() + # return (effective_noise - batch.latents).detach() + + + def _get_low_res_latents(self, latents): + # todo prevent needing to do this and grab the tensor another way. + with torch.no_grad(): + # Decode latents to image space + images = self.decode_latents(latents, device=latents.device, dtype=latents.dtype) + + # Downsample by a factor of 2 using bilinear interpolation + B, C, H, W = images.shape + low_res_images = torch.nn.functional.interpolate( + images, + size=(H // 4, W // 4), + mode="bilinear", + align_corners=False + ) + + # Upsample back to original resolution to match expected VAE input dimensions + upsampled_low_res_images = torch.nn.functional.interpolate( + low_res_images, + size=(H, W), + mode="bilinear", + align_corners=False + ) + + # Encode the low-resolution images back to latent space + low_res_latents = self.encode_images(upsampled_low_res_images, device=latents.device, dtype=latents.dtype) + return low_res_latents + + + # def add_noise( + # self, + # original_samples: torch.FloatTensor, + # noise: torch.FloatTensor, + # timesteps: torch.IntTensor, + # **kwargs, + # ) -> torch.FloatTensor: + # relay_start_point = 500 + + # # Store original samples for loss calculation + # self.original_samples = original_samples + + # # Prepare chunks for batch processing + # original_samples_chunks = torch.chunk( + # original_samples, original_samples.shape[0], dim=0) + # noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) + # timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) + + # # Get the low res latents only if needed + # low_res_latents_chunks = None + + # # Handle case where timesteps is a single value for all samples + # if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): + # timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks) + + # noisy_latents_chunks = [] + # effective_noise_chunks = [] # Store the effective noise for each sample + + # for idx in range(original_samples.shape[0]): + # t = timesteps_chunks[idx] + # t_01 = (t / 1000).to(original_samples_chunks[idx].device) + + # # Flowmatching interpolation between original and noise + # if t > relay_start_point: + # # Standard flowmatching - direct linear interpolation + # noisy_latents = (1 - t_01) * original_samples_chunks[idx] + t_01 * noise_chunks[idx] + # effective_noise_chunks.append(noise_chunks[idx]) # Effective noise is just the noise + # else: + # # Relay flowmatching case - only compute low_res_latents if needed + # if low_res_latents_chunks is None: + # low_res_latents = self._get_low_res_latents(original_samples) + # low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0) + + # # Calculate the relay ratio (0 to 1) + # t_ratio = t.float() / relay_start_point + # t_ratio = torch.clamp(t_ratio, 0.0, 1.0) + + # # First blend between original and low-res based on t_ratio + # z0_t = (1 - t_ratio) * original_samples_chunks[idx] + t_ratio * low_res_latents_chunks[idx] + + # added_lor_res_noise = z0_t - original_samples_chunks[idx] + + # # Then apply flowmatching interpolation between this blended state and noise + # noisy_latents = (1 - t_01) * z0_t + t_01 * noise_chunks[idx] + + # # For prediction target, we need to store the effective "source" + # effective_noise_chunks.append(noise_chunks[idx] + added_lor_res_noise) + + # noisy_latents_chunks.append(noisy_latents) + + # noisy_latents = torch.cat(noisy_latents_chunks, dim=0) + # self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation + + # return noisy_latents + + + # def add_noise( + # self, + # original_samples: torch.FloatTensor, + # noise: torch.FloatTensor, + # timesteps: torch.IntTensor, + # **kwargs, + # ) -> torch.FloatTensor: + # relay_start_point = 500 + + # # Store original samples for loss calculation + # self.original_samples = original_samples + + # # Prepare chunks for batch processing + # original_samples_chunks = torch.chunk( + # original_samples, original_samples.shape[0], dim=0) + # noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) + # timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) + + # # Get the low res latents only if needed + # low_res_latents = self._get_low_res_latents(original_samples) + # low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0) + + # # Handle case where timesteps is a single value for all samples + # if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): + # timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks) + + # noisy_latents_chunks = [] + # effective_noise_chunks = [] # Store the effective noise for each sample + + # for idx in range(original_samples.shape[0]): + # t = timesteps_chunks[idx] + # t_01 = (t / 1000).to(original_samples_chunks[idx].device) + + # lrln = low_res_latents_chunks[idx] - original_samples_chunks[idx] + # lrln = lrln * (1 - t_01) + + # # make the noise an interpolation between noise and low_res_latents with + # # being noise at t_01=1 and low_res_latents at t_01=0 + # # new_noise = t_01 * noise_chunks[idx] + (1 - t_01) * lrln + # new_noise = noise_chunks[idx] + lrln + + # # Then apply flowmatching interpolation between this blended state and noise + # noisy_latents = (1 - t_01) * original_samples + t_01 * new_noise + + # # For prediction target, we need to store the effective "source" + # effective_noise_chunks.append(new_noise) + + # noisy_latents_chunks.append(noisy_latents) + + # noisy_latents = torch.cat(noisy_latents_chunks, dim=0) + # self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation + + # return noisy_latents diff --git a/toolkit/models/wan21.py b/toolkit/models/wan21.py index b9a98400..045e9b1c 100644 --- a/toolkit/models/wan21.py +++ b/toolkit/models/wan21.py @@ -36,12 +36,11 @@ class Wan21(BaseModel): super().__init__(device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs) self.is_flow_matching = True + raise NotImplementedError("Wan21 is not implemented yet") # these must be implemented in child classes def load_model(self): - self.pipeline = Wan21( - - ) + pass def get_generation_pipeline(self): # override this in child classes @@ -50,6 +49,7 @@ class Wan21(BaseModel): def generate_single_image( self, + pipeline, gen_config: GenerateImageConfig, conditional_embeds: PromptEmbeds, unconditional_embeds: PromptEmbeds, @@ -72,3 +72,11 @@ class Wan21(BaseModel): def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: raise NotImplementedError( "get_prompt_embeds must be implemented in child classes") + + def get_model_has_grad(self): + raise NotImplementedError( + "get_model_has_grad must be implemented in child classes") + + def get_te_has_grad(self): + raise NotImplementedError( + "get_te_has_grad must be implemented in child classes") diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py index 2a7a1cfd..f0dba4e7 100644 --- a/toolkit/samplers/custom_flowmatch_sampler.py +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -44,7 +44,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): hbsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum()) # flatten second half to max - hbsmntw_weighing[num_timesteps // 2:] = hbsmntw_weighing[num_timesteps // 2:].max() + hbsmntw_weighing[num_timesteps // + 2:] = hbsmntw_weighing[num_timesteps // 2:].max() # Create linear timesteps from 1000 to 0 timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu') @@ -56,7 +57,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor: # Get the indices of the timesteps - step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps] + step_indices = [(self.timesteps == t).nonzero().item() + for t in timesteps] # Get the weights for the timesteps if v2: @@ -70,7 +72,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): sigmas = self.sigmas.to(device=device, dtype=dtype) schedule_timesteps = self.timesteps.to(device) timesteps = timesteps.to(device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + step_indices = [(schedule_timesteps == t).nonzero().item() + for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < n_dim: @@ -84,27 +87,24 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: - ## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578 - ## Add noise according to flow matching. - ## zt = (1 - texp) * x + texp * z1 - - # sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) - # noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise - - # timestep needs to be in [0, 1], we store them in [0, 1000] - # noisy_sample = (1 - timestep) * latent + timestep * noise t_01 = (timesteps / 1000).to(original_samples.device) + # forward ODE noisy_model_input = (1 - t_01) * original_samples + t_01 * noise - - # n_dim = original_samples.ndim - # sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device) - # noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise + # reverse ODE + # noisy_model_input = (1 - t_01) * noise + t_01 * original_samples return noisy_model_input def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: return sample - def set_train_timesteps(self, num_timesteps, device, timestep_type='linear', latents=None): + def set_train_timesteps( + self, + num_timesteps, + device, + timestep_type='linear', + latents=None, + patch_size=1 + ): self.timestep_type = timestep_type if timestep_type == 'linear': timesteps = torch.linspace(1000, 0, num_timesteps, device=device) @@ -124,42 +124,67 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): self.timesteps = timesteps.to(device=device) return timesteps - elif timestep_type == 'flux_shift' or timestep_type == 'lumina2_shift': + elif timestep_type in ['flux_shift', 'lumina2_shift', 'shift']: # matches inference dynamic shifting timesteps = np.linspace( - self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_timesteps + self._sigma_to_t(self.sigma_max), self._sigma_to_t( + self.sigma_min), num_timesteps ) sigmas = timesteps / self.config.num_train_timesteps - - if latents is None: - raise ValueError('latents is None') - - h = latents.shape[2] // 2 # Divide by ph - w = latents.shape[3] // 2 # Divide by pw - image_seq_len = h * w - # todo need to know the mu for the shift - mu = calculate_shift( - image_seq_len, - self.config.get("base_image_seq_len", 256), - self.config.get("max_image_seq_len", 4096), - self.config.get("base_shift", 0.5), - self.config.get("max_shift", 1.16), - ) - sigmas = self.time_shift(mu, 1.0, sigmas) + if self.config.use_dynamic_shifting: + if latents is None: + raise ValueError('latents is None') - sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + # for flux we double up the patch size before sending her to simulate the latent reduction + h = latents.shape[2] + w = latents.shape[3] + image_seq_len = h * w // (patch_size**2) + + mu = calculate_shift( + image_seq_len, + self.config.get("base_image_seq_len", 256), + self.config.get("max_image_seq_len", 4096), + self.config.get("base_shift", 0.5), + self.config.get("max_shift", 1.16), + ) + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) + + if self.config.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras( + in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps) + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential( + in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps) + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta( + in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps) + + sigmas = torch.from_numpy(sigmas).to( + dtype=torch.float32, device=device) timesteps = sigmas * self.config.num_train_timesteps - sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + if self.config.invert_sigmas: + sigmas = 1.0 - sigmas + timesteps = sigmas * self.config.num_train_timesteps + sigmas = torch.cat( + [sigmas, torch.ones(1, device=sigmas.device)]) + else: + sigmas = torch.cat( + [sigmas, torch.zeros(1, device=sigmas.device)]) self.timesteps = timesteps.to(device=device) self.sigmas = sigmas - + self.timesteps = timesteps.to(device=device) return timesteps - + elif timestep_type == 'lognorm_blend': # disgtribute timestepd to the center/early and blend in linear alpha = 0.75 @@ -173,7 +198,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): t1 = ((1 - t1/t1.max()) * 1000) # add half of linear - t2 = torch.linspace(1000, 0, int(num_timesteps * (1 - alpha)), device=device) + t2 = torch.linspace(1000, 0, int( + num_timesteps * (1 - alpha)), device=device) timesteps = torch.cat((t1, t2)) # Sort the timesteps in descending order diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 1e850bb7..bf93e84c 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -160,7 +160,6 @@ class StableDiffusion: self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline'] self.vae: Union[None, 'AutoencoderKL'] self.unet: Union[None, 'UNet2DConditionModel'] - self.unet_unwrapped: Union[None, 'UNet2DConditionModel'] self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]] self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']] self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler @@ -205,6 +204,8 @@ class StableDiffusion: self.invert_assistant_lora = False self._after_sample_img_hooks = [] self._status_update_hooks = [] + # todo update this based on the model + self.is_transformer = False # properties for old arch for backwards compatibility @property @@ -246,6 +247,10 @@ class StableDiffusion: @property def is_lumina2(self): return self.arch == 'lumina2' + + @property + def unet_unwrapped(self): + return unwrap_model(self.unet) def load_model(self): if self.is_loaded: @@ -977,7 +982,6 @@ class StableDiffusion: if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux or self.is_lumina2: # pixart and sd3 dont use a unet self.unet = pipe.transformer - self.unet_unwrapped = pipe.transformer else: self.unet: 'UNet2DConditionModel' = pipe.unet self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) @@ -1776,7 +1780,8 @@ class StableDiffusion: self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.IntTensor + timesteps: torch.IntTensor, + **kwargs, ) -> torch.FloatTensor: original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0) noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) diff --git a/toolkit/util/get_model.py b/toolkit/util/get_model.py index b22d52c5..4d1668f8 100644 --- a/toolkit/util/get_model.py +++ b/toolkit/util/get_model.py @@ -5,5 +5,8 @@ def get_model_class(config: ModelConfig): if config.arch == "wan21": from toolkit.models.wan21 import Wan21 return Wan21 + elif config.arch == "cogview4": + from toolkit.models.cogview4 import CogView4 + return CogView4 else: return StableDiffusion \ No newline at end of file From aa44828c0caf068b9889686ebea190539040a222 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 5 Mar 2025 09:43:00 -0700 Subject: [PATCH 6/8] WIP more work on cogview4 --- requirements.txt | 2 +- toolkit/models/cogview4.py | 111 ++++++++----------- toolkit/samplers/custom_flowmatch_sampler.py | 2 +- 3 files changed, 51 insertions(+), 64 deletions(-) diff --git a/requirements.txt b/requirements.txt index d25678d2..cef9b658 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ torch==2.5.1 torchvision==0.20.1 safetensors -git+https://github.com/huggingface/diffusers@97fda1b75c70705b245a462044fedb47abb17e56 +git+https://github.com/huggingface/diffusers@24c062aaa19f5626d03d058daf8afffa2dfd49f7 transformers==4.49.0 lycoris-lora==1.8.3 flatten_json diff --git a/toolkit/models/cogview4.py b/toolkit/models/cogview4.py index 51d87a55..902886bb 100644 --- a/toolkit/models/cogview4.py +++ b/toolkit/models/cogview4.py @@ -20,7 +20,6 @@ from transformers import GlmModel, AutoTokenizer from diffusers import FlowMatchEulerDiscreteScheduler from typing import TYPE_CHECKING from toolkit.accelerator import unwrap_model - from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler if TYPE_CHECKING: @@ -71,7 +70,7 @@ class CogView4(BaseModel): self.is_flow_matching = True self.is_transformer = True self.target_lora_modules = ['CogView4Transformer2DModel'] - + # cache for holding noise self.effective_noise = None @@ -86,7 +85,6 @@ class CogView4(BaseModel): base_model_path = "THUDM/CogView4-6B" model_path = self.model_config.name_or_path - # pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16) self.print_and_status_update("Loading CogView4 model") # base_model_path = "black-forest-labs/FLUX.1-schnell" base_model_path = self.model_config.name_or_path_original @@ -213,19 +211,6 @@ class CogView4(BaseModel): generator: torch.Generator, extra: dict, ): - # there is a bug in the check in diffusers code that requires the prompt embeds to be the same length for conditional and unconditional - # they are processed in 2 passes and the encoding code doesnt do this. So it shouldnt be needed. But, we will zero pad the shorter one. for now. Just inference here, so it should be fine. - if conditional_embeds.text_embeds.shape[1] < unconditional_embeds.text_embeds.shape[1]: - pad_len = unconditional_embeds.text_embeds.shape[1] - \ - conditional_embeds.text_embeds.shape[1] - conditional_embeds.text_embeds = torch.cat([conditional_embeds.text_embeds, torch.zeros(conditional_embeds.text_embeds.shape[0], pad_len, - conditional_embeds.text_embeds.shape[2], device=conditional_embeds.text_embeds.device, dtype=conditional_embeds.text_embeds.dtype)], dim=1) - elif conditional_embeds.text_embeds.shape[1] > unconditional_embeds.text_embeds.shape[1]: - pad_len = conditional_embeds.text_embeds.shape[1] - \ - unconditional_embeds.text_embeds.shape[1] - unconditional_embeds.text_embeds = torch.cat([unconditional_embeds.text_embeds, torch.zeros(unconditional_embeds.text_embeds.shape[0], pad_len, - unconditional_embeds.text_embeds.shape[2], device=unconditional_embeds.text_embeds.device, dtype=unconditional_embeds.text_embeds.dtype)], dim=1) - img = pipeline( prompt_embeds=conditional_embeds.text_embeds.to( self.device_torch, dtype=self.torch_dtype), @@ -259,12 +244,12 @@ class CogView4(BaseModel): [target_size], dtype=self.torch_dtype, device=self.device_torch) target_size = original_size.clone() noise_pred_cond = self.model( - hidden_states=latent_model_input, # torch.Size([1, 16, 128, 128]) - encoder_hidden_states=text_embeddings.text_embeds, # torch.Size([1, 16, 4096]) + hidden_states=latent_model_input, + encoder_hidden_states=text_embeddings.text_embeds, timestep=timestep, - original_size=original_size, # [[1024., 1024.]] - target_size=target_size, # [[1024., 1024.]] - crop_coords=crops_coords_top_left, # [[0., 0.]] + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, return_dict=False, )[0] return noise_pred_cond @@ -283,9 +268,9 @@ class CogView4(BaseModel): def get_te_has_grad(self): return self.text_encoder.layers[0].mlp.down_proj.weight.requires_grad - + def save_model(self, output_path, meta, save_dtype): - # only save the unet + # only save the unet transformer: CogView4Transformer2DModel = unwrap_model(self.model) transformer.save_pretrained( save_directory=os.path.join(output_path, 'transformer'), @@ -295,7 +280,7 @@ class CogView4(BaseModel): meta_path = os.path.join(output_path, 'aitk_meta.yaml') with open(meta_path, 'w') as f: yaml.dump(meta, f) - + def get_loss_target(self, *args, **kwargs): noise = kwargs.get('noise') effective_noise = self.effective_noise @@ -305,25 +290,27 @@ class CogView4(BaseModel): if noise is None: raise ValueError("Noise is not provided") # return batch.latents + # return (batch.latents - noise).detach() return (noise - batch.latents).detach() + # return (batch.latents).detach() # return (effective_noise - batch.latents).detach() - - + def _get_low_res_latents(self, latents): - # todo prevent needing to do this and grab the tensor another way. + # todo prevent needing to do this and grab the tensor another way. with torch.no_grad(): # Decode latents to image space - images = self.decode_latents(latents, device=latents.device, dtype=latents.dtype) - + images = self.decode_latents( + latents, device=latents.device, dtype=latents.dtype) + # Downsample by a factor of 2 using bilinear interpolation B, C, H, W = images.shape low_res_images = torch.nn.functional.interpolate( images, - size=(H // 4, W // 4), + size=(H // 2, W // 2), mode="bilinear", align_corners=False ) - + # Upsample back to original resolution to match expected VAE input dimensions upsampled_low_res_images = torch.nn.functional.interpolate( low_res_images, @@ -331,12 +318,12 @@ class CogView4(BaseModel): mode="bilinear", align_corners=False ) - + # Encode the low-resolution images back to latent space - low_res_latents = self.encode_images(upsampled_low_res_images, device=latents.device, dtype=latents.dtype) + low_res_latents = self.encode_images( + upsampled_low_res_images, device=latents.device, dtype=latents.dtype) return low_res_latents - - + # def add_noise( # self, # original_samples: torch.FloatTensor, @@ -345,19 +332,19 @@ class CogView4(BaseModel): # **kwargs, # ) -> torch.FloatTensor: # relay_start_point = 500 - + # # Store original samples for loss calculation # self.original_samples = original_samples - + # # Prepare chunks for batch processing # original_samples_chunks = torch.chunk( # original_samples, original_samples.shape[0], dim=0) # noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) # timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) - + # # Get the low res latents only if needed # low_res_latents_chunks = None - + # # Handle case where timesteps is a single value for all samples # if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): # timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks) @@ -368,7 +355,7 @@ class CogView4(BaseModel): # for idx in range(original_samples.shape[0]): # t = timesteps_chunks[idx] # t_01 = (t / 1000).to(original_samples_chunks[idx].device) - + # # Flowmatching interpolation between original and noise # if t > relay_start_point: # # Standard flowmatching - direct linear interpolation @@ -379,30 +366,29 @@ class CogView4(BaseModel): # if low_res_latents_chunks is None: # low_res_latents = self._get_low_res_latents(original_samples) # low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0) - + # # Calculate the relay ratio (0 to 1) # t_ratio = t.float() / relay_start_point # t_ratio = torch.clamp(t_ratio, 0.0, 1.0) - + # # First blend between original and low-res based on t_ratio # z0_t = (1 - t_ratio) * original_samples_chunks[idx] + t_ratio * low_res_latents_chunks[idx] - + # added_lor_res_noise = z0_t - original_samples_chunks[idx] - + # # Then apply flowmatching interpolation between this blended state and noise # noisy_latents = (1 - t_01) * z0_t + t_01 * noise_chunks[idx] - + # # For prediction target, we need to store the effective "source" # effective_noise_chunks.append(noise_chunks[idx] + added_lor_res_noise) - + # noisy_latents_chunks.append(noisy_latents) # noisy_latents = torch.cat(noisy_latents_chunks, dim=0) # self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation - + # return noisy_latents - # def add_noise( # self, # original_samples: torch.FloatTensor, @@ -411,20 +397,20 @@ class CogView4(BaseModel): # **kwargs, # ) -> torch.FloatTensor: # relay_start_point = 500 - + # # Store original samples for loss calculation # self.original_samples = original_samples - + # # Prepare chunks for batch processing # original_samples_chunks = torch.chunk( # original_samples, original_samples.shape[0], dim=0) # noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) # timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) - + # # Get the low res latents only if needed # low_res_latents = self._get_low_res_latents(original_samples) # low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0) - + # # Handle case where timesteps is a single value for all samples # if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): # timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks) @@ -435,24 +421,25 @@ class CogView4(BaseModel): # for idx in range(original_samples.shape[0]): # t = timesteps_chunks[idx] # t_01 = (t / 1000).to(original_samples_chunks[idx].device) - + # lrln = low_res_latents_chunks[idx] - original_samples_chunks[idx] - # lrln = lrln * (1 - t_01) - - # # make the noise an interpolation between noise and low_res_latents with + # # lrln = lrln * (1 - t_01) + + # # make the noise an interpolation between noise and low_res_latents with # # being noise at t_01=1 and low_res_latents at t_01=0 - # # new_noise = t_01 * noise_chunks[idx] + (1 - t_01) * lrln - # new_noise = noise_chunks[idx] + lrln - + # new_noise = t_01 * noise_chunks[idx] + (1 - t_01) * lrln + # # new_noise = noise_chunks[idx] + lrln + # # new_noise = noise_chunks[idx] + lrln + # # Then apply flowmatching interpolation between this blended state and noise # noisy_latents = (1 - t_01) * original_samples + t_01 * new_noise - + # # For prediction target, we need to store the effective "source" # effective_noise_chunks.append(new_noise) - + # noisy_latents_chunks.append(noisy_latents) # noisy_latents = torch.cat(noisy_latents_chunks, dim=0) # self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation - + # return noisy_latents diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py index f0dba4e7..1e0ae2ab 100644 --- a/toolkit/samplers/custom_flowmatch_sampler.py +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -89,7 +89,7 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): ) -> torch.Tensor: t_01 = (timesteps / 1000).to(original_samples.device) # forward ODE - noisy_model_input = (1 - t_01) * original_samples + t_01 * noise + noisy_model_input = (1.0 - t_01) * original_samples + t_01 * noise # reverse ODE # noisy_model_input = (1 - t_01) * noise + t_01 * original_samples return noisy_model_input From 4fe33f51c14b5ae21f1ec80b20beb3b12c861dbd Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 5 Mar 2025 13:44:40 -0700 Subject: [PATCH 7/8] Fix issue with picking layers for quantization, adjust layers fo better quantization of cogview4 --- toolkit/models/cogview4.py | 24 ++++++++++++-- toolkit/stable_diffusion_model.py | 3 +- toolkit/util/quantize.py | 55 +++++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 4 deletions(-) create mode 100644 toolkit/util/quantize.py diff --git a/toolkit/models/cogview4.py b/toolkit/models/cogview4.py index 902886bb..62af5498 100644 --- a/toolkit/models/cogview4.py +++ b/toolkit/models/cogview4.py @@ -15,7 +15,8 @@ from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch import torch import diffusers from diffusers import AutoencoderKL, CogView4Transformer2DModel, CogView4Pipeline -from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4 +from optimum.quanto import freeze, qfloat8, QTensor, qint4 +from toolkit.util.quantize import quantize from transformers import GlmModel, AutoTokenizer from diffusers import FlowMatchEulerDiscreteScheduler from typing import TYPE_CHECKING @@ -142,12 +143,29 @@ class CogView4(BaseModel): flush() if self.model_config.quantize: + quantization_args = self.model_config.quantize_kwargs + if 'exclude' not in quantization_args: + quantization_args['exclude'] = [] + if 'include' not in quantization_args: + quantization_args['include'] = [] + + # Be more specific with the include pattern to exactly match transformer blocks + quantization_args['include'] += ["transformer_blocks.*"] + + # Exclude all LayerNorm layers within transformer blocks + quantization_args['exclude'] += [ + "transformer_blocks.*.norm1", + "transformer_blocks.*.norm2", + "transformer_blocks.*.norm2_context", + "transformer_blocks.*.attn1.norm_q", + "transformer_blocks.*.attn1.norm_k" + ] + # patch the state dict method patch_dequantization_on_save(transformer) quantization_type = qfloat8 self.print_and_status_update("Quantizing transformer") - quantize(transformer, weights=quantization_type, - **self.model_config.quantize_kwargs) + quantize(transformer, weights=quantization_type, **quantization_args) freeze(transformer) transformer.to(self.device_torch) else: diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index bf93e84c..65736178 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -64,7 +64,8 @@ from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT from huggingface_hub import hf_hub_download from toolkit.models.flux import add_model_gpu_splitter_to_flux, bypass_flux_guidance, restore_flux_guidance -from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4 +from optimum.quanto import freeze, qfloat8, QTensor, qint4 +from toolkit.util.quantize import quantize from toolkit.accelerator import get_accelerator, unwrap_model from typing import TYPE_CHECKING from toolkit.print import print_acc diff --git a/toolkit/util/quantize.py b/toolkit/util/quantize.py new file mode 100644 index 00000000..fd7b3178 --- /dev/null +++ b/toolkit/util/quantize.py @@ -0,0 +1,55 @@ +from fnmatch import fnmatch +from typing import Any, Dict, List, Optional, Union +import torch + +from optimum.quanto.quantize import _quantize_submodule +from optimum.quanto.tensor import Optimizer, qtype + +# the quantize function in quanto had a bug where it was using exclude instead of include + + +def quantize( + model: torch.nn.Module, + weights: Optional[Union[str, qtype]] = None, + activations: Optional[Union[str, qtype]] = None, + optimizer: Optional[Optimizer] = None, + include: Optional[Union[str, List[str]]] = None, + exclude: Optional[Union[str, List[str]]] = None, +): + """Quantize the specified model submodules + + Recursively quantize the submodules of the specified parent model. + + Only modules that have quantized counterparts will be quantized. + + If include patterns are specified, the submodule name must match one of them. + + If exclude patterns are specified, the submodule must not match one of them. + + Include or exclude patterns are Unix shell-style wildcards which are NOT regular expressions. See + https://docs.python.org/3/library/fnmatch.html for more details. + + Note: quantization happens in-place and modifies the original model and its descendants. + + Args: + model (`torch.nn.Module`): the model whose submodules will be quantized. + weights (`Optional[Union[str, qtype]]`): the qtype for weights quantization. + activations (`Optional[Union[str, qtype]]`): the qtype for activations quantization. + include (`Optional[Union[str, List[str]]]`): + Patterns constituting the allowlist. If provided, module names must match at + least one pattern from the allowlist. + exclude (`Optional[Union[str, List[str]]]`): + Patterns constituting the denylist. If provided, module names must not match + any patterns from the denylist. + """ + if include is not None: + include = [include] if isinstance(include, str) else include + if exclude is not None: + exclude = [exclude] if isinstance(exclude, str) else exclude + for name, m in model.named_modules(): + if include is not None and not any(fnmatch(name, pattern) for pattern in include): + continue + if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude): + continue + _quantize_submodule(model, name, m, weights=weights, + activations=activations, optimizer=optimizer) From 763128ea42246fc714de65c0715ceca478c417c1 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 5 Mar 2025 14:46:11 -0700 Subject: [PATCH 8/8] Note about cogview --- toolkit/models/cogview4.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/toolkit/models/cogview4.py b/toolkit/models/cogview4.py index 62af5498..593fa977 100644 --- a/toolkit/models/cogview4.py +++ b/toolkit/models/cogview4.py @@ -1,3 +1,6 @@ +# DONT USE THIS!. IT DOES NOT WORK YET! +# Will revisit this when they release more info on how it was trained. + import weakref from diffusers import CogView4Pipeline import torch