mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Reworked flux pred. Again
This commit is contained in:
@@ -35,6 +35,7 @@ from toolkit.sampler import get_sampler
|
||||
from toolkit.saving import save_ldm_model_from_diffusers, get_ldm_state_dict_from_diffusers
|
||||
from toolkit.sd_device_states_presets import empty_preset
|
||||
from toolkit.train_tools import get_torch_dtype, apply_noise_offset
|
||||
from einops import rearrange, repeat
|
||||
import torch
|
||||
from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline, \
|
||||
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline, FluxWithCFGPipeline
|
||||
@@ -1600,28 +1601,22 @@ class StableDiffusion:
|
||||
self.unet = self.unet.to(dtype=self.torch_dtype)
|
||||
if self.is_flux:
|
||||
with torch.no_grad():
|
||||
VAE_SCALE_FACTOR = 2 ** (len(self.vae.config['block_out_channels']) - 1) # 16 . Maybe dont subtract
|
||||
# this is what diffusers does
|
||||
text_ids = torch.zeros(latent_model_input.shape[0], text_embeddings.text_embeds.shape[1], 3).to(
|
||||
device=self.device_torch, dtype=self.text_encoder[0].dtype
|
||||
)
|
||||
# todo check these
|
||||
# height = latent_model_input.shape[2] * VAE_SCALE_FACTOR
|
||||
# width = latent_model_input.shape[3] * VAE_SCALE_FACTOR
|
||||
height = latent_model_input.shape[2] * VAE_SCALE_FACTOR # 128
|
||||
width = latent_model_input.shape[3] * VAE_SCALE_FACTOR # 128
|
||||
|
||||
width_latent = latent_model_input.shape[3]
|
||||
height_latent = latent_model_input.shape[2]
|
||||
|
||||
latent_image_ids = self.pipeline._prepare_latent_image_ids(
|
||||
batch_size=latent_model_input.shape[0],
|
||||
height=height_latent,
|
||||
width=width_latent,
|
||||
device=self.device_torch,
|
||||
dtype=self.torch_dtype,
|
||||
bs, c, h, w = latent_model_input.shape
|
||||
latent_model_input_packed = rearrange(
|
||||
latent_model_input,
|
||||
"b c (h ph) (w pw) -> b (h w) (c ph pw)",
|
||||
ph=2,
|
||||
pw=2
|
||||
)
|
||||
|
||||
img_ids = torch.zeros(h // 2, w // 2, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs).to(self.device_torch)
|
||||
|
||||
txt_ids = torch.zeros(bs, text_embeddings.text_embeds.shape[1], 3).to(self.device_torch)
|
||||
|
||||
# # handle guidance
|
||||
guidance_scale = 1.0 # ?
|
||||
if self.unet.config.guidance_embeds:
|
||||
@@ -1630,44 +1625,17 @@ class StableDiffusion:
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
# not sure how to handle this
|
||||
|
||||
# sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
# image_seq_len = latents.shape[1]
|
||||
# mu = calculate_shift(
|
||||
# image_seq_len,
|
||||
# self.scheduler.config.base_image_seq_len,
|
||||
# self.scheduler.config.max_image_seq_len,
|
||||
# self.scheduler.config.base_shift,
|
||||
# self.scheduler.config.max_shift,
|
||||
# )
|
||||
# timesteps, num_inference_steps = retrieve_timesteps(
|
||||
# self.scheduler,
|
||||
# num_inference_steps,
|
||||
# device,
|
||||
# timesteps,
|
||||
# sigmas,
|
||||
# mu=mu,
|
||||
# )
|
||||
latent_model_input_packed = self.pipeline._pack_latents(
|
||||
latent_model_input,
|
||||
batch_size=latent_model_input.shape[0],
|
||||
num_channels_latents=latent_model_input.shape[1], # 16
|
||||
height=height_latent, # 128
|
||||
width=width_latent, # 128
|
||||
)
|
||||
|
||||
cast_dtype = self.unet.dtype
|
||||
# with torch.amp.autocast(device_type='cuda', dtype=cast_dtype):
|
||||
noise_pred = self.unet(
|
||||
hidden_states=latent_model_input_packed.to(self.device_torch, cast_dtype), # [1, 4096, 64]
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
# todo make sure this doesnt change
|
||||
timestep=timestep / 1000, # timestep is 1000 scale
|
||||
timestep=timestep / 1000, # timestep is 1000 scale
|
||||
encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype), # [1, 512, 4096]
|
||||
pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, cast_dtype), # [1, 768]
|
||||
txt_ids=text_ids, # [1, 512, 3]
|
||||
img_ids=latent_image_ids, # [1, 4096, 3]
|
||||
txt_ids=txt_ids, # [1, 512, 3]
|
||||
img_ids=img_ids, # [1, 4096, 3]
|
||||
guidance=guidance,
|
||||
return_dict=False,
|
||||
**kwargs,
|
||||
@@ -1676,12 +1644,14 @@ class StableDiffusion:
|
||||
if isinstance(noise_pred, QTensor):
|
||||
noise_pred = noise_pred.dequantize()
|
||||
|
||||
# unpack latents
|
||||
noise_pred = self.pipeline._unpack_latents(
|
||||
noise_pred = rearrange(
|
||||
noise_pred,
|
||||
height=height, # 1024
|
||||
width=width, # 1024
|
||||
vae_scale_factor=VAE_SCALE_FACTOR * 2, # should be 16 not sure why
|
||||
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||
h=latent_model_input.shape[2] // 2,
|
||||
w=latent_model_input.shape[3] // 2,
|
||||
ph=2,
|
||||
pw=2,
|
||||
c=latent_model_input.shape[1],
|
||||
)
|
||||
elif self.is_v3:
|
||||
noise_pred = self.unet(
|
||||
|
||||
Reference in New Issue
Block a user