Added a new experimental linear weighing technique

This commit is contained in:
Jaret Burkett
2024-09-02 09:22:13 -06:00
parent 7d9ab22405
commit d44d4eb61a
4 changed files with 19 additions and 5 deletions

View File

@@ -390,9 +390,12 @@ class SDTrainer(BaseSDTrainProcess):
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
# handle linear timesteps and only adjust the weight of the timesteps
if self.sd.is_flow_matching and self.train_config.linear_timesteps:
if self.sd.is_flow_matching and (self.train_config.linear_timesteps or self.train_config.linear_timesteps2):
# calculate the weights for the timesteps
timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps(timesteps).to(loss.device, dtype=loss.dtype)
timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps(
timesteps,
v2=self.train_config.linear_timesteps2
).to(loss.device, dtype=loss.dtype)
timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach()
loss = loss * timestep_weight

View File

@@ -913,7 +913,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.sd.noise_scheduler.set_train_timesteps(
num_train_timesteps,
device=self.device_torch,
linear=self.train_config.linear_timesteps
linear=self.train_config.linear_timesteps or self.train_config.linear_timesteps2
)
else:
self.sd.noise_scheduler.set_timesteps(

View File

@@ -359,6 +359,7 @@ class TrainConfig:
self.target_norm_std = kwargs.get('target_norm_std', None)
self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0)
self.linear_timesteps = kwargs.get('linear_timesteps', False)
self.linear_timesteps2 = kwargs.get('linear_timesteps2', False)
self.disable_sampling = kwargs.get('disable_sampling', False)

View File

@@ -25,19 +25,29 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
# Scale to make mean 1
bsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
# only do half bell
hbsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
# flatten second half to max
hbsmntw_weighing[num_timesteps // 2:] = hbsmntw_weighing[num_timesteps // 2:].max()
# Create linear timesteps from 1000 to 0
timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu')
self.linear_timesteps = timesteps
self.linear_timesteps_weights = bsmntw_weighing
self.linear_timesteps_weights2 = hbsmntw_weighing
pass
def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor:
def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor:
# Get the indices of the timesteps
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
# Get the weights for the timesteps
weights = self.linear_timesteps_weights[step_indices].flatten()
if v2:
weights = self.linear_timesteps_weights2[step_indices].flatten()
else:
weights = self.linear_timesteps_weights[step_indices].flatten()
return weights