diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py index a38fcce0..0a1d7f45 100644 --- a/toolkit/samplers/custom_flowmatch_sampler.py +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -55,4 +55,6 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): # Sort the timesteps in descending order timesteps, _ = torch.sort(timesteps, descending=True) + self.timesteps = timesteps.to(device=device) + return timesteps