diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py index 8176edec..f834a8a3 100644 --- a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py @@ -97,17 +97,10 @@ class QwenImageModel(BaseModel): subfolder=transformer_subfolder, torch_dtype=dtype ) - # transformer.to(self.quantize_device, dtype=dtype) if self.model_config.quantize: # patch the state dict method patch_dequantization_on_save(transformer) - # quantization_type = get_qtype(self.model_config.qtype) - # self.print_and_status_update("Quantizing transformer") - # quantize(transformer, weights=quantization_type, - # **self.model_config.quantize_kwargs) - # freeze(transformer) - # transformer.to(self.device_torch) # move and quantize only certain pieces at a time. quantization_type = get_qtype(self.model_config.qtype) all_blocks = list(transformer.transformer_blocks) @@ -229,11 +222,10 @@ class QwenImageModel(BaseModel): gen_config.width = int(gen_config.width // sc * sc) gen_config.height = int(gen_config.height // sc * sc) img = pipeline( - image=control_img, prompt_embeds=conditional_embeds.text_embeds, - prompt_embeds_mask=conditional_embeds.attention_mask, + prompt_embeds_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=torch.int64), negative_prompt_embeds=unconditional_embeds.text_embeds, - negative_prompt_embeds_mask=unconditional_embeds.attention_mask, + negative_prompt_embeds_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=torch.int64), height=gen_config.height, width=gen_config.width, num_inference_steps=gen_config.num_inference_steps, @@ -251,16 +243,33 @@ class QwenImageModel(BaseModel): text_embeddings: PromptEmbeds, **kwargs ): + batch_size, num_channels_latents, height, width = latent_model_input.shape + + latent_model_input = latent_model_input.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5) + latent_model_input = latent_model_input.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + prompt_embeds_mask = text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64) + + img_shapes = [(1, height // 2, width // 2)] * batch_size + noise_pred = self.transformer( hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), timestep=timestep / 1000, guidance=None, - encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch), - encoder_hidden_states_mask=text_embeddings.attention_mask.to(self.device_torch), + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), + encoder_hidden_states_mask=prompt_embeds_mask, + img_shapes=img_shapes, + txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), return_dict=False, **kwargs, )[0] + # unpack the noise prediction + noise_pred = noise_pred.view(batch_size, height // 2, width // 2, num_channels_latents, 2, 2) + noise_pred = noise_pred.permute(0, 3, 1, 4, 2, 5) + noise_pred = noise_pred.reshape(batch_size, num_channels_latents, height, width) + return noise_pred def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: @@ -320,4 +329,45 @@ class QwenImageModel(BaseModel): for key, value in state_dict.items(): new_key = key.replace("diffusion_model.", "transformer.") new_sd[new_key] = value - return new_sd \ No newline at end of file + return new_sd + + def encode_images( + self, + image_list: List[torch.Tensor], + device=None, + dtype=None + ): + if device is None: + device = self.vae_device_torch + if dtype is None: + dtype = self.vae_torch_dtype + + # Move to vae to device if on cpu + if self.vae.device == 'cpu': + self.vae.to(device) + self.vae.eval() + self.vae.requires_grad_(False) + # move to device and dtype + image_list = [image.to(device, dtype=dtype) for image in image_list] + images = torch.stack(image_list).to(device, dtype=dtype) + # it uses wan vae, so add dim for frame count + + images = images.unsqueeze(2) + latents = self.vae.encode(images).latent_dist.sample() + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + latents = (latents - latents_mean) * latents_std + latents = latents.to(device, dtype=dtype) + + + latents = latents.squeeze(2) # remove the frame count dimension + + return latents \ No newline at end of file