mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Training working for Qwen Image
This commit is contained in:
@@ -97,17 +97,10 @@ class QwenImageModel(BaseModel):
|
||||
subfolder=transformer_subfolder,
|
||||
torch_dtype=dtype
|
||||
)
|
||||
# transformer.to(self.quantize_device, dtype=dtype)
|
||||
|
||||
if self.model_config.quantize:
|
||||
# patch the state dict method
|
||||
patch_dequantization_on_save(transformer)
|
||||
# quantization_type = get_qtype(self.model_config.qtype)
|
||||
# self.print_and_status_update("Quantizing transformer")
|
||||
# quantize(transformer, weights=quantization_type,
|
||||
# **self.model_config.quantize_kwargs)
|
||||
# freeze(transformer)
|
||||
# transformer.to(self.device_torch)
|
||||
# move and quantize only certain pieces at a time.
|
||||
quantization_type = get_qtype(self.model_config.qtype)
|
||||
all_blocks = list(transformer.transformer_blocks)
|
||||
@@ -229,11 +222,10 @@ class QwenImageModel(BaseModel):
|
||||
gen_config.width = int(gen_config.width // sc * sc)
|
||||
gen_config.height = int(gen_config.height // sc * sc)
|
||||
img = pipeline(
|
||||
image=control_img,
|
||||
prompt_embeds=conditional_embeds.text_embeds,
|
||||
prompt_embeds_mask=conditional_embeds.attention_mask,
|
||||
prompt_embeds_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=torch.int64),
|
||||
negative_prompt_embeds=unconditional_embeds.text_embeds,
|
||||
negative_prompt_embeds_mask=unconditional_embeds.attention_mask,
|
||||
negative_prompt_embeds_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=torch.int64),
|
||||
height=gen_config.height,
|
||||
width=gen_config.width,
|
||||
num_inference_steps=gen_config.num_inference_steps,
|
||||
@@ -251,16 +243,33 @@ class QwenImageModel(BaseModel):
|
||||
text_embeddings: PromptEmbeds,
|
||||
**kwargs
|
||||
):
|
||||
batch_size, num_channels_latents, height, width = latent_model_input.shape
|
||||
|
||||
latent_model_input = latent_model_input.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latent_model_input = latent_model_input.permute(0, 2, 4, 1, 3, 5)
|
||||
latent_model_input = latent_model_input.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
|
||||
prompt_embeds_mask = text_embeddings.attention_mask.to(self.device_torch, dtype=torch.int64)
|
||||
|
||||
img_shapes = [(1, height // 2, width // 2)] * batch_size
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch),
|
||||
encoder_hidden_states_mask=text_embeddings.attention_mask.to(self.device_torch),
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype),
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
return_dict=False,
|
||||
**kwargs,
|
||||
)[0]
|
||||
|
||||
# unpack the noise prediction
|
||||
noise_pred = noise_pred.view(batch_size, height // 2, width // 2, num_channels_latents, 2, 2)
|
||||
noise_pred = noise_pred.permute(0, 3, 1, 4, 2, 5)
|
||||
noise_pred = noise_pred.reshape(batch_size, num_channels_latents, height, width)
|
||||
|
||||
return noise_pred
|
||||
|
||||
def get_prompt_embeds(self, prompt: str) -> PromptEmbeds:
|
||||
@@ -320,4 +329,45 @@ class QwenImageModel(BaseModel):
|
||||
for key, value in state_dict.items():
|
||||
new_key = key.replace("diffusion_model.", "transformer.")
|
||||
new_sd[new_key] = value
|
||||
return new_sd
|
||||
return new_sd
|
||||
|
||||
def encode_images(
|
||||
self,
|
||||
image_list: List[torch.Tensor],
|
||||
device=None,
|
||||
dtype=None
|
||||
):
|
||||
if device is None:
|
||||
device = self.vae_device_torch
|
||||
if dtype is None:
|
||||
dtype = self.vae_torch_dtype
|
||||
|
||||
# Move to vae to device if on cpu
|
||||
if self.vae.device == 'cpu':
|
||||
self.vae.to(device)
|
||||
self.vae.eval()
|
||||
self.vae.requires_grad_(False)
|
||||
# move to device and dtype
|
||||
image_list = [image.to(device, dtype=dtype) for image in image_list]
|
||||
images = torch.stack(image_list).to(device, dtype=dtype)
|
||||
# it uses wan vae, so add dim for frame count
|
||||
|
||||
images = images.unsqueeze(2)
|
||||
latents = self.vae.encode(images).latent_dist.sample()
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
|
||||
latents = (latents - latents_mean) * latents_std
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
|
||||
|
||||
latents = latents.squeeze(2) # remove the frame count dimension
|
||||
|
||||
return latents
|
||||
Reference in New Issue
Block a user