Reworked timestep distribution on flowmatch sampler when training.

This commit is contained in:
Jaret Burkett
2024-08-08 06:01:45 -06:00
parent acafe9984f
commit e69a520616
2 changed files with 27 additions and 4 deletions

View File

@@ -906,6 +906,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.sd.noise_scheduler.set_timesteps(
num_train_timesteps, device=self.device_torch, original_inference_steps=num_train_timesteps
)
elif self.train_config.noise_scheduler == 'flowmatch':
self.sd.noise_scheduler.set_train_timesteps(
num_train_timesteps, device=self.device_torch
)
else:
self.sd.noise_scheduler.set_timesteps(
num_train_timesteps, device=self.device_torch