mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Clip max token embeddings to the max rope length for qwen image to solve for an error for super long captions > 1024
This commit is contained in:
@@ -251,34 +251,44 @@ class QwenImageModel(BaseModel):
|
||||
**kwargs
|
||||
):
|
||||
batch_size, num_channels_latents, height, width = latent_model_input.shape
|
||||
|
||||
|
||||
# pack image tokens
|
||||
latent_model_input = latent_model_input.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5)
|
||||
latent_model_input = latent_model_input.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
|
||||
# make txt_seq_lens match the actual encoder_hidden_states length, and clamp the mask ---
|
||||
seq_len = text_embeddings.text_embeds.shape[1]
|
||||
prompt_embeds_mask = text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64)[:, :seq_len]
|
||||
txt_seq_lens = [seq_len] * batch_size
|
||||
|
||||
img_shapes = [(1, height // 2, width // 2)] * batch_size
|
||||
# clamp text length to RoPE capacity for this image size
|
||||
# img_shapes passed to the model
|
||||
img_h2, img_w2 = height // 2, width // 2
|
||||
img_shapes = [(1, img_h2, img_w2)] * batch_size
|
||||
|
||||
# QwenEmbedRope logic:
|
||||
max_vid_index = max(img_h2 // 2, img_w2 // 2)
|
||||
|
||||
rope_cap = 1024 - max_vid_index # available text positions in RoPE cache
|
||||
seq_len_actual = text_embeddings.text_embeds.shape[1]
|
||||
use_len = min(seq_len_actual, rope_cap)
|
||||
|
||||
enc_hs = text_embeddings.text_embeds[:, :use_len].to(self.device_torch, self.torch_dtype)
|
||||
prompt_embeds_mask = text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64)[:, :use_len]
|
||||
txt_seq_lens = [use_len] * batch_size
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
|
||||
encoder_hidden_states=enc_hs,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
return_dict=False,
|
||||
**kwargs,
|
||||
)[0]
|
||||
|
||||
|
||||
# unpack
|
||||
noise_pred = noise_pred.view(batch_size, height // 2, width // 2, num_channels_latents, 2, 2)
|
||||
noise_pred = noise_pred.permute(0, 3, 1, 4, 2, 5)
|
||||
noise_pred = noise_pred.reshape(batch_size, num_channels_latents, height, width)
|
||||
|
||||
return noise_pred
|
||||
|
||||
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||
|
||||
Reference in New Issue
Block a user