mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Setup to retrain guidance embedding for flux. Use defualt timestep distribution for flux
This commit is contained in:
@@ -330,7 +330,12 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
elif self.sd.is_rectified_flow:
|
elif self.sd.is_rectified_flow:
|
||||||
# only if preconditioning model outputs
|
# only if preconditioning model outputs
|
||||||
# if not preconditioning, (target = noise - batch.latents) is used
|
# if not preconditioning, (target = noise - batch.latents)
|
||||||
|
|
||||||
|
|
||||||
|
# target = noise - batch.latents
|
||||||
|
|
||||||
|
# if preconditioning outputs, target latents
|
||||||
target = batch.latents.detach()
|
target = batch.latents.detach()
|
||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|||||||
@@ -959,15 +959,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
raise ValueError(f"Unknown content_or_style {content_or_style}")
|
raise ValueError(f"Unknown content_or_style {content_or_style}")
|
||||||
|
|
||||||
# do flow matching
|
# do flow matching
|
||||||
if self.sd.is_rectified_flow:
|
# if self.sd.is_rectified_flow:
|
||||||
u = compute_density_for_timestep_sampling(
|
# u = compute_density_for_timestep_sampling(
|
||||||
weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"]
|
# weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"]
|
||||||
batch_size=batch_size,
|
# batch_size=batch_size,
|
||||||
logit_mean=0.0,
|
# logit_mean=0.0,
|
||||||
logit_std=1.0,
|
# logit_std=1.0,
|
||||||
mode_scale=1.29,
|
# mode_scale=1.29,
|
||||||
)
|
# )
|
||||||
timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long()
|
# timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long()
|
||||||
# convert the timestep_indices to a timestep
|
# convert the timestep_indices to a timestep
|
||||||
timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
|
timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
|
||||||
timesteps = torch.stack(timesteps, dim=0)
|
timesteps = torch.stack(timesteps, dim=0)
|
||||||
|
|||||||
@@ -464,7 +464,13 @@ class StableDiffusion:
|
|||||||
subfolder = None
|
subfolder = None
|
||||||
transformer_path = os.path.join(transformer_path, 'transformer')
|
transformer_path = os.path.join(transformer_path, 'transformer')
|
||||||
|
|
||||||
transformer = FluxTransformer2DModel.from_pretrained(transformer_path, subfolder=subfolder, torch_dtype=dtype)
|
transformer = FluxTransformer2DModel.from_pretrained(
|
||||||
|
transformer_path,
|
||||||
|
subfolder=subfolder,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
low_cpu_mem_usage=False,
|
||||||
|
device_map=None
|
||||||
|
)
|
||||||
transformer.to(self.device_torch, dtype=dtype)
|
transformer.to(self.device_torch, dtype=dtype)
|
||||||
flush()
|
flush()
|
||||||
|
|
||||||
@@ -1609,7 +1615,6 @@ class StableDiffusion:
|
|||||||
vae_scale_factor=VAE_SCALE_FACTOR * 2, # should be 16 not sure why
|
vae_scale_factor=VAE_SCALE_FACTOR * 2, # should be 16 not sure why
|
||||||
)
|
)
|
||||||
|
|
||||||
# todo we do this on sd3 training. I think we do it here too? No paper
|
|
||||||
noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep)
|
noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep)
|
||||||
elif self.is_v3:
|
elif self.is_v3:
|
||||||
noise_pred = self.unet(
|
noise_pred = self.unet(
|
||||||
@@ -2053,6 +2058,12 @@ class StableDiffusion:
|
|||||||
# for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
# for name, param in block.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||||
# named_params[name] = param
|
# named_params[name] = param
|
||||||
|
|
||||||
|
# train the guidance embedding
|
||||||
|
if self.unet.config.guidance_embeds:
|
||||||
|
transformer: FluxTransformer2DModel = self.unet
|
||||||
|
for name, param in transformer.time_text_embed.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||||
|
named_params[name] = param
|
||||||
|
|
||||||
for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||||
named_params[name] = param
|
named_params[name] = param
|
||||||
for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):
|
||||||
|
|||||||
Reference in New Issue
Block a user