Actually use the correct timestep sampling instead of calculating it and moving on lol. Tested a few with it and it seems to work better.

This commit is contained in:
Jaret Burkett
2024-08-11 11:10:37 -06:00
parent ec1ea7aa0e
commit fbed8568fb

View File

@@ -55,4 +55,6 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
# Sort the timesteps in descending order
timesteps, _ = torch.sort(timesteps, descending=True)
self.timesteps = timesteps.to(device=device)
return timesteps