mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fix issue with wan22 14b where timesteps were generated not in the current boundary.
This commit is contained in:
@@ -1177,13 +1177,13 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.sd.is_multistage:
|
||||
with self.timer('adjust_multistage_timesteps'):
|
||||
# get our current sample range
|
||||
boundaries = [1000] + self.sd.multistage_boundaries
|
||||
boundaries = [1] + self.sd.multistage_boundaries
|
||||
boundary_max, boundary_min = boundaries[self.current_boundary_index], boundaries[self.current_boundary_index + 1]
|
||||
lo = torch.searchsorted(self.sd.noise_scheduler.timesteps, -torch.tensor(boundary_max, device=self.sd.noise_scheduler.timesteps.device), right=False)
|
||||
hi = torch.searchsorted(self.sd.noise_scheduler.timesteps, -torch.tensor(boundary_min, device=self.sd.noise_scheduler.timesteps.device), right=True)
|
||||
first_idx = lo.item() if hi > lo else 0
|
||||
asc_timesteps = torch.flip(self.sd.noise_scheduler.timesteps, dims=[0])
|
||||
lo = len(asc_timesteps) - torch.searchsorted(asc_timesteps, torch.tensor(boundary_max * 1000, device=asc_timesteps.device), right=False)
|
||||
hi = len(asc_timesteps) - torch.searchsorted(asc_timesteps, torch.tensor(boundary_min * 1000, device=asc_timesteps.device), right=True)
|
||||
first_idx = (lo - 1).item() if hi > lo else 0
|
||||
last_idx = (hi - 1).item() if hi > lo else 999
|
||||
|
||||
min_noise_steps = first_idx
|
||||
max_noise_steps = last_idx
|
||||
|
||||
@@ -1246,7 +1246,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
max_idx = max_noise_steps - 1
|
||||
if self.train_config.noise_scheduler == 'flowmatch':
|
||||
# flowmatch uses indices, so we need to use indices
|
||||
min_idx = 0
|
||||
min_idx = min_noise_steps
|
||||
max_idx = max_noise_steps
|
||||
timestep_indices = torch.randint(
|
||||
min_idx,
|
||||
|
||||
Reference in New Issue
Block a user