Added ability to set timesteps to linear for flowmatching schedule

This commit is contained in:
Jaret Burkett
2024-08-11 13:06:08 -06:00
parent f8f0657b68
commit a6aa4b2c7d
3 changed files with 19 additions and 11 deletions

View File

@@ -908,7 +908,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
)
elif self.train_config.noise_scheduler == 'flowmatch':
self.sd.noise_scheduler.set_train_timesteps(
num_train_timesteps, device=self.device_torch
num_train_timesteps,
device=self.device_torch,
linear=self.train_config.linear_timesteps
)
else:
self.sd.noise_scheduler.set_timesteps(

View File

@@ -356,6 +356,7 @@ class TrainConfig:
# adds an additional loss to the network to encourage it output a normalized standard deviation
self.target_norm_std = kwargs.get('target_norm_std', None)
self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0)
self.linear_timesteps = kwargs.get('linear_timesteps', False)
class ModelConfig:

View File

@@ -44,17 +44,22 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
return sample
def set_train_timesteps(self, num_timesteps, device):
# distribute them closer to center. Inference distributes them as a bias toward first
# Generate values from 0 to 1
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
def set_train_timesteps(self, num_timesteps, device, linear=False):
if linear:
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
self.timesteps = timesteps
return timesteps
else:
# distribute them closer to center. Inference distributes them as a bias toward first
# Generate values from 0 to 1
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
# Scale and reverse the values to go from 1000 to 0
timesteps = ((1 - t) * 1000)
# Scale and reverse the values to go from 1000 to 0
timesteps = ((1 - t) * 1000)
# Sort the timesteps in descending order
timesteps, _ = torch.sort(timesteps, descending=True)
# Sort the timesteps in descending order
timesteps, _ = torch.sort(timesteps, descending=True)
self.timesteps = timesteps.to(device=device)
self.timesteps = timesteps.to(device=device)
return timesteps
return timesteps