mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-25 08:49:14 +00:00
Minor bug fixes
This commit is contained in:
@@ -43,6 +43,7 @@ scheduler_config = {
|
||||
|
||||
class QwenImageModel(BaseModel):
|
||||
arch = "qwen_image"
|
||||
_qwen_image_keep_processor = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -117,6 +118,9 @@ class QwenImageModel(BaseModel):
|
||||
)
|
||||
|
||||
# remove the visual model as it is not needed for image generation
|
||||
self.processor = None
|
||||
if self._qwen_image_keep_processor:
|
||||
self.processor = text_encoder.model.visual
|
||||
text_encoder.model.visual = None
|
||||
|
||||
text_encoder.to(self.device_torch, dtype=dtype)
|
||||
@@ -162,7 +166,6 @@ class QwenImageModel(BaseModel):
|
||||
text_encoder[0].to(self.device_torch)
|
||||
text_encoder[0].requires_grad_(False)
|
||||
text_encoder[0].eval()
|
||||
pipe.transformer = pipe.transformer.to(self.device_torch)
|
||||
flush()
|
||||
|
||||
# save it to the model class
|
||||
@@ -210,7 +213,8 @@ class QwenImageModel(BaseModel):
|
||||
control_img = control_img.resize(
|
||||
(gen_config.width, gen_config.height), Image.BILINEAR
|
||||
)
|
||||
|
||||
self.model.to(self.device_torch)
|
||||
|
||||
# flush for low vram if we are doing that
|
||||
flush_between_steps = self.model_config.low_vram
|
||||
# Fix a bug in diffusers/torch
|
||||
@@ -247,20 +251,23 @@ class QwenImageModel(BaseModel):
|
||||
text_embeddings: PromptEmbeds,
|
||||
**kwargs
|
||||
):
|
||||
self.model.to(self.device_torch)
|
||||
batch_size, num_channels_latents, height, width = latent_model_input.shape
|
||||
|
||||
ps = self.transformer.config.patch_size
|
||||
|
||||
# pack image tokens
|
||||
latent_model_input = latent_model_input.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latent_model_input = latent_model_input.view(batch_size, num_channels_latents, height // ps, ps, width // ps, ps)
|
||||
latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5)
|
||||
latent_model_input = latent_model_input.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
latent_model_input = latent_model_input.reshape(batch_size, (height // ps) * (width // ps), num_channels_latents * (ps * ps))
|
||||
|
||||
# clamp text length to RoPE capacity for this image size
|
||||
# img_shapes passed to the model
|
||||
img_h2, img_w2 = height // 2, width // 2
|
||||
img_h2, img_w2 = height // ps, width // ps
|
||||
img_shapes = [(1, img_h2, img_w2)] * batch_size
|
||||
|
||||
# QwenEmbedRope logic:
|
||||
max_vid_index = max(img_h2 // 2, img_w2 // 2)
|
||||
max_vid_index = max(img_h2 // ps, img_w2 // ps)
|
||||
|
||||
rope_cap = 1024 - max_vid_index # available text positions in RoPE cache
|
||||
seq_len_actual = text_embeddings.text_embeds.shape[1]
|
||||
@@ -283,7 +290,7 @@ class QwenImageModel(BaseModel):
|
||||
)[0]
|
||||
|
||||
# unpack
|
||||
noise_pred = noise_pred.view(batch_size, height // 2, width // 2, num_channels_latents, 2, 2)
|
||||
noise_pred = noise_pred.view(batch_size, height // ps, width // ps, num_channels_latents, ps, ps)
|
||||
noise_pred = noise_pred.permute(0, 3, 1, 4, 2, 5)
|
||||
noise_pred = noise_pred.reshape(batch_size, num_channels_latents, height, width)
|
||||
return noise_pred
|
||||
|
||||
Reference in New Issue
Block a user