diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 07e8240c..280b4738 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -390,9 +390,12 @@ class SDTrainer(BaseSDTrainProcess): loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none") # handle linear timesteps and only adjust the weight of the timesteps - if self.sd.is_flow_matching and self.train_config.linear_timesteps: + if self.sd.is_flow_matching and (self.train_config.linear_timesteps or self.train_config.linear_timesteps2): # calculate the weights for the timesteps - timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps(timesteps).to(loss.device, dtype=loss.dtype) + timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps( + timesteps, + v2=self.train_config.linear_timesteps2 + ).to(loss.device, dtype=loss.dtype) timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach() loss = loss * timestep_weight diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 455f28bb..930a556e 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -913,7 +913,7 @@ class BaseSDTrainProcess(BaseTrainProcess): self.sd.noise_scheduler.set_train_timesteps( num_train_timesteps, device=self.device_torch, - linear=self.train_config.linear_timesteps + linear=self.train_config.linear_timesteps or self.train_config.linear_timesteps2 ) else: self.sd.noise_scheduler.set_timesteps( diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index b794dd2f..2b844324 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -359,6 +359,7 @@ class TrainConfig: self.target_norm_std = kwargs.get('target_norm_std', None) self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0) self.linear_timesteps = kwargs.get('linear_timesteps', False) + self.linear_timesteps2 = kwargs.get('linear_timesteps2', False) self.disable_sampling = kwargs.get('disable_sampling', False) diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py index 1cb2eac6..440eb4fa 100644 --- a/toolkit/samplers/custom_flowmatch_sampler.py +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -25,19 +25,29 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): # Scale to make mean 1 bsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum()) + # only do half bell + hbsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum()) + + # flatten second half to max + hbsmntw_weighing[num_timesteps // 2:] = hbsmntw_weighing[num_timesteps // 2:].max() + # Create linear timesteps from 1000 to 0 timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu') self.linear_timesteps = timesteps self.linear_timesteps_weights = bsmntw_weighing + self.linear_timesteps_weights2 = hbsmntw_weighing pass - def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor: + def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor: # Get the indices of the timesteps step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps] # Get the weights for the timesteps - weights = self.linear_timesteps_weights[step_indices].flatten() + if v2: + weights = self.linear_timesteps_weights2[step_indices].flatten() + else: + weights = self.linear_timesteps_weights[step_indices].flatten() return weights