Setup to retrain guidance embedding for flux. Use defualt timestep distribution for flux

This commit is contained in:
Jaret Burkett
2024-08-04 10:37:23 -06:00
parent 88acc28d7f
commit f321de7bdb
3 changed files with 28 additions and 12 deletions

View File

@@ -959,15 +959,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
raise ValueError(f"Unknown content_or_style {content_or_style}")
# do flow matching
if self.sd.is_rectified_flow:
u = compute_density_for_timestep_sampling(
weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"]
batch_size=batch_size,
logit_mean=0.0,
logit_std=1.0,
mode_scale=1.29,
)
timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long()
# if self.sd.is_rectified_flow:
# u = compute_density_for_timestep_sampling(
# weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"]
# batch_size=batch_size,
# logit_mean=0.0,
# logit_std=1.0,
# mode_scale=1.29,
# )
# timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long()
# convert the timestep_indices to a timestep
timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
timesteps = torch.stack(timesteps, dim=0)