Training working for Qwen Image

This commit is contained in:
Jaret Burkett
2025-08-04 21:14:30 +00:00
parent 9da8b5408e
commit 93202c7a2b

View File

@@ -97,17 +97,10 @@ class QwenImageModel(BaseModel):
subfolder=transformer_subfolder,
torch_dtype=dtype
)
# transformer.to(self.quantize_device, dtype=dtype)
if self.model_config.quantize:
# patch the state dict method
patch_dequantization_on_save(transformer)
# quantization_type = get_qtype(self.model_config.qtype)
# self.print_and_status_update("Quantizing transformer")
# quantize(transformer, weights=quantization_type,
# **self.model_config.quantize_kwargs)
# freeze(transformer)
# transformer.to(self.device_torch)
# move and quantize only certain pieces at a time.
quantization_type = get_qtype(self.model_config.qtype)
all_blocks = list(transformer.transformer_blocks)
@@ -229,11 +222,10 @@ class QwenImageModel(BaseModel):
gen_config.width = int(gen_config.width // sc * sc)
gen_config.height = int(gen_config.height // sc * sc)
img = pipeline(
image=control_img,
prompt_embeds=conditional_embeds.text_embeds,
prompt_embeds_mask=conditional_embeds.attention_mask,
prompt_embeds_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=torch.int64),
negative_prompt_embeds=unconditional_embeds.text_embeds,
negative_prompt_embeds_mask=unconditional_embeds.attention_mask,
negative_prompt_embeds_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=torch.int64),
height=gen_config.height,
width=gen_config.width,
num_inference_steps=gen_config.num_inference_steps,
@@ -251,16 +243,33 @@ class QwenImageModel(BaseModel):
text_embeddings: PromptEmbeds,
**kwargs
):
batch_size, num_channels_latents, height, width = latent_model_input.shape
latent_model_input = latent_model_input.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5)
latent_model_input = latent_model_input.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
prompt_embeds_mask = text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64)
img_shapes = [(1, height // 2, width // 2)] * batch_size
noise_pred = self.transformer(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
timestep=timestep / 1000,
guidance=None,
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch),
encoder_hidden_states_mask=text_embeddings.attention_mask.to(self.device_torch),
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
encoder_hidden_states_mask=prompt_embeds_mask,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
return_dict=False,
**kwargs,
)[0]
# unpack the noise prediction
noise_pred = noise_pred.view(batch_size, height // 2, width // 2, num_channels_latents, 2, 2)
noise_pred = noise_pred.permute(0, 3, 1, 4, 2, 5)
noise_pred = noise_pred.reshape(batch_size, num_channels_latents, height, width)
return noise_pred
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
@@ -320,4 +329,45 @@ class QwenImageModel(BaseModel):
for key, value in state_dict.items():
new_key = key.replace("diffusion_model.", "transformer.")
new_sd[new_key] = value
return new_sd
return new_sd
def encode_images(
self,
image_list: List[torch.Tensor],
device=None,
dtype=None
):
if device is None:
device = self.vae_device_torch
if dtype is None:
dtype = self.vae_torch_dtype
# Move to vae to device if on cpu
if self.vae.device == 'cpu':
self.vae.to(device)
self.vae.eval()
self.vae.requires_grad_(False)
# move to device and dtype
image_list = [image.to(device, dtype=dtype) for image in image_list]
images = torch.stack(image_list).to(device, dtype=dtype)
# it uses wan vae, so add dim for frame count
images = images.unsqueeze(2)
latents = self.vae.encode(images).latent_dist.sample()
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = (latents - latents_mean) * latents_std
latents = latents.to(device, dtype=dtype)
latents = latents.squeeze(2) # remove the frame count dimension
return latents