Added ability to set timesteps to linear for flowmatching schedule

This commit is contained in:
Jaret Burkett
2024-08-11 13:06:08 -06:00
parent f8f0657b68
commit a6aa4b2c7d
3 changed files with 19 additions and 11 deletions

View File

@@ -44,17 +44,22 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
return sample
def set_train_timesteps(self, num_timesteps, device):
# distribute them closer to center. Inference distributes them as a bias toward first
# Generate values from 0 to 1
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
def set_train_timesteps(self, num_timesteps, device, linear=False):
if linear:
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
self.timesteps = timesteps
return timesteps
else:
# distribute them closer to center. Inference distributes them as a bias toward first
# Generate values from 0 to 1
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
# Scale and reverse the values to go from 1000 to 0
timesteps = ((1 - t) * 1000)
# Scale and reverse the values to go from 1000 to 0
timesteps = ((1 - t) * 1000)
# Sort the timesteps in descending order
timesteps, _ = torch.sort(timesteps, descending=True)
# Sort the timesteps in descending order
timesteps, _ = torch.sort(timesteps, descending=True)
self.timesteps = timesteps.to(device=device)
self.timesteps = timesteps.to(device=device)
return timesteps
return timesteps