mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 19:21:39 +00:00
Added a new experimental linear weighing technique
This commit is contained in:
@@ -25,19 +25,29 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
|
||||
# Scale to make mean 1
|
||||
bsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
|
||||
|
||||
# only do half bell
|
||||
hbsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
|
||||
|
||||
# flatten second half to max
|
||||
hbsmntw_weighing[num_timesteps // 2:] = hbsmntw_weighing[num_timesteps // 2:].max()
|
||||
|
||||
# Create linear timesteps from 1000 to 0
|
||||
timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu')
|
||||
|
||||
self.linear_timesteps = timesteps
|
||||
self.linear_timesteps_weights = bsmntw_weighing
|
||||
self.linear_timesteps_weights2 = hbsmntw_weighing
|
||||
pass
|
||||
|
||||
def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor:
|
||||
def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor:
|
||||
# Get the indices of the timesteps
|
||||
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
# Get the weights for the timesteps
|
||||
weights = self.linear_timesteps_weights[step_indices].flatten()
|
||||
if v2:
|
||||
weights = self.linear_timesteps_weights2[step_indices].flatten()
|
||||
else:
|
||||
weights = self.linear_timesteps_weights[step_indices].flatten()
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
Reference in New Issue
Block a user