From b3e03295ada47505511fd1d4b3c57e14aebdfe61 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 8 Aug 2024 13:06:34 -0600 Subject: [PATCH] Reworked flux pred. Again --- toolkit/stable_diffusion_model.py | 78 ++++++++++--------------------- 1 file changed, 24 insertions(+), 54 deletions(-) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 04f08de8..7410c118 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -35,6 +35,7 @@ from toolkit.sampler import get_sampler from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers from toolkit.sd_device_states_presets import empty_preset from toolkit.train_tools import get_torch_dtype, apply_noise_offset +from einops import rearrange, repeat import torch from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \ StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline @@ -1600,28 +1601,22 @@ class StableDiffusion: self.unet = self.unet.to(dtype=self.torch_dtype) if self.is_flux: with torch.no_grad(): - VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) # 16 . Maybe dont subtract - # this is what diffusers does - text_ids = torch.zeros(latent_model_input.shape[0], text_embeddings.text_embeds.shape[1], 3).to( - device=self.device_torch, dtype=self.text_encoder[0].dtype - ) - # todo check these - # height = latent_model_input.shape[2] * VAE_SCALE_FACTOR - # width = latent_model_input.shape[3] * VAE_SCALE_FACTOR - height = latent_model_input.shape[2] * VAE_SCALE_FACTOR # 128 - width = latent_model_input.shape[3] * VAE_SCALE_FACTOR # 128 - width_latent = latent_model_input.shape[3] - height_latent = latent_model_input.shape[2] - - latent_image_ids = self.pipeline._prepare_latent_image_ids( - batch_size=latent_model_input.shape[0], - height=height_latent, - width=width_latent, - device=self.device_torch, - dtype=self.torch_dtype, + bs, c, h, w = latent_model_input.shape + latent_model_input_packed = rearrange( + latent_model_input, + "b c (h ph) (w pw) -> b (h w) (c ph pw)", + ph=2, + pw=2 ) + img_ids = torch.zeros(h // 2, w // 2, 3) + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs).to(self.device_torch) + + txt_ids = torch.zeros(bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch) + # # handle guidance guidance_scale = 1.0 # ? if self.unet.config.guidance_embeds: @@ -1630,44 +1625,17 @@ class StableDiffusion: else: guidance = None - # not sure how to handle this - - # sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - # image_seq_len = latents.shape[1] - # mu = calculate_shift( - # image_seq_len, - # self.scheduler.config.base_image_seq_len, - # self.scheduler.config.max_image_seq_len, - # self.scheduler.config.base_shift, - # self.scheduler.config.max_shift, - # ) - # timesteps, num_inference_steps = retrieve_timesteps( - # self.scheduler, - # num_inference_steps, - # device, - # timesteps, - # sigmas, - # mu=mu, - # ) - latent_model_input_packed = self.pipeline._pack_latents( - latent_model_input, - batch_size=latent_model_input.shape[0], - num_channels_latents=latent_model_input.shape[1], # 16 - height=height_latent, # 128 - width=width_latent, # 128 - ) - cast_dtype = self.unet.dtype # with torch.amp.autocast(device_type='cuda', dtype=cast_dtype): noise_pred = self.unet( hidden_states=latent_model_input_packed.to(self.device_torch, cast_dtype), # [1, 4096, 64] # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) # todo make sure this doesnt change - timestep=timestep / 1000, # timestep is 1000 scale + timestep=timestep / 1000, # timestep is 1000 scale encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype), # [1, 512, 4096] pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, cast_dtype), # [1, 768] - txt_ids=text_ids, # [1, 512, 3] - img_ids=latent_image_ids, # [1, 4096, 3] + txt_ids=txt_ids, # [1, 512, 3] + img_ids=img_ids, # [1, 4096, 3] guidance=guidance, return_dict=False, **kwargs, @@ -1676,12 +1644,14 @@ class StableDiffusion: if isinstance(noise_pred, QTensor): noise_pred = noise_pred.dequantize() - # unpack latents - noise_pred = self.pipeline._unpack_latents( + noise_pred = rearrange( noise_pred, - height=height, # 1024 - width=width, # 1024 - vae_scale_factor=VAE_SCALE_FACTOR * 2, # should be 16 not sure why + "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=latent_model_input.shape[2] // 2, + w=latent_model_input.shape[3] // 2, + ph=2, + pw=2, + c=latent_model_input.shape[1], ) elif self.is_v3: noise_pred = self.unet(