mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
Flux training should work now... maybe
This commit is contained in:
@@ -166,6 +166,10 @@ class StableDiffusion:
|
||||
|
||||
self.config_file = None
|
||||
|
||||
self.is_rectified_flow = False
|
||||
if self.is_flux or self.is_v3 or self.is_auraflow:
|
||||
self.is_rectified_flow = True
|
||||
|
||||
def load_model(self):
|
||||
if self.is_loaded:
|
||||
return
|
||||
@@ -448,7 +452,7 @@ class StableDiffusion:
|
||||
|
||||
elif self.model_config.is_flux:
|
||||
print("Loading Flux model")
|
||||
base_model_path = "/home/jaret/Dev/models/hf/FLUX.1-schnell/"
|
||||
base_model_path = "black-forest-labs/FLUX.1-schnell"
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
|
||||
print("Loading vae")
|
||||
vae = AutoencoderKL.from_pretrained(base_model_path, subfolder="vae", torch_dtype=dtype)
|
||||
@@ -1223,14 +1227,6 @@ class StableDiffusion:
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor
|
||||
) -> torch.FloatTensor:
|
||||
# we handle adding noise for the various schedulers here. Some
|
||||
# schedulers reference timesteps while others reference idx
|
||||
# so we need to handle both cases
|
||||
# get scheduler class name
|
||||
scheduler_class_name = self.noise_scheduler.__class__.__name__
|
||||
|
||||
# todo handle if timestep is single value
|
||||
|
||||
original_samples_chunks = torch.chunk(original_samples, original_samples.shape[0], dim=0)
|
||||
noise_chunks = torch.chunk(noise, noise.shape[0], dim=0)
|
||||
timesteps_chunks = torch.chunk(timesteps, timesteps.shape[0], dim=0)
|
||||
@@ -1582,7 +1578,7 @@ class StableDiffusion:
|
||||
# sigmas,
|
||||
# mu=mu,
|
||||
# )
|
||||
latent_model_input = self.pipeline._pack_latents(
|
||||
latent_model_input_packed = self.pipeline._pack_latents(
|
||||
latent_model_input,
|
||||
batch_size=latent_model_input.shape[0],
|
||||
num_channels_latents=latent_model_input.shape[1], # 16
|
||||
@@ -1592,7 +1588,7 @@ class StableDiffusion:
|
||||
|
||||
|
||||
noise_pred = self.unet(
|
||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), # [1, 4096, 64]
|
||||
hidden_states=latent_model_input_packed.to(self.device_torch, self.torch_dtype), # [1, 4096, 64]
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
# todo make sure this doesnt change
|
||||
timestep=timestep / 1000, # timestep is 1000 scale
|
||||
@@ -1609,9 +1605,12 @@ class StableDiffusion:
|
||||
noise_pred = self.pipeline._unpack_latents(
|
||||
noise_pred,
|
||||
height=height, # 1024
|
||||
width=height, # 1024
|
||||
width=width, # 1024
|
||||
vae_scale_factor=VAE_SCALE_FACTOR * 2, # should be 16 not sure why
|
||||
)
|
||||
|
||||
# todo we do this on sd3 training. I think we do it here too? No paper
|
||||
noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep)
|
||||
elif self.is_v3:
|
||||
noise_pred = self.unet(
|
||||
hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype),
|
||||
@@ -2039,25 +2038,25 @@ class StableDiffusion:
|
||||
if unet:
|
||||
if self.is_flux:
|
||||
# Just train the middle 2 blocks of each transformer block
|
||||
block_list = []
|
||||
num_transformer_blocks = 2
|
||||
start_block = len(self.unet.transformer_blocks) // 2 - (num_transformer_blocks // 2)
|
||||
for i in range(num_transformer_blocks):
|
||||
block_list.append(self.unet.transformer_blocks[start_block + i])
|
||||
# block_list = []
|
||||
# num_transformer_blocks = 2
|
||||
# start_block = len(self.unet.transformer_blocks) // 2 - (num_transformer_blocks // 2)
|
||||
# for i in range(num_transformer_blocks):
|
||||
# block_list.append(self.unet.transformer_blocks[start_block + i])
|
||||
#
|
||||
# num_single_transformer_blocks = 4
|
||||
# start_block = len(self.unet.single_transformer_blocks) // 2 - (num_single_transformer_blocks // 2)
|
||||
# for i in range(num_single_transformer_blocks):
|
||||
# block_list.append(self.unet.single_transformer_blocks[start_block + i])
|
||||
#
|
||||
# for block in block_list:
|
||||
# for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
# named_params[name] = param
|
||||
|
||||
num_single_transformer_blocks = 4
|
||||
start_block = len(self.unet.single_transformer_blocks) // 2 - (num_single_transformer_blocks // 2)
|
||||
for i in range(num_single_transformer_blocks):
|
||||
block_list.append(self.unet.single_transformer_blocks[start_block + i])
|
||||
|
||||
for block in block_list:
|
||||
for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
named_params[name] = param
|
||||
|
||||
# for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
# named_params[name] = param
|
||||
# for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
# named_params[name] = param
|
||||
for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
named_params[name] = param
|
||||
for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
named_params[name] = param
|
||||
else:
|
||||
for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||
named_params[name] = param
|
||||
|
||||
Reference in New Issue
Block a user