# DONT USE THIS!. IT DOES NOT WORK YET! # Will revisit this when they release more info on how it was trained. import weakref from diffusers import CogView4Pipeline import torch import yaml from toolkit.basic import flush from toolkit.config_modules import GenerateImageConfig, ModelConfig from toolkit.dequantize import patch_dequantization_on_save from toolkit.models.base_model import BaseModel from toolkit.prompt_utils import PromptEmbeds import os import copy from toolkit.config_modules import ModelConfig, GenerateImageConfig, ModelArch import torch import diffusers from diffusers import AutoencoderKL, CogView4Transformer2DModel, CogView4Pipeline from optimum.quanto import freeze, qfloat8, QTensor, qint4 from toolkit.util.quantize import quantize from transformers import GlmModel, AutoTokenizer from diffusers import FlowMatchEulerDiscreteScheduler from typing import TYPE_CHECKING from toolkit.accelerator import unwrap_model from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler if TYPE_CHECKING: from toolkit.lora_special import LoRASpecialNetwork # remove this after a bug is fixed in diffusers code. This is a workaround. class FakeModel: def __init__(self, model): self.model_ref = weakref.ref(model) pass @property def device(self): return self.model_ref().device scheduler_config = { "base_image_seq_len": 256, "base_shift": 0.25, "invert_sigmas": False, "max_image_seq_len": 4096, "max_shift": 0.75, "num_train_timesteps": 1000, "shift": 1.0, "shift_terminal": None, "time_shift_type": "linear", "use_beta_sigmas": False, "use_dynamic_shifting": True, "use_exponential_sigmas": False, "use_karras_sigmas": False } class CogView4(BaseModel): def __init__( self, device, model_config: ModelConfig, dtype='bf16', custom_pipeline=None, noise_scheduler=None, **kwargs ): super().__init__(device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs) self.is_flow_matching = True self.is_transformer = True self.target_lora_modules = ['CogView4Transformer2DModel'] # cache for holding noise self.effective_noise = None # static method to get the scheduler @staticmethod def get_train_scheduler(): scheduler = CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) return scheduler def load_model(self): dtype = self.torch_dtype base_model_path = "THUDM/CogView4-6B" model_path = self.model_config.name_or_path self.print_and_status_update("Loading CogView4 model") # base_model_path = "black-forest-labs/FLUX.1-schnell" base_model_path = self.model_config.name_or_path_original subfolder = 'transformer' transformer_path = model_path if os.path.exists(transformer_path): subfolder = None transformer_path = os.path.join(transformer_path, 'transformer') # check if the path is a full checkpoint. te_folder_path = os.path.join(model_path, 'text_encoder') # if we have the te, this folder is a full checkpoint, use it as the base if os.path.exists(te_folder_path): base_model_path = model_path self.print_and_status_update("Loading GlmModel") tokenizer = AutoTokenizer.from_pretrained( base_model_path, subfolder="tokenizer", torch_dtype=dtype) text_encoder = GlmModel.from_pretrained( base_model_path, subfolder="text_encoder", torch_dtype=dtype) text_encoder.to(self.device_torch, dtype=dtype) flush() if self.model_config.quantize_te: self.print_and_status_update("Quantizing GlmModel") quantize(text_encoder, weights=qfloat8) freeze(text_encoder) flush() # hack to fix diffusers bug workaround text_encoder.model = FakeModel(text_encoder) self.print_and_status_update("Loading transformer") transformer = CogView4Transformer2DModel.from_pretrained( transformer_path, subfolder=subfolder, torch_dtype=dtype, ) if self.model_config.split_model_over_gpus: raise ValueError( "Splitting model over gpus is not supported for CogViewModels models") transformer.to(self.quantize_device, dtype=dtype) flush() if self.model_config.assistant_lora_path is not None or self.model_config.inference_lora_path is not None: raise ValueError( "Assistant LoRA is not supported for CogViewModels models currently") if self.model_config.lora_path is not None: raise ValueError( "Loading LoRA is not supported for CogViewModels models currently") flush() if self.model_config.quantize: quantization_args = self.model_config.quantize_kwargs if 'exclude' not in quantization_args: quantization_args['exclude'] = [] if 'include' not in quantization_args: quantization_args['include'] = [] # Be more specific with the include pattern to exactly match transformer blocks quantization_args['include'] += ["transformer_blocks.*"] # Exclude all LayerNorm layers within transformer blocks quantization_args['exclude'] += [ "transformer_blocks.*.norm1", "transformer_blocks.*.norm2", "transformer_blocks.*.norm2_context", "transformer_blocks.*.attn1.norm_q", "transformer_blocks.*.attn1.norm_k" ] # patch the state dict method patch_dequantization_on_save(transformer) quantization_type = qfloat8 self.print_and_status_update("Quantizing transformer") quantize(transformer, weights=quantization_type, **quantization_args) freeze(transformer) transformer.to(self.device_torch) else: transformer.to(self.device_torch, dtype=dtype) flush() scheduler = CogView4.get_train_scheduler() self.print_and_status_update("Loading VAE") vae = AutoencoderKL.from_pretrained( base_model_path, subfolder="vae", torch_dtype=dtype) flush() self.print_and_status_update("Making pipe") pipe: CogView4Pipeline = CogView4Pipeline( scheduler=scheduler, text_encoder=None, tokenizer=tokenizer, vae=vae, transformer=None, ) pipe.text_encoder = text_encoder pipe.transformer = transformer self.print_and_status_update("Preparing Model") text_encoder = pipe.text_encoder tokenizer = pipe.tokenizer pipe.transformer = pipe.transformer.to(self.device_torch) flush() text_encoder.to(self.device_torch) text_encoder.requires_grad_(False) text_encoder.eval() pipe.transformer = pipe.transformer.to(self.device_torch) flush() self.pipeline = pipe self.model = transformer self.vae = vae self.text_encoder = text_encoder self.tokenizer = tokenizer def get_generation_pipeline(self): scheduler = CogView4.get_train_scheduler() pipeline = CogView4Pipeline( vae=self.vae, transformer=self.unet, text_encoder=self.text_encoder, tokenizer=self.tokenizer, scheduler=scheduler, ) return pipeline def generate_single_image( self, pipeline: CogView4Pipeline, gen_config: GenerateImageConfig, conditional_embeds: PromptEmbeds, unconditional_embeds: PromptEmbeds, generator: torch.Generator, extra: dict, ): img = pipeline( prompt_embeds=conditional_embeds.text_embeds.to( self.device_torch, dtype=self.torch_dtype), negative_prompt_embeds=unconditional_embeds.text_embeds.to( self.device_torch, dtype=self.torch_dtype), height=gen_config.height, width=gen_config.width, num_inference_steps=gen_config.num_inference_steps, guidance_scale=gen_config.guidance_scale, latents=gen_config.latents, generator=generator, **extra ).images[0] return img def get_noise_prediction( self, latent_model_input: torch.Tensor, timestep: torch.Tensor, # 0 to 1000 scale text_embeddings: PromptEmbeds, **kwargs ): # target_size = (height, width) target_size = latent_model_input.shape[-2:] # multiply by 8 target_size = (target_size[0] * 8, target_size[1] * 8) crops_coords_top_left = torch.tensor( [(0, 0)], dtype=self.torch_dtype, device=self.device_torch) original_size = torch.tensor( [target_size], dtype=self.torch_dtype, device=self.device_torch) target_size = original_size.clone() noise_pred_cond = self.model( hidden_states=latent_model_input, encoder_hidden_states=text_embeddings.text_embeds, timestep=timestep, original_size=original_size, target_size=target_size, crop_coords=crops_coords_top_left, return_dict=False, )[0] return noise_pred_cond def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: prompt_embeds, _ = self.pipeline.encode_prompt( prompt, do_classifier_free_guidance=False, device=self.device_torch, dtype=self.torch_dtype, ) return PromptEmbeds(prompt_embeds) def get_model_has_grad(self): return self.model.proj_out.weight.requires_grad def get_te_has_grad(self): return self.text_encoder.layers[0].mlp.down_proj.weight.requires_grad def save_model(self, output_path, meta, save_dtype): # only save the unet transformer: CogView4Transformer2DModel = unwrap_model(self.model) transformer.save_pretrained( save_directory=os.path.join(output_path, 'transformer'), safe_serialization=True, ) meta_path = os.path.join(output_path, 'aitk_meta.yaml') with open(meta_path, 'w') as f: yaml.dump(meta, f) def get_loss_target(self, *args, **kwargs): noise = kwargs.get('noise') effective_noise = self.effective_noise batch = kwargs.get('batch') if batch is None: raise ValueError("Batch is not provided") if noise is None: raise ValueError("Noise is not provided") # return batch.latents # return (batch.latents - noise).detach() return (noise - batch.latents).detach() # return (batch.latents).detach() # return (effective_noise - batch.latents).detach() def _get_low_res_latents(self, latents): # todo prevent needing to do this and grab the tensor another way. with torch.no_grad(): # Decode latents to image space images = self.decode_latents( latents, device=latents.device, dtype=latents.dtype) # Downsample by a factor of 2 using bilinear interpolation B, C, H, W = images.shape low_res_images = torch.nn.functional.interpolate( images, size=(H // 2, W // 2), mode="bilinear", align_corners=False ) # Upsample back to original resolution to match expected VAE input dimensions upsampled_low_res_images = torch.nn.functional.interpolate( low_res_images, size=(H, W), mode="bilinear", align_corners=False ) # Encode the low-resolution images back to latent space low_res_latents = self.encode_images( upsampled_low_res_images, device=latents.device, dtype=latents.dtype) return low_res_latents # def add_noise( # self, # original_samples: torch.FloatTensor, # noise: torch.FloatTensor, # timesteps: torch.IntTensor, # **kwargs, # ) -> torch.FloatTensor: # relay_start_point = 500 # # Store original samples for loss calculation # self.original_samples = original_samples # # Prepare chunks for batch processing # original_samples_chunks = torch.chunk( # original_samples, original_samples.shape[0], dim=0) # noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) # timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) # # Get the low res latents only if needed # low_res_latents_chunks = None # # Handle case where timesteps is a single value for all samples # if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): # timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks) # noisy_latents_chunks = [] # effective_noise_chunks = [] # Store the effective noise for each sample # for idx in range(original_samples.shape[0]): # t = timesteps_chunks[idx] # t_01 = (t / 1000).to(original_samples_chunks[idx].device) # # Flowmatching interpolation between original and noise # if t > relay_start_point: # # Standard flowmatching - direct linear interpolation # noisy_latents = (1 - t_01) * original_samples_chunks[idx] + t_01 * noise_chunks[idx] # effective_noise_chunks.append(noise_chunks[idx]) # Effective noise is just the noise # else: # # Relay flowmatching case - only compute low_res_latents if needed # if low_res_latents_chunks is None: # low_res_latents = self._get_low_res_latents(original_samples) # low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0) # # Calculate the relay ratio (0 to 1) # t_ratio = t.float() / relay_start_point # t_ratio = torch.clamp(t_ratio, 0.0, 1.0) # # First blend between original and low-res based on t_ratio # z0_t = (1 - t_ratio) * original_samples_chunks[idx] + t_ratio * low_res_latents_chunks[idx] # added_lor_res_noise = z0_t - original_samples_chunks[idx] # # Then apply flowmatching interpolation between this blended state and noise # noisy_latents = (1 - t_01) * z0_t + t_01 * noise_chunks[idx] # # For prediction target, we need to store the effective "source" # effective_noise_chunks.append(noise_chunks[idx] + added_lor_res_noise) # noisy_latents_chunks.append(noisy_latents) # noisy_latents = torch.cat(noisy_latents_chunks, dim=0) # self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation # return noisy_latents # def add_noise( # self, # original_samples: torch.FloatTensor, # noise: torch.FloatTensor, # timesteps: torch.IntTensor, # **kwargs, # ) -> torch.FloatTensor: # relay_start_point = 500 # # Store original samples for loss calculation # self.original_samples = original_samples # # Prepare chunks for batch processing # original_samples_chunks = torch.chunk( # original_samples, original_samples.shape[0], dim=0) # noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) # timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) # # Get the low res latents only if needed # low_res_latents = self._get_low_res_latents(original_samples) # low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0) # # Handle case where timesteps is a single value for all samples # if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): # timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks) # noisy_latents_chunks = [] # effective_noise_chunks = [] # Store the effective noise for each sample # for idx in range(original_samples.shape[0]): # t = timesteps_chunks[idx] # t_01 = (t / 1000).to(original_samples_chunks[idx].device) # lrln = low_res_latents_chunks[idx] - original_samples_chunks[idx] # # lrln = lrln * (1 - t_01) # # make the noise an interpolation between noise and low_res_latents with # # being noise at t_01=1 and low_res_latents at t_01=0 # new_noise = t_01 * noise_chunks[idx] + (1 - t_01) * lrln # # new_noise = noise_chunks[idx] + lrln # # new_noise = noise_chunks[idx] + lrln # # Then apply flowmatching interpolation between this blended state and noise # noisy_latents = (1 - t_01) * original_samples + t_01 * new_noise # # For prediction target, we need to store the effective "source" # effective_noise_chunks.append(new_noise) # noisy_latents_chunks.append(noisy_latents) # noisy_latents = torch.cat(noisy_latents_chunks, dim=0) # self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation # return noisy_latents