diff --git a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py index 0163dc19..b257930a 100644 --- a/extensions_built_in/diffusion_models/qwen_image/qwen_image.py +++ b/extensions_built_in/diffusion_models/qwen_image/qwen_image.py @@ -251,34 +251,44 @@ class QwenImageModel(BaseModel): **kwargs ): batch_size, num_channels_latents, height, width = latent_model_input.shape - + + # pack image tokens latent_model_input = latent_model_input.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5) latent_model_input = latent_model_input.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - - # make txt_seq_lens match the actual encoder_hidden_states length, and clamp the mask --- - seq_len = text_embeddings.text_embeds.shape[1] - prompt_embeds_mask = text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64)[:, :seq_len] - txt_seq_lens = [seq_len] * batch_size - img_shapes = [(1, height // 2, width // 2)] * batch_size + # clamp text length to RoPE capacity for this image size + # img_shapes passed to the model + img_h2, img_w2 = height // 2, width // 2 + img_shapes = [(1, img_h2, img_w2)] * batch_size + + # QwenEmbedRope logic: + max_vid_index = max(img_h2 // 2, img_w2 // 2) + + rope_cap = 1024 - max_vid_index # available text positions in RoPE cache + seq_len_actual = text_embeddings.text_embeds.shape[1] + use_len = min(seq_len_actual, rope_cap) + + enc_hs = text_embeddings.text_embeds[:, :use_len].to(self.device_torch, self.torch_dtype) + prompt_embeds_mask = text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64)[:, :use_len] + txt_seq_lens = [use_len] * batch_size noise_pred = self.transformer( hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), timestep=timestep / 1000, guidance=None, - encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), + encoder_hidden_states=enc_hs, encoder_hidden_states_mask=prompt_embeds_mask, img_shapes=img_shapes, txt_seq_lens=txt_seq_lens, return_dict=False, **kwargs, )[0] - + + # unpack noise_pred = noise_pred.view(batch_size, height // 2, width // 2, num_channels_latents, 2, 2) noise_pred = noise_pred.permute(0, 3, 1, 4, 2, 5) noise_pred = noise_pred.reshape(batch_size, num_channels_latents, height, width) - return noise_pred def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: