Reworked flux pred. Again

This commit is contained in:
Jaret Burkett
2024-08-08 13:06:34 -06:00
parent e69a520616
commit b3e03295ad

View File

@@ -35,6 +35,7 @@ from toolkit.sampler import get_sampler
from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers
from toolkit.sd_device_states_presets import empty_preset
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
from einops import rearrange, repeat
import torch
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline
@@ -1600,28 +1601,22 @@ class StableDiffusion:
self.unet = self.unet.to(dtype=self.torch_dtype)
if self.is_flux:
with torch.no_grad():
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) # 16 . Maybe dont subtract
# this is what diffusers does
text_ids = torch.zeros(latent_model_input.shape[0], text_embeddings.text_embeds.shape[1], 3).to(
device=self.device_torch, dtype=self.text_encoder[0].dtype
)
# todo check these
# height = latent_model_input.shape[2] * VAE_SCALE_FACTOR
# width = latent_model_input.shape[3] * VAE_SCALE_FACTOR
height = latent_model_input.shape[2] * VAE_SCALE_FACTOR # 128
width = latent_model_input.shape[3] * VAE_SCALE_FACTOR # 128
width_latent = latent_model_input.shape[3]
height_latent = latent_model_input.shape[2]
latent_image_ids = self.pipeline._prepare_latent_image_ids(
batch_size=latent_model_input.shape[0],
height=height_latent,
width=width_latent,
device=self.device_torch,
dtype=self.torch_dtype,
bs, c, h, w = latent_model_input.shape
latent_model_input_packed = rearrange(
latent_model_input,
"b c (h ph) (w pw) -> b (h w) (c ph pw)",
ph=2,
pw=2
)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs).to(self.device_torch)
txt_ids = torch.zeros(bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch)
# # handle guidance
guidance_scale = 1.0 # ?
if self.unet.config.guidance_embeds:
@@ -1630,44 +1625,17 @@ class StableDiffusion:
else:
guidance = None
# not sure how to handle this
# sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
# image_seq_len = latents.shape[1]
# mu = calculate_shift(
# image_seq_len,
# self.scheduler.config.base_image_seq_len,
# self.scheduler.config.max_image_seq_len,
# self.scheduler.config.base_shift,
# self.scheduler.config.max_shift,
# )
# timesteps, num_inference_steps = retrieve_timesteps(
# self.scheduler,
# num_inference_steps,
# device,
# timesteps,
# sigmas,
# mu=mu,
# )
latent_model_input_packed = self.pipeline._pack_latents(
latent_model_input,
batch_size=latent_model_input.shape[0],
num_channels_latents=latent_model_input.shape[1], # 16
height=height_latent, # 128
width=width_latent, # 128
)
cast_dtype = self.unet.dtype
# with torch.amp.autocast(device_type='cuda', dtype=cast_dtype):
noise_pred = self.unet(
hidden_states=latent_model_input_packed.to(self.device_torch, cast_dtype), # [1, 4096, 64]
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
# todo make sure this doesnt change
timestep=timestep / 1000, # timestep is 1000 scale
timestep=timestep / 1000, # timestep is 1000 scale
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype), # [1, 512, 4096]
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, cast_dtype), # [1, 768]
txt_ids=text_ids, # [1, 512, 3]
img_ids=latent_image_ids, # [1, 4096, 3]
txt_ids=txt_ids, # [1, 512, 3]
img_ids=img_ids, # [1, 4096, 3]
guidance=guidance,
return_dict=False,
**kwargs,
@@ -1676,12 +1644,14 @@ class StableDiffusion:
if isinstance(noise_pred, QTensor):
noise_pred = noise_pred.dequantize()
# unpack latents
noise_pred = self.pipeline._unpack_latents(
noise_pred = rearrange(
noise_pred,
height=height, # 1024
width=width, # 1024
vae_scale_factor=VAE_SCALE_FACTOR * 2, # should be 16 not sure why
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=latent_model_input.shape[2] // 2,
w=latent_model_input.shape[3] // 2,
ph=2,
pw=2,
c=latent_model_input.shape[1],
)
elif self.is_v3:
noise_pred = self.unet(