mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 17:51:41 +00:00
Setup to retrain guidance embedding for flux. Use defualt timestep distribution for flux
This commit is contained in:
@@ -959,15 +959,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
raise ValueError(f"Unknown content_or_style {content_or_style}")
|
||||
|
||||
# do flow matching
|
||||
if self.sd.is_rectified_flow:
|
||||
u = compute_density_for_timestep_sampling(
|
||||
weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"]
|
||||
batch_size=batch_size,
|
||||
logit_mean=0.0,
|
||||
logit_std=1.0,
|
||||
mode_scale=1.29,
|
||||
)
|
||||
timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long()
|
||||
# if self.sd.is_rectified_flow:
|
||||
# u = compute_density_for_timestep_sampling(
|
||||
# weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"]
|
||||
# batch_size=batch_size,
|
||||
# logit_mean=0.0,
|
||||
# logit_std=1.0,
|
||||
# mode_scale=1.29,
|
||||
# )
|
||||
# timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long()
|
||||
# convert the timestep_indices to a timestep
|
||||
timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
|
||||
timesteps = torch.stack(timesteps, dim=0)
|
||||
|
||||
Reference in New Issue
Block a user