Added a new experimental linear weighing technique

This commit is contained in:
Jaret Burkett
2024-09-02 09:22:13 -06:00
parent 7d9ab22405
commit d44d4eb61a
4 changed files with 19 additions and 5 deletions

View File

@@ -25,19 +25,29 @@ class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
# Scale to make mean 1
bsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
# only do half bell
hbsmntw_weighing = y_shifted * (num_timesteps / y_shifted.sum())
# flatten second half to max
hbsmntw_weighing[num_timesteps // 2:] = hbsmntw_weighing[num_timesteps // 2:].max()
# Create linear timesteps from 1000 to 0
timesteps = torch.linspace(1000, 0, num_timesteps, device='cpu')
self.linear_timesteps = timesteps
self.linear_timesteps_weights = bsmntw_weighing
self.linear_timesteps_weights2 = hbsmntw_weighing
pass
def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor:
def get_weights_for_timesteps(self, timesteps: torch.Tensor, v2=False) -> torch.Tensor:
# Get the indices of the timesteps
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
# Get the weights for the timesteps
weights = self.linear_timesteps_weights[step_indices].flatten()
if v2:
weights = self.linear_timesteps_weights2[step_indices].flatten()
else:
weights = self.linear_timesteps_weights[step_indices].flatten()
return weights