Fixed issue that sometimes happens in qwen image where text seq length is wrong

This commit is contained in:
Jaret Burkett
2025-08-09 16:33:05 -06:00
parent ccd449ec49
commit f0105c33a7

View File

@@ -256,8 +256,11 @@ class QwenImageModel(BaseModel):
latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5)
latent_model_input = latent_model_input.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
prompt_embeds_mask = text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64)
# make txt_seq_lens match the actual encoder_hidden_states length, and clamp the mask ---
seq_len = text_embeddings.text_embeds.shape[1]
prompt_embeds_mask = text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64)[:, :seq_len]
txt_seq_lens = [seq_len] * batch_size
img_shapes = [(1, height // 2, width // 2)] * batch_size
noise_pred = self.transformer(
@@ -267,12 +270,11 @@ class QwenImageModel(BaseModel):
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
encoder_hidden_states_mask=prompt_embeds_mask,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
txt_seq_lens=txt_seq_lens,
return_dict=False,
**kwargs,
)[0]
# unpack the noise prediction
noise_pred = noise_pred.view(batch_size, height // 2, width // 2, num_channels_latents, 2, 2)
noise_pred = noise_pred.permute(0, 3, 1, 4, 2, 5)
noise_pred = noise_pred.reshape(batch_size, num_channels_latents, height, width)