Flux training should work now... maybe

This commit is contained in:
Jaret Burkett
2024-08-03 09:17:34 -06:00
parent 369aa143bc
commit 9beea1c268
4 changed files with 55 additions and 49 deletions

View File

@@ -166,6 +166,10 @@ class StableDiffusion:
self.config_file = None
self.is_rectified_flow = False
if self.is_flux or self.is_v3 or self.is_auraflow:
self.is_rectified_flow = True
def load_model(self):
if self.is_loaded:
return
@@ -448,7 +452,7 @@ class StableDiffusion:
elif self.model_config.is_flux:
print("Loading Flux model")
base_model_path = "/home/jaret/Dev/models/hf/FLUX.1-schnell/"
base_model_path = "black-forest-labs/FLUX.1-schnell"
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
print("Loading vae")
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
@@ -1223,14 +1227,6 @@ class StableDiffusion:
noise: torch.FloatTensor,
timesteps: torch.IntTensor
) -> torch.FloatTensor:
# we handle adding noise for the various schedulers here. Some
# schedulers reference timesteps while others reference idx
# so we need to handle both cases
# get scheduler class name
scheduler_class_name = self.noise_scheduler.__class__.__name__
# todo handle if timestep is single value
original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0)
noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0)
@@ -1582,7 +1578,7 @@ class StableDiffusion:
# sigmas,
# mu=mu,
# )
latent_model_input = self.pipeline._pack_latents(
latent_model_input_packed = self.pipeline._pack_latents(
latent_model_input,
batch_size=latent_model_input.shape[0],
num_channels_latents=latent_model_input.shape[1], # 16
@@ -1592,7 +1588,7 @@ class StableDiffusion:
noise_pred = self.unet(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), # [1, 4096, 64]
hidden_states=latent_model_input_packed.to(self.device_torch, self.torch_dtype), # [1, 4096, 64]
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
# todo make sure this doesnt change
timestep=timestep / 1000, # timestep is 1000 scale
@@ -1609,9 +1605,12 @@ class StableDiffusion:
noise_pred = self.pipeline._unpack_latents(
noise_pred,
height=height, # 1024
width=height, # 1024
width=width, # 1024
vae_scale_factor=VAE_SCALE_FACTOR * 2, # should be 16 not sure why
)
# todo we do this on sd3 training. I think we do it here too? No paper
noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep)
elif self.is_v3:
noise_pred = self.unet(
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
@@ -2039,25 +2038,25 @@ class StableDiffusion:
if unet:
if self.is_flux:
# Just train the middle 2 blocks of each transformer block
block_list = []
num_transformer_blocks = 2
start_block = len(self.unet.transformer_blocks) // 2 - (num_transformer_blocks // 2)
for i in range(num_transformer_blocks):
block_list.append(self.unet.transformer_blocks[start_block + i])
# block_list = []
# num_transformer_blocks = 2
# start_block = len(self.unet.transformer_blocks) // 2 - (num_transformer_blocks // 2)
# for i in range(num_transformer_blocks):
# block_list.append(self.unet.transformer_blocks[start_block + i])
#
# num_single_transformer_blocks = 4
# start_block = len(self.unet.single_transformer_blocks) // 2 - (num_single_transformer_blocks // 2)
# for i in range(num_single_transformer_blocks):
# block_list.append(self.unet.single_transformer_blocks[start_block + i])
#
# for block in block_list:
# for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
# named_params[name] = param
num_single_transformer_blocks = 4
start_block = len(self.unet.single_transformer_blocks) // 2 - (num_single_transformer_blocks // 2)
for i in range(num_single_transformer_blocks):
block_list.append(self.unet.single_transformer_blocks[start_block + i])
for block in block_list:
for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
named_params[name] = param
# for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
# named_params[name] = param
# for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
# named_params[name] = param
for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
named_params[name] = param
for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
named_params[name] = param
else:
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
named_params[name] = param