diff --git a/requirements.txt b/requirements.txt index d25678d..cef9b65 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ torch==2.5.1 torchvision==0.20.1 safetensors -git+https://github.com/huggingface/diffusers@97fda1b75c70705b245a462044fedb47abb17e56 +git+https://github.com/huggingface/diffusers@24c062aaa19f5626d03d058daf8afffa2dfd49f7 transformers==4.49.0 lycoris-lora==1.8.3 flatten_json diff --git a/toolkit/models/cogview4.py b/toolkit/models/cogview4.py index 51d87a5..902886b 100644 --- a/toolkit/models/cogview4.py +++ b/toolkit/models/cogview4.py @@ -20,7 +20,6 @@ from transformers import GlmModel, AutoTokenizer from diffusers import FlowMatchEulerDiscreteScheduler from typing import TYPE_CHECKING from toolkit.accelerator import unwrap_model - from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler if TYPE_CHECKING: @@ -71,7 +70,7 @@ class CogView4(BaseModel): self.is_flow_matching = True self.is_transformer = True self.target_lora_modules = ['CogView4Transformer2DModel'] - + # cache for holding noise self.effective_noise = None @@ -86,7 +85,6 @@ class CogView4(BaseModel): base_model_path = "THUDM/CogView4-6B" model_path = self.model_config.name_or_path - # pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16) self.print_and_status_update("Loading CogView4 model") # base_model_path = "black-forest-labs/FLUX.1-schnell" base_model_path = self.model_config.name_or_path_original @@ -213,19 +211,6 @@ class CogView4(BaseModel): generator: torch.Generator, extra: dict, ): - # there is a bug in the check in diffusers code that requires the prompt embeds to be the same length for conditional and unconditional - # they are processed in 2 passes and the encoding code doesnt do this. So it shouldnt be needed. But, we will zero pad the shorter one. for now. Just inference here, so it should be fine. - if conditional_embeds.text_embeds.shape[1] < unconditional_embeds.text_embeds.shape[1]: - pad_len = unconditional_embeds.text_embeds.shape[1] - \ - conditional_embeds.text_embeds.shape[1] - conditional_embeds.text_embeds = torch.cat([conditional_embeds.text_embeds, torch.zeros(conditional_embeds.text_embeds.shape[0], pad_len, - conditional_embeds.text_embeds.shape[2], device=conditional_embeds.text_embeds.device, dtype=conditional_embeds.text_embeds.dtype)], dim=1) - elif conditional_embeds.text_embeds.shape[1] > unconditional_embeds.text_embeds.shape[1]: - pad_len = conditional_embeds.text_embeds.shape[1] - \ - unconditional_embeds.text_embeds.shape[1] - unconditional_embeds.text_embeds = torch.cat([unconditional_embeds.text_embeds, torch.zeros(unconditional_embeds.text_embeds.shape[0], pad_len, - unconditional_embeds.text_embeds.shape[2], device=unconditional_embeds.text_embeds.device, dtype=unconditional_embeds.text_embeds.dtype)], dim=1) - img = pipeline( prompt_embeds=conditional_embeds.text_embeds.to( self.device_torch, dtype=self.torch_dtype), @@ -259,12 +244,12 @@ class CogView4(BaseModel): [target_size], dtype=self.torch_dtype, device=self.device_torch) target_size = original_size.clone() noise_pred_cond = self.model( - hidden_states=latent_model_input, # torch.Size([1, 16, 128, 128]) - encoder_hidden_states=text_embeddings.text_embeds, # torch.Size([1, 16, 4096]) + hidden_states=latent_model_input, + encoder_hidden_states=text_embeddings.text_embeds, timestep=timestep, - original_size=original_size, # [[1024., 1024.]] - target_size=target_size, # [[1024., 1024.]] - crop_coords=crops_coords_top_left, # [[0., 0.]] + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, return_dict=False, )[0] return noise_pred_cond @@ -283,9 +268,9 @@ class CogView4(BaseModel): def get_te_has_grad(self): return self.text_encoder.layers[0].mlp.down_proj.weight.requires_grad - + def save_model(self, output_path, meta, save_dtype): - # only save the unet + # only save the unet transformer: CogView4Transformer2DModel = unwrap_model(self.model) transformer.save_pretrained( save_directory=os.path.join(output_path, 'transformer'), @@ -295,7 +280,7 @@ class CogView4(BaseModel): meta_path = os.path.join(output_path, 'aitk_meta.yaml') with open(meta_path, 'w') as f: yaml.dump(meta, f) - + def get_loss_target(self, *args, **kwargs): noise = kwargs.get('noise') effective_noise = self.effective_noise @@ -305,25 +290,27 @@ class CogView4(BaseModel): if noise is None: raise ValueError("Noise is not provided") # return batch.latents + # return (batch.latents - noise).detach() return (noise - batch.latents).detach() + # return (batch.latents).detach() # return (effective_noise - batch.latents).detach() - - + def _get_low_res_latents(self, latents): - # todo prevent needing to do this and grab the tensor another way. + # todo prevent needing to do this and grab the tensor another way. with torch.no_grad(): # Decode latents to image space - images = self.decode_latents(latents, device=latents.device, dtype=latents.dtype) - + images = self.decode_latents( + latents, device=latents.device, dtype=latents.dtype) + # Downsample by a factor of 2 using bilinear interpolation B, C, H, W = images.shape low_res_images = torch.nn.functional.interpolate( images, - size=(H // 4, W // 4), + size=(H // 2, W // 2), mode="bilinear", align_corners=False ) - + # Upsample back to original resolution to match expected VAE input dimensions upsampled_low_res_images = torch.nn.functional.interpolate( low_res_images, @@ -331,12 +318,12 @@ class CogView4(BaseModel): mode="bilinear", align_corners=False ) - + # Encode the low-resolution images back to latent space - low_res_latents = self.encode_images(upsampled_low_res_images, device=latents.device, dtype=latents.dtype) + low_res_latents = self.encode_images( + upsampled_low_res_images, device=latents.device, dtype=latents.dtype) return low_res_latents - - + # def add_noise( # self, # original_samples: torch.FloatTensor, @@ -345,19 +332,19 @@ class CogView4(BaseModel): # **kwargs, # ) -> torch.FloatTensor: # relay_start_point = 500 - + # # Store original samples for loss calculation # self.original_samples = original_samples - + # # Prepare chunks for batch processing # original_samples_chunks = torch.chunk( # original_samples, original_samples.shape[0], dim=0) # noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) # timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) - + # # Get the low res latents only if needed # low_res_latents_chunks = None - + # # Handle case where timesteps is a single value for all samples # if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): # timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks) @@ -368,7 +355,7 @@ class CogView4(BaseModel): # for idx in range(original_samples.shape[0]): # t = timesteps_chunks[idx] # t_01 = (t / 1000).to(original_samples_chunks[idx].device) - + # # Flowmatching interpolation between original and noise # if t > relay_start_point: # # Standard flowmatching - direct linear interpolation @@ -379,30 +366,29 @@ class CogView4(BaseModel): # if low_res_latents_chunks is None: # low_res_latents = self._get_low_res_latents(original_samples) # low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0) - + # # Calculate the relay ratio (0 to 1) # t_ratio = t.float() / relay_start_point # t_ratio = torch.clamp(t_ratio, 0.0, 1.0) - + # # First blend between original and low-res based on t_ratio # z0_t = (1 - t_ratio) * original_samples_chunks[idx] + t_ratio * low_res_latents_chunks[idx] - + # added_lor_res_noise = z0_t - original_samples_chunks[idx] - + # # Then apply flowmatching interpolation between this blended state and noise # noisy_latents = (1 - t_01) * z0_t + t_01 * noise_chunks[idx] - + # # For prediction target, we need to store the effective "source" # effective_noise_chunks.append(noise_chunks[idx] + added_lor_res_noise) - + # noisy_latents_chunks.append(noisy_latents) # noisy_latents = torch.cat(noisy_latents_chunks, dim=0) # self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation - + # return noisy_latents - # def add_noise( # self, # original_samples: torch.FloatTensor, @@ -411,20 +397,20 @@ class CogView4(BaseModel): # **kwargs, # ) -> torch.FloatTensor: # relay_start_point = 500 - + # # Store original samples for loss calculation # self.original_samples = original_samples - + # # Prepare chunks for batch processing # original_samples_chunks = torch.chunk( # original_samples, original_samples.shape[0], dim=0) # noise_chunks = torch.chunk(noise, noise.shape[0], dim=0) # timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0) - + # # Get the low res latents only if needed # low_res_latents = self._get_low_res_latents(original_samples) # low_res_latents_chunks = torch.chunk(low_res_latents, low_res_latents.shape[0], dim=0) - + # # Handle case where timesteps is a single value for all samples # if len(timesteps_chunks) == 1 and len(timesteps_chunks) != len(original_samples_chunks): # timesteps_chunks = [timesteps_chunks[0]] * len(original_samples_chunks) @@ -435,24 +421,25 @@ class CogView4(BaseModel): # for idx in range(original_samples.shape[0]): # t = timesteps_chunks[idx] # t_01 = (t / 1000).to(original_samples_chunks[idx].device) - + # lrln = low_res_latents_chunks[idx] - original_samples_chunks[idx] - # lrln = lrln * (1 - t_01) - - # # make the noise an interpolation between noise and low_res_latents with + # # lrln = lrln * (1 - t_01) + + # # make the noise an interpolation between noise and low_res_latents with # # being noise at t_01=1 and low_res_latents at t_01=0 - # # new_noise = t_01 * noise_chunks[idx] + (1 - t_01) * lrln - # new_noise = noise_chunks[idx] + lrln - + # new_noise = t_01 * noise_chunks[idx] + (1 - t_01) * lrln + # # new_noise = noise_chunks[idx] + lrln + # # new_noise = noise_chunks[idx] + lrln + # # Then apply flowmatching interpolation between this blended state and noise # noisy_latents = (1 - t_01) * original_samples + t_01 * new_noise - + # # For prediction target, we need to store the effective "source" # effective_noise_chunks.append(new_noise) - + # noisy_latents_chunks.append(noisy_latents) # noisy_latents = torch.cat(noisy_latents_chunks, dim=0) # self.effective_noise = torch.cat(effective_noise_chunks, dim=0) # Store for loss calculation - + # return noisy_latents diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py index f0dba4e..1e0ae2a 100644 --- a/toolkit/samplers/custom_flowmatch_sampler.py +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -89,7 +89,7 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): ) -> torch.Tensor: t_01 = (timesteps / 1000).to(original_samples.device) # forward ODE - noisy_model_input = (1 - t_01) * original_samples + t_01 * noise + noisy_model_input = (1.0 - t_01) * original_samples + t_01 * noise # reverse ODE # noisy_model_input = (1 - t_01) * noise + t_01 * original_samples return noisy_model_input