mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added a new experimental linear weighing technique
This commit is contained in:
@@ -390,9 +390,12 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none")
|
||||
|
||||
# handle linear timesteps and only adjust the weight of the timesteps
|
||||
if self.sd.is_flow_matching and self.train_config.linear_timesteps:
|
||||
if self.sd.is_flow_matching and (self.train_config.linear_timesteps or self.train_config.linear_timesteps2):
|
||||
# calculate the weights for the timesteps
|
||||
timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps(timesteps).to(loss.device, dtype=loss.dtype)
|
||||
timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps(
|
||||
timesteps,
|
||||
v2=self.train_config.linear_timesteps2
|
||||
).to(loss.device, dtype=loss.dtype)
|
||||
timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach()
|
||||
loss = loss * timestep_weight
|
||||
|
||||
|
||||
@@ -913,7 +913,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.sd.noise_scheduler.set_train_timesteps(
|
||||
num_train_timesteps,
|
||||
device=self.device_torch,
|
||||
linear=self.train_config.linear_timesteps
|
||||
linear=self.train_config.linear_timesteps or self.train_config.linear_timesteps2
|
||||
)
|
||||
else:
|
||||
self.sd.noise_scheduler.set_timesteps(
|
||||
|
||||
@@ -359,6 +359,7 @@ class TrainConfig:
|
||||
self.target_norm_std = kwargs.get('target_norm_std', None)
|
||||
self.target_norm_std_value = kwargs.get('target_norm_std_value', 1.0)
|
||||
self.linear_timesteps = kwargs.get('linear_timesteps', False)
|
||||
self.linear_timesteps2 = kwargs.get('linear_timesteps2', False)
|
||||
self.disable_sampling = kwargs.get('disable_sampling', False)
|
||||
|
||||
|
||||
|
||||
@@ -25,19 +25,29 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
# Scale to make mean 1
|
||||
bsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
|
||||
|
||||
# only do half bell
|
||||
hbsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
|
||||
|
||||
# flatten second half to max
|
||||
hbsmntw_weighing[num_timesteps // 2:] = hbsmntw_weighing[num_timesteps // 2:].max()
|
||||
|
||||
# Create linear timesteps from 1000 to 0
|
||||
timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu')
|
||||
|
||||
self.linear_timesteps = timesteps
|
||||
self.linear_timesteps_weights = bsmntw_weighing
|
||||
self.linear_timesteps_weights2 = hbsmntw_weighing
|
||||
pass
|
||||
|
||||
def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor:
|
||||
# Get the indices of the timesteps
|
||||
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
# Get the weights for the timesteps
|
||||
weights = self.linear_timesteps_weights[step_indices].flatten()
|
||||
if v2:
|
||||
weights = self.linear_timesteps_weights2[step_indices].flatten()
|
||||
else:
|
||||
weights = self.linear_timesteps_weights[step_indices].flatten()
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
Reference in New Issue
Block a user