diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 827e05c4..3fe2f2f5 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -380,9 +380,19 @@ class SDTrainer(BaseSDTrainProcess): elif self.sd.prediction_type == 'v_prediction': # v-parameterization training target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps) - + + elif hasattr(self.sd, 'get_loss_target'): + target = self.sd.get_loss_target( + noise=noise, + batch=batch, + timesteps=timesteps, + ).detach() + elif self.sd.is_flow_matching: + # forward ODE target = (noise - batch.latents).detach() + # reverse ODE + # target = (batch.latents - noise).detach() else: target = noise diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index b1493dbd..77122af9 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -668,7 +668,6 @@ class BaseSDTrainProcess(BaseTrainProcess): # # prepare all the models stuff for accelerator (hopefully we dont miss any) self.sd.vae = self.accelerator.prepare(self.sd.vae) if self.sd.unet is not None: - self.sd.unet_unwrapped = self.sd.unet self.sd.unet = self.accelerator.prepare(self.sd.unet) # todo always tdo it? self.modules_being_trained.append(self.sd.unet) @@ -1105,11 +1104,19 @@ class BaseSDTrainProcess(BaseTrainProcess): if timestep_type is None: timestep_type = self.train_config.timestep_type + patch_size = 1 + if self.sd.is_flux: + # flux is a patch size of 1, but latents are divided by 2, so we need to double it + patch_size = 2 + elif hasattr(self.sd.unet.config, 'patch_size'): + patch_size = self.sd.unet.config.patch_size + self.sd.noise_scheduler.set_train_timesteps( num_train_timesteps, device=self.device_torch, timestep_type=timestep_type, - latents=latents + latents=latents, + patch_size=patch_size, ) else: self.sd.noise_scheduler.set_timesteps( @@ -1403,21 +1410,26 @@ class BaseSDTrainProcess(BaseTrainProcess): model_config_to_load.name_or_path = latest_save_path self.load_training_state_from_metadata(latest_save_path) - # get the noise scheduler - arch = 'sd' - if self.model_config.is_pixart: - arch = 'pixart' - if self.model_config.is_flux: - arch = 'flux' - if self.model_config.is_lumina2: - arch = 'lumina2' - sampler = get_sampler( - self.train_config.noise_scheduler, - { - "prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon", - }, - arch=arch, - ) + ModelClass = get_model_class(self.model_config) + # if the model class has get_train_scheduler static method + if hasattr(ModelClass, 'get_train_scheduler'): + sampler = ModelClass.get_train_scheduler() + else: + # get the noise scheduler + arch = 'sd' + if self.model_config.is_pixart: + arch = 'pixart' + if self.model_config.is_flux: + arch = 'flux' + if self.model_config.is_lumina2: + arch = 'lumina2' + sampler = get_sampler( + self.train_config.noise_scheduler, + { + "prediction_type": "v_prediction" if self.model_config.is_v_pred else "epsilon", + }, + arch=arch, + ) if self.train_config.train_refiner and self.model_config.refiner_name_or_path is not None and self.network_config is None: previous_refiner_save = self.get_latest_save_path(self.job.name + '_refiner') @@ -1425,7 +1437,6 @@ class BaseSDTrainProcess(BaseTrainProcess): model_config_to_load.refiner_name_or_path = previous_refiner_save self.load_training_state_from_metadata(previous_refiner_save) - ModelClass = get_model_class(self.model_config) self.sd = ModelClass( device=self.device, model_config=model_config_to_load, @@ -1562,6 +1573,9 @@ class BaseSDTrainProcess(BaseTrainProcess): # if is_lycoris: # preset = PRESET['full'] # NetworkClass.apply_preset(preset) + + if hasattr(self.sd, 'target_lora_modules'): + network_kwargs['target_lin_modules'] = self.sd.target_lora_modules self.network = NetworkClass( text_encoder=text_encoder, @@ -1590,6 +1604,7 @@ class BaseSDTrainProcess(BaseTrainProcess): network_config=self.network_config, network_type=self.network_config.type, transformer_only=self.network_config.transformer_only, + is_transformer=self.sd.is_transformer, **network_kwargs ) diff --git a/requirements.txt b/requirements.txt index 4040e760..d25678d2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ torch==2.5.1 torchvision==0.20.1 safetensors -git+https://github.com/huggingface/diffusers@28f48f4051e80082cbe97f2d62b365dbb01040ec -transformers +git+https://github.com/huggingface/diffusers@97fda1b75c70705b245a462044fedb47abb17e56 +transformers==4.49.0 lycoris-lora==1.8.3 flatten_json pyyaml diff --git a/testing/test_vae.py b/testing/test_vae.py index 44b31f63..463ab555 100644 --- a/testing/test_vae.py +++ b/testing/test_vae.py @@ -29,7 +29,7 @@ def paramiter_count(model): return int(paramiter_count) -def calculate_metrics(vae, images, max_imgs=-1): +def calculate_metrics(vae, images, max_imgs=-1, save_output=False): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") vae = vae.to(device) lpips_model = lpips.LPIPS(net='alex').to(device) @@ -44,6 +44,9 @@ def calculate_metrics(vae, images, max_imgs=-1): # ]) # needs values between -1 and 1 to_tensor = ToTensor() + + # remove _reconstructed.png files + images = [img for img in images if not img.endswith("_reconstructed.png")] if max_imgs > 0 and len(images) > max_imgs: images = images[:max_imgs] @@ -82,6 +85,15 @@ def calculate_metrics(vae, images, max_imgs=-1): avg_rfid = 0 avg_psnr = sum(psnr_scores) / len(psnr_scores) avg_lpips = sum(lpips_scores) / len(lpips_scores) + + if save_output: + filename_no_ext = os.path.splitext(os.path.basename(img_path))[0] + folder = os.path.dirname(img_path) + save_path = os.path.join(folder, filename_no_ext + "_reconstructed.png") + reconstructed = (reconstructed + 1) / 2 + reconstructed = reconstructed.clamp(0, 1) + reconstructed = transforms.ToPILImage()(reconstructed[0].cpu()) + reconstructed.save(save_path) return avg_rfid, avg_psnr, avg_lpips @@ -91,18 +103,23 @@ def main(): parser.add_argument("--vae_path", type=str, required=True, help="Path to the VAE model") parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images") parser.add_argument("--max_imgs", type=int, default=-1, help="Max num of images. Default is -1 for all images.") + # boolean store true + parser.add_argument("--save_output", action="store_true", help="Save the output images") args = parser.parse_args() if os.path.isfile(args.vae_path): vae = AutoencoderKL.from_single_file(args.vae_path) else: - vae = AutoencoderKL.from_pretrained(args.vae_path) + try: + vae = AutoencoderKL.from_pretrained(args.vae_path) + except: + vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae") vae.eval() vae = vae.to(device) print(f"Model has {paramiter_count(vae)} parameters") images = load_images(args.image_folder) - avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs) + avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs, args.save_output) # print(f"Average rFID: {avg_rfid}") print(f"Average PSNR: {avg_psnr}") diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index a9aa8cf8..e92f7cbe 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -513,7 +513,7 @@ class ModelConfig: self.te_name_or_path = kwargs.get("te_name_or_path", None) - self.arch: ModelArch = kwargs.get("model_arch", None) + self.arch: ModelArch = kwargs.get("arch", None) # handle migrating to new model arch if self.arch is None: diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 1b308cd4..84ac02b1 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -178,6 +178,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): transformer_only: bool = False, peft_format: bool = False, is_assistant_adapter: bool = False, + is_transformer: bool = False, **kwargs ) -> None: """ @@ -237,9 +238,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): self.network_config: NetworkConfig = kwargs.get("network_config", None) self.peft_format = peft_format + self.is_transformer = is_transformer + # always do peft for flux only for now - if self.is_flux or self.is_v3 or self.is_lumina2: + if self.is_flux or self.is_v3 or self.is_lumina2 or is_transformer: # don't do peft format for lokr if self.network_type.lower() != "lokr": self.peft_format = True @@ -282,7 +285,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): unet_prefix = self.LORA_PREFIX_UNET if self.peft_format: unet_prefix = self.PEFT_PREFIX_UNET - if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2: + if is_pixart or is_v3 or is_auraflow or is_flux or is_lumina2 or self.is_transformer: unet_prefix = f"lora_transformer" if self.peft_format: unet_prefix = "transformer" @@ -341,6 +344,11 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): if self.transformer_only and self.is_v3 and is_unet: if "transformer_blocks" not in lora_name: skip = True + + # handle custom models + if self.transformer_only and is_unet and hasattr(root_module, 'transformer_blocks'): + if "transformer_blocks" not in lora_name: + skip = True if (is_linear or is_conv2d) and not skip: diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index adc4e882..cae29ffd 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -168,11 +168,17 @@ class BaseModel: self.invert_assistant_lora = False self._after_sample_img_hooks = [] self._status_update_hooks = [] + self.is_transformer = False # properties for old arch for backwards compatibility @property def unet(self): return self.model + + # set unet to model + @unet.setter + def unet(self, value): + self.model = value @property def unet_unwrapped(self): @@ -235,6 +241,7 @@ class BaseModel: def generate_single_image( self, + pipeline, gen_config: GenerateImageConfig, conditional_embeds: PromptEmbeds, unconditional_embeds: PromptEmbeds, @@ -257,6 +264,25 @@ class BaseModel: def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: raise NotImplementedError( "get_prompt_embeds must be implemented in child classes") + + def get_model_has_grad(self): + raise NotImplementedError( + "get_model_has_grad must be implemented in child classes") + + def get_te_has_grad(self): + raise NotImplementedError( + "get_te_has_grad must be implemented in child classes") + + def save_model(self, output_path, meta, save_dtype): + # todo handle dtype without overloading anything (vram, cpu, etc) + unwrap_model(self.pipeline).save_pretrained( + save_directory=output_path, + safe_serialization=True, + ) + # save out meta config + meta_path = os.path.join(output_path, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(meta, f) # end must be implemented in child classes def te_train(self): @@ -512,6 +538,7 @@ class BaseModel: self.device_torch, dtype=self.unet.dtype) img = self.generate_single_image( + pipeline, gen_config, conditional_embeds, unconditional_embeds, @@ -603,7 +630,8 @@ class BaseModel: self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.IntTensor + timesteps: torch.IntTensor, + **kwargs, ) -> torch.FloatTensor: original_samples_chunks = torch.chunk( original_samples, original_samples.shape[0], dim=0) @@ -1071,7 +1099,7 @@ class BaseModel: for name, param in self.text_encoder.named_parameters(recurse=True, prefix=f"{SD_PREFIX_TEXT_ENCODER}"): named_params[name] = param if unet: - if self.is_flux or self.is_lumina2: + if self.is_flux or self.is_lumina2 or self.is_transformer: for name, param in self.unet.named_parameters(recurse=True, prefix="transformer"): named_params[name] = param else: @@ -1105,59 +1133,11 @@ class BaseModel: return named_params def save(self, output_file: str, meta: OrderedDict, save_dtype=get_torch_dtype('fp16'), logit_scale=None): - version_string = '1' - if self.is_v2: - version_string = '2' - if self.is_xl: - version_string = 'sdxl' - if self.is_ssd: - # overwrite sdxl because both wil be true here - version_string = 'ssd' - if self.is_ssd and self.is_vega: - version_string = 'vega' - # if output file does not end in .safetensors, then it is a directory and we are - # saving in diffusers format - if not output_file.endswith('.safetensors'): - # diffusers - if self.is_flux: - # only save the unet - transformer: FluxTransformer2DModel = unwrap_model(self.unet) - transformer.save_pretrained( - save_directory=os.path.join(output_file, 'transformer'), - safe_serialization=True, - ) - elif self.is_lumina2: - # only save the unet - transformer: Lumina2Transformer2DModel = unwrap_model( - self.unet) - transformer.save_pretrained( - save_directory=os.path.join(output_file, 'transformer'), - safe_serialization=True, - ) - - else: - - self.pipeline.save_pretrained( - save_directory=output_file, - safe_serialization=True, - ) - # save out meta config - meta_path = os.path.join(output_file, 'aitk_meta.yaml') - with open(meta_path, 'w') as f: - yaml.dump(meta, f) - - else: - save_ldm_model_from_diffusers( - sd=self, - output_file=output_file, - meta=meta, - save_dtype=save_dtype, - sd_version=version_string, - ) - if self.config_file is not None: - output_path_no_ext = os.path.splitext(output_file)[0] - output_config_path = f"{output_path_no_ext}.yaml" - shutil.copyfile(self.config_file, output_config_path) + self.save_model( + output_path=output_file, + meta=meta, + save_dtype=save_dtype + ) def prepare_optimizer_params( self, @@ -1240,12 +1220,7 @@ class BaseModel: def save_device_state(self): # saves the current device state for all modules # this is useful for when we want to alter the state and restore it - if self.is_lumina2: - unet_has_grad = self.unet.x_embedder.weight.requires_grad - elif self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux: - unet_has_grad = self.unet.proj_out.weight.requires_grad - else: - unet_has_grad = self.unet.conv_in.weight.requires_grad + unet_has_grad = self.get_model_has_grad() self.device_state = { **empty_preset, @@ -1262,13 +1237,7 @@ class BaseModel: if isinstance(self.text_encoder, list): self.device_state['text_encoder']: List[dict] = [] for encoder in self.text_encoder: - if isinstance(encoder, LlamaModel): - te_has_grad = encoder.layers[0].mlp.gate_proj.weight.requires_grad - else: - try: - te_has_grad = encoder.text_model.final_layer_norm.weight.requires_grad - except: - te_has_grad = encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad + te_has_grad = self.get_te_has_grad() self.device_state['text_encoder'].append({ 'training': encoder.training, 'device': encoder.device, @@ -1276,17 +1245,7 @@ class BaseModel: 'requires_grad': te_has_grad }) else: - if isinstance(self.text_encoder, T5EncoderModel) or isinstance(self.text_encoder, UMT5EncoderModel): - te_has_grad = self.text_encoder.encoder.block[ - 0].layer[0].SelfAttention.q.weight.requires_grad - elif isinstance(self.text_encoder, Gemma2Model): - te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad - elif isinstance(self.text_encoder, Qwen2Model): - te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad - elif isinstance(self.text_encoder, LlamaModel): - te_has_grad = self.text_encoder.layers[0].mlp.gate_proj.weight.requires_grad - else: - te_has_grad = self.text_encoder.text_model.final_layer_norm.weight.requires_grad + te_has_grad = self.get_te_has_grad() self.device_state['text_encoder'] = { 'training': self.text_encoder.training, diff --git a/toolkit/models/cogview4.py b/toolkit/models/cogview4.py new file mode 100644 index 00000000..51d87a55 --- /dev/null +++ b/toolkit/models/cogview4.py @@ -0,0 +1,458 @@ +import weakref +from diffusers import CogView4Pipeline +import torch +import yaml + +from toolkit.basic import flush +from toolkit.config_modules import GenerateImageConfig, ModelConfig +from toolkit.dequantize import patch_dequantization_on_save +from toolkit.models.base_model import BaseModel +from toolkit.prompt_utils import PromptEmbeds + +import os +import copy +from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch +import torch +import diffusers +from diffusers import AutoencoderKL, CogView4Transformer2DModel, CogView4Pipeline +from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4 +from transformers import GlmModel, AutoTokenizer +from diffusers import FlowMatchEulerDiscreteScheduler +from typing import TYPE_CHECKING +from toolkit.accelerator import unwrap_model + +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler + +if TYPE_CHECKING: + from toolkit.lora_special import LoRASpecialNetwork + +# remove this after a bug is fixed in diffusers code. This is a workaround. + + +class FakeModel: + def __init__(self, model): + self.model_ref = weakref.ref(model) + pass + + @property + def device(self): + return self.model_ref().device + + +scheduler_config = { + "base_image_seq_len": 256, + "base_shift": 0.25, + "invert_sigmas": False, + "max_image_seq_len": 4096, + "max_shift": 0.75, + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": None, + "time_shift_type": "linear", + "use_beta_sigmas": False, + "use_dynamic_shifting": True, + "use_exponential_sigmas": False, + "use_karras_sigmas": False +} + + +class CogView4(BaseModel): + def __init__( + self, + device, + model_config: ModelConfig, + dtype='bf16', + custom_pipeline=None, + noise_scheduler=None, + **kwargs + ): + super().__init__(device, model_config, dtype, + custom_pipeline, noise_scheduler, **kwargs) + self.is_flow_matching = True + self.is_transformer = True + self.target_lora_modules = ['CogView4Transformer2DModel'] + + # cache for holding noise + self.effective_noise = None + + # static method to get the scheduler + @staticmethod + def get_train_scheduler(): + scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) + return scheduler + + def load_model(self): + dtype = self.torch_dtype + base_model_path = "THUDM/CogView4-6B" + model_path = self.model_config.name_or_path + + # pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16) + self.print_and_status_update("Loading CogView4 model") + # base_model_path = "black-forest-labs/FLUX.1-schnell" + base_model_path = self.model_config.name_or_path_original + subfolder = 'transformer' + transformer_path = model_path + if os.path.exists(transformer_path): + subfolder = None + transformer_path = os.path.join(transformer_path, 'transformer') + # check if the path is a full checkpoint. + te_folder_path = os.path.join(model_path, 'text_encoder') + # if we have the te, this folder is a full checkpoint, use it as the base + if os.path.exists(te_folder_path): + base_model_path = model_path + + self.print_and_status_update("Loading GlmModel") + tokenizer = AutoTokenizer.from_pretrained( + base_model_path, subfolder="tokenizer", torch_dtype=dtype) + text_encoder = GlmModel.from_pretrained( + base_model_path, subfolder="text_encoder", torch_dtype=dtype) + + text_encoder.to(self.device_torch, dtype=dtype) + flush() + + if self.model_config.quantize_te: + self.print_and_status_update("Quantizing GlmModel") + quantize(text_encoder, weights=qfloat8) + freeze(text_encoder) + flush() + + # hack to fix diffusers bug workaround + text_encoder.model = FakeModel(text_encoder) + + self.print_and_status_update("Loading transformer") + transformer = CogView4Transformer2DModel.from_pretrained( + transformer_path, + subfolder=subfolder, + torch_dtype=dtype, + ) + + if self.model_config.split_model_over_gpus: + raise ValueError( + "Splitting model over gpus is not supported for CogViewModels models") + + transformer.to(self.quantize_device, dtype=dtype) + flush() + + if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None: + raise ValueError( + "Assistant LoRA is not supported for CogViewModels models currently") + + if self.model_config.lora_path is not None: + raise ValueError( + "Loading LoRA is not supported for CogViewModels models currently") + + flush() + + if self.model_config.quantize: + # patch the state dict method + patch_dequantization_on_save(transformer) + quantization_type = qfloat8 + self.print_and_status_update("Quantizing transformer") + quantize(transformer, weights=quantization_type, + **self.model_config.quantize_kwargs) + freeze(transformer) + transformer.to(self.device_torch) + else: + transformer.to(self.device_torch, dtype=dtype) + + flush() + + scheduler = CogView4.get_train_scheduler() + self.print_and_status_update("Loading VAE") + vae = AutoencoderKL.from_pretrained( + base_model_path, subfolder="vae", torch_dtype=dtype) + flush() + + self.print_and_status_update("Making pipe") + pipe: CogView4Pipeline = CogView4Pipeline( + scheduler=scheduler, + text_encoder=None, + tokenizer=tokenizer, + vae=vae, + transformer=None, + ) + pipe.text_encoder = text_encoder + pipe.transformer = transformer + + self.print_and_status_update("Preparing Model") + + text_encoder = pipe.text_encoder + tokenizer = pipe.tokenizer + + pipe.transformer = pipe.transformer.to(self.device_torch) + + flush() + text_encoder.to(self.device_torch) + text_encoder.requires_grad_(False) + text_encoder.eval() + pipe.transformer = pipe.transformer.to(self.device_torch) + flush() + self.pipeline = pipe + self.model = transformer + self.vae = vae + self.text_encoder = text_encoder + self.tokenizer = tokenizer + + def get_generation_pipeline(self): + scheduler = CogView4.get_train_scheduler() + pipeline = CogView4Pipeline( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + scheduler=scheduler, + ) + return pipeline + + def generate_single_image( + self, + pipeline: CogView4Pipeline, + gen_config: GenerateImageConfig, + conditional_embeds: PromptEmbeds, + unconditional_embeds: PromptEmbeds, + generator: torch.Generator, + extra: dict, + ): + # there is a bug in the check in diffusers code that requires the prompt embeds to be the same length for conditional and unconditional + # they are processed in 2 passes and the encoding code doesnt do this. So it shouldnt be needed. But, we will zero pad the shorter one. for now. Just inference here, so it should be fine. + if conditional_embeds.text_embeds.shape[1] < unconditional_embeds.text_embeds.shape[1]: + pad_len = unconditional_embeds.text_embeds.shape[1] - \ + conditional_embeds.text_embeds.shape[1] + conditional_embeds.text_embeds = torch.cat([conditional_embeds.text_embeds, torch.zeros(conditional_embeds.text_embeds.shape[0], pad_len, + conditional_embeds.text_embeds.shape[2], device=conditional_embeds.text_embeds.device, dtype=conditional_embeds.text_embeds.dtype)], dim=1) + elif conditional_embeds.text_embeds.shape[1] > unconditional_embeds.text_embeds.shape[1]: + pad_len = conditional_embeds.text_embeds.shape[1] - \ + unconditional_embeds.text_embeds.shape[1] + unconditional_embeds.text_embeds = torch.cat([unconditional_embeds.text_embeds, torch.zeros(unconditional_embeds.text_embeds.shape[0], pad_len, + unconditional_embeds.text_embeds.shape[2], device=unconditional_embeds.text_embeds.device, dtype=unconditional_embeds.text_embeds.dtype)], dim=1) + + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype), + negative_prompt_embeds=unconditional_embeds.text_embeds.to( + self.device_torch, dtype=self.torch_dtype), + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + generator=generator, + **extra + ).images[0] + return img + + def get_noise_prediction( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, # 0 to 1000 scale + text_embeddings: PromptEmbeds, + **kwargs + ): + # target_size = (height, width) + target_size = latent_model_input.shape[-2:] + # multiply by 8 + target_size = (target_size[0] * 8, target_size[1] * 8) + crops_coords_top_left = torch.tensor( + [(0, 0)], dtype=self.torch_dtype, device=self.device_torch) + + original_size = torch.tensor( + [target_size], dtype=self.torch_dtype, device=self.device_torch) + target_size = original_size.clone() + noise_pred_cond = self.model( + hidden_states=latent_model_input, # torch.Size([1, 16, 128, 128]) + encoder_hidden_states=text_embeddings.text_embeds, # torch.Size([1, 16, 4096]) + timestep=timestep, + original_size=original_size, # [[1024., 1024.]] + target_size=target_size, # [[1024., 1024.]] + crop_coords=crops_coords_top_left, # [[0., 0.]] + return_dict=False, + )[0] + return noise_pred_cond + + def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: + prompt_embeds, _ = self.pipeline.encode_prompt( + prompt, + do_classifier_free_guidance=False, + device=self.device_torch, + dtype=self.torch_dtype, + ) + return PromptEmbeds(prompt_embeds) + + def get_model_has_grad(self): + return self.model.proj_out.weight.requires_grad + + def get_te_has_grad(self): + return self.text_encoder.layers[0].mlp.down_proj.weight.requires_grad + + def save_model(self, output_path, meta, save_dtype): + # only save the unet + transformer: CogView4Transformer2DModel = unwrap_model(self.model) + transformer.save_pretrained( + save_directory=os.path.join(output_path, 'transformer'), + safe_serialization=True, + ) + + meta_path = os.path.join(output_path, 'aitk_meta.yaml') + with open(meta_path, 'w') as f: + yaml.dump(meta, f) + + def get_loss_target(self, *args, **kwargs): + noise = kwargs.get('noise') + effective_noise = self.effective_noise + batch = kwargs.get('batch') + if batch is None: + raise ValueError("Batch is not provided") + if noise is None: + raise ValueError("Noise is not provided") + # return batch.latents + return (noise - batch.latents).detach() + # return (effective_noise - batch.latents).detach() + + + def _get_low_res_latents(self, latents): + # todo prevent needing to do this and grab the tensor another way. + with torch.no_grad(): + # Decode latents to image space + images = self.decode_latents(latents, device=latents.device, dtype=latents.dtype) + + # Downsample by a factor of 2 using bilinear interpolation + B, C, H, W = images.shape + low_res_images = torch.nn.functional.interpolate( + images, + size=(H // 4, W // 4), + mode="bilinear", + align_corners=False + ) + + # Upsample back to original resolution to match expected VAE input dimensions + upsampled_low_res_images = torch.nn.functional.interpolate( + low_res_images, + size=(H, W), + mode="bilinear", + align_corners=False + ) + + # Encode the low-resolution images back to latent space + low_res_latents = self.encode_images(upsampled_low_res_images, device=latents.device, dtype=latents.dtype) + return low_res_latents + + + # def add_noise( + # self, + # original_samples: torch.FloatTensor, + # noise: torch.FloatTensor, + # timesteps: torch.IntTensor, + # **kwargs, + # ) -> torch.FloatTensor: + # relay_start_point = 500 + + # # Store original samples for loss calculation + # self.original_samples = original_samples + + # # Prepare chunks for batch processing + # original_samples_chunks = torch.chunk( + # original_samples, original_samples.shape[0], dim=0) + # noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) + # timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) + + # # Get the low res latents only if needed + # low_res_latents_chunks = None + + # # Handle case where timesteps is a single value for all samples + # if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): + # timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks) + + # noisy_latents_chunks = [] + # effective_noise_chunks = [] # Store the effective noise for each sample + + # for idx in range(original_samples.shape[0]): + # t = timesteps_chunks[idx] + # t_01 = (t / 1000).to(original_samples_chunks[idx].device) + + # # Flowmatching interpolation between original and noise + # if t > relay_start_point: + # # Standard flowmatching - direct linear interpolation + # noisy_latents = (1 - t_01) * original_samples_chunks[idx] + t_01 * noise_chunks[idx] + # effective_noise_chunks.append(noise_chunks[idx]) # Effective noise is just the noise + # else: + # # Relay flowmatching case - only compute low_res_latents if needed + # if low_res_latents_chunks is None: + # low_res_latents = self._get_low_res_latents(original_samples) + # low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0) + + # # Calculate the relay ratio (0 to 1) + # t_ratio = t.float() / relay_start_point + # t_ratio = torch.clamp(t_ratio, 0.0, 1.0) + + # # First blend between original and low-res based on t_ratio + # z0_t = (1 - t_ratio) * original_samples_chunks[idx] + t_ratio * low_res_latents_chunks[idx] + + # added_lor_res_noise = z0_t - original_samples_chunks[idx] + + # # Then apply flowmatching interpolation between this blended state and noise + # noisy_latents = (1 - t_01) * z0_t + t_01 * noise_chunks[idx] + + # # For prediction target, we need to store the effective "source" + # effective_noise_chunks.append(noise_chunks[idx] + added_lor_res_noise) + + # noisy_latents_chunks.append(noisy_latents) + + # noisy_latents = torch.cat(noisy_latents_chunks, dim=0) + # self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation + + # return noisy_latents + + + # def add_noise( + # self, + # original_samples: torch.FloatTensor, + # noise: torch.FloatTensor, + # timesteps: torch.IntTensor, + # **kwargs, + # ) -> torch.FloatTensor: + # relay_start_point = 500 + + # # Store original samples for loss calculation + # self.original_samples = original_samples + + # # Prepare chunks for batch processing + # original_samples_chunks = torch.chunk( + # original_samples, original_samples.shape[0], dim=0) + # noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) + # timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) + + # # Get the low res latents only if needed + # low_res_latents = self._get_low_res_latents(original_samples) + # low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0) + + # # Handle case where timesteps is a single value for all samples + # if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): + # timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks) + + # noisy_latents_chunks = [] + # effective_noise_chunks = [] # Store the effective noise for each sample + + # for idx in range(original_samples.shape[0]): + # t = timesteps_chunks[idx] + # t_01 = (t / 1000).to(original_samples_chunks[idx].device) + + # lrln = low_res_latents_chunks[idx] - original_samples_chunks[idx] + # lrln = lrln * (1 - t_01) + + # # make the noise an interpolation between noise and low_res_latents with + # # being noise at t_01=1 and low_res_latents at t_01=0 + # # new_noise = t_01 * noise_chunks[idx] + (1 - t_01) * lrln + # new_noise = noise_chunks[idx] + lrln + + # # Then apply flowmatching interpolation between this blended state and noise + # noisy_latents = (1 - t_01) * original_samples + t_01 * new_noise + + # # For prediction target, we need to store the effective "source" + # effective_noise_chunks.append(new_noise) + + # noisy_latents_chunks.append(noisy_latents) + + # noisy_latents = torch.cat(noisy_latents_chunks, dim=0) + # self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation + + # return noisy_latents diff --git a/toolkit/models/wan21.py b/toolkit/models/wan21.py index b9a98400..045e9b1c 100644 --- a/toolkit/models/wan21.py +++ b/toolkit/models/wan21.py @@ -36,12 +36,11 @@ class Wan21(BaseModel): super().__init__(device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs) self.is_flow_matching = True + raise NotImplementedError("Wan21 is not implemented yet") # these must be implemented in child classes def load_model(self): - self.pipeline = Wan21( - - ) + pass def get_generation_pipeline(self): # override this in child classes @@ -50,6 +49,7 @@ class Wan21(BaseModel): def generate_single_image( self, + pipeline, gen_config: GenerateImageConfig, conditional_embeds: PromptEmbeds, unconditional_embeds: PromptEmbeds, @@ -72,3 +72,11 @@ class Wan21(BaseModel): def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: raise NotImplementedError( "get_prompt_embeds must be implemented in child classes") + + def get_model_has_grad(self): + raise NotImplementedError( + "get_model_has_grad must be implemented in child classes") + + def get_te_has_grad(self): + raise NotImplementedError( + "get_te_has_grad must be implemented in child classes") diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py index 2a7a1cfd..f0dba4e7 100644 --- a/toolkit/samplers/custom_flowmatch_sampler.py +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -44,7 +44,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): hbsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum()) # flatten second half to max - hbsmntw_weighing[num_timesteps // 2:] = hbsmntw_weighing[num_timesteps // 2:].max() + hbsmntw_weighing[num_timesteps // + 2:] = hbsmntw_weighing[num_timesteps // 2:].max() # Create linear timesteps from 1000 to 0 timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu') @@ -56,7 +57,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor: # Get the indices of the timesteps - step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps] + step_indices = [(self.timesteps == t).nonzero().item() + for t in timesteps] # Get the weights for the timesteps if v2: @@ -70,7 +72,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): sigmas = self.sigmas.to(device=device, dtype=dtype) schedule_timesteps = self.timesteps.to(device) timesteps = timesteps.to(device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + step_indices = [(schedule_timesteps == t).nonzero().item() + for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < n_dim: @@ -84,27 +87,24 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): noise: torch.Tensor, timesteps: torch.Tensor, ) -> torch.Tensor: - ## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578 - ## Add noise according to flow matching. - ## zt = (1 - texp) * x + texp * z1 - - # sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) - # noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise - - # timestep needs to be in [0, 1], we store them in [0, 1000] - # noisy_sample = (1 - timestep) * latent + timestep * noise t_01 = (timesteps / 1000).to(original_samples.device) + # forward ODE noisy_model_input = (1 - t_01) * original_samples + t_01 * noise - - # n_dim = original_samples.ndim - # sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device) - # noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise + # reverse ODE + # noisy_model_input = (1 - t_01) * noise + t_01 * original_samples return noisy_model_input def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: return sample - def set_train_timesteps(self, num_timesteps, device, timestep_type='linear', latents=None): + def set_train_timesteps( + self, + num_timesteps, + device, + timestep_type='linear', + latents=None, + patch_size=1 + ): self.timestep_type = timestep_type if timestep_type == 'linear': timesteps = torch.linspace(1000, 0, num_timesteps, device=device) @@ -124,42 +124,67 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): self.timesteps = timesteps.to(device=device) return timesteps - elif timestep_type == 'flux_shift' or timestep_type == 'lumina2_shift': + elif timestep_type in ['flux_shift', 'lumina2_shift', 'shift']: # matches inference dynamic shifting timesteps = np.linspace( - self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_timesteps + self._sigma_to_t(self.sigma_max), self._sigma_to_t( + self.sigma_min), num_timesteps ) sigmas = timesteps / self.config.num_train_timesteps - - if latents is None: - raise ValueError('latents is None') - - h = latents.shape[2] // 2 # Divide by ph - w = latents.shape[3] // 2 # Divide by pw - image_seq_len = h * w - # todo need to know the mu for the shift - mu = calculate_shift( - image_seq_len, - self.config.get("base_image_seq_len", 256), - self.config.get("max_image_seq_len", 4096), - self.config.get("base_shift", 0.5), - self.config.get("max_shift", 1.16), - ) - sigmas = self.time_shift(mu, 1.0, sigmas) + if self.config.use_dynamic_shifting: + if latents is None: + raise ValueError('latents is None') - sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) + # for flux we double up the patch size before sending her to simulate the latent reduction + h = latents.shape[2] + w = latents.shape[3] + image_seq_len = h * w // (patch_size**2) + + mu = calculate_shift( + image_seq_len, + self.config.get("base_image_seq_len", 256), + self.config.get("max_image_seq_len", 4096), + self.config.get("base_shift", 0.5), + self.config.get("max_shift", 1.16), + ) + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) + + if self.config.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras( + in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps) + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential( + in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps) + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta( + in_sigmas=sigmas, num_inference_steps=self.config.num_train_timesteps) + + sigmas = torch.from_numpy(sigmas).to( + dtype=torch.float32, device=device) timesteps = sigmas * self.config.num_train_timesteps - sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + if self.config.invert_sigmas: + sigmas = 1.0 - sigmas + timesteps = sigmas * self.config.num_train_timesteps + sigmas = torch.cat( + [sigmas, torch.ones(1, device=sigmas.device)]) + else: + sigmas = torch.cat( + [sigmas, torch.zeros(1, device=sigmas.device)]) self.timesteps = timesteps.to(device=device) self.sigmas = sigmas - + self.timesteps = timesteps.to(device=device) return timesteps - + elif timestep_type == 'lognorm_blend': # disgtribute timestepd to the center/early and blend in linear alpha = 0.75 @@ -173,7 +198,8 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): t1 = ((1 - t1/t1.max()) * 1000) # add half of linear - t2 = torch.linspace(1000, 0, int(num_timesteps * (1 - alpha)), device=device) + t2 = torch.linspace(1000, 0, int( + num_timesteps * (1 - alpha)), device=device) timesteps = torch.cat((t1, t2)) # Sort the timesteps in descending order diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 1e850bb7..bf93e84c 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -160,7 +160,6 @@ class StableDiffusion: self.pipeline: Union[None, 'StableDiffusionPipeline', 'CustomStableDiffusionXLPipeline', 'PixArtAlphaPipeline'] self.vae: Union[None, 'AutoencoderKL'] self.unet: Union[None, 'UNet2DConditionModel'] - self.unet_unwrapped: Union[None, 'UNet2DConditionModel'] self.text_encoder: Union[None, 'CLIPTextModel', List[Union['CLIPTextModel', 'CLIPTextModelWithProjection']]] self.tokenizer: Union[None, 'CLIPTokenizer', List['CLIPTokenizer']] self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler @@ -205,6 +204,8 @@ class StableDiffusion: self.invert_assistant_lora = False self._after_sample_img_hooks = [] self._status_update_hooks = [] + # todo update this based on the model + self.is_transformer = False # properties for old arch for backwards compatibility @property @@ -246,6 +247,10 @@ class StableDiffusion: @property def is_lumina2(self): return self.arch == 'lumina2' + + @property + def unet_unwrapped(self): + return unwrap_model(self.unet) def load_model(self): if self.is_loaded: @@ -977,7 +982,6 @@ class StableDiffusion: if self.is_pixart or self.is_v3 or self.is_auraflow or self.is_flux or self.is_lumina2: # pixart and sd3 dont use a unet self.unet = pipe.transformer - self.unet_unwrapped = pipe.transformer else: self.unet: 'UNet2DConditionModel' = pipe.unet self.vae: 'AutoencoderKL' = pipe.vae.to(self.vae_device_torch, dtype=self.vae_torch_dtype) @@ -1776,7 +1780,8 @@ class StableDiffusion: self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, - timesteps: torch.IntTensor + timesteps: torch.IntTensor, + **kwargs, ) -> torch.FloatTensor: original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0) noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) diff --git a/toolkit/util/get_model.py b/toolkit/util/get_model.py index b22d52c5..4d1668f8 100644 --- a/toolkit/util/get_model.py +++ b/toolkit/util/get_model.py @@ -5,5 +5,8 @@ def get_model_class(config: ModelConfig): if config.arch == "wan21": from toolkit.models.wan21 import Wan21 return Wan21 + elif config.arch == "cogview4": + from toolkit.models.cogview4 import CogView4 + return CogView4 else: return StableDiffusion \ No newline at end of file